Compare commits

9 Commits

Author SHA1 Message Date
506c74e53a Small update 2026-01-16 10:39:46 +02:00
eefda5b878 Adding metdata to tiled images 2026-01-16 10:39:14 +02:00
31cb6a6c8e Using 8bit images 2026-01-16 10:38:34 +02:00
0c19ea2557 Updating 2026-01-16 10:30:13 +02:00
89e47591db Formatting 2026-01-16 10:27:15 +02:00
69cde09e53 Changing alpha value 2026-01-16 10:26:25 +02:00
fcbd5fb16d correcting label writing and formatting code 2026-01-16 10:24:19 +02:00
ca52312925 Adding LIKE option for filtering queries 2026-01-16 10:18:48 +02:00
0a93bf797a Adding auto zoom when result is loaded 2026-01-12 14:15:02 +02:00
9 changed files with 238 additions and 357 deletions

View File

@@ -60,9 +60,7 @@ class DatabaseManager:
cursor = conn.cursor() cursor = conn.cursor()
# Check if annotations table exists # Check if annotations table exists
cursor.execute( cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='annotations'")
"SELECT name FROM sqlite_master WHERE type='table' AND name='annotations'"
)
if not cursor.fetchone(): if not cursor.fetchone():
# Table doesn't exist yet, no migration needed # Table doesn't exist yet, no migration needed
return return
@@ -242,9 +240,7 @@ class DatabaseManager:
return cursor.lastrowid return cursor.lastrowid
except sqlite3.IntegrityError: except sqlite3.IntegrityError:
# Image already exists, return its ID # Image already exists, return its ID
cursor.execute( cursor.execute("SELECT id FROM images WHERE relative_path = ?", (relative_path,))
"SELECT id FROM images WHERE relative_path = ?", (relative_path,)
)
row = cursor.fetchone() row = cursor.fetchone()
return row["id"] if row else None return row["id"] if row else None
finally: finally:
@@ -255,17 +251,13 @@ class DatabaseManager:
conn = self.get_connection() conn = self.get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute("SELECT * FROM images WHERE relative_path = ?", (relative_path,))
"SELECT * FROM images WHERE relative_path = ?", (relative_path,)
)
row = cursor.fetchone() row = cursor.fetchone()
return dict(row) if row else None return dict(row) if row else None
finally: finally:
conn.close() conn.close()
def get_or_create_image( def get_or_create_image(self, relative_path: str, filename: str, width: int, height: int) -> int:
self, relative_path: str, filename: str, width: int, height: int
) -> int:
"""Get existing image or create new one.""" """Get existing image or create new one."""
existing = self.get_image_by_path(relative_path) existing = self.get_image_by_path(relative_path)
if existing: if existing:
@@ -355,16 +347,8 @@ class DatabaseManager:
bbox[2], bbox[2],
bbox[3], bbox[3],
det["confidence"], det["confidence"],
( (json.dumps(det.get("segmentation_mask")) if det.get("segmentation_mask") else None),
json.dumps(det.get("segmentation_mask")) (json.dumps(det.get("metadata")) if det.get("metadata") else None),
if det.get("segmentation_mask")
else None
),
(
json.dumps(det.get("metadata"))
if det.get("metadata")
else None
),
), ),
) )
conn.commit() conn.commit()
@@ -409,15 +393,16 @@ class DatabaseManager:
if filters: if filters:
conditions = [] conditions = []
for key, value in filters.items(): for key, value in filters.items():
if ( if key.startswith("d.") or key.startswith("i.") or key.startswith("m."):
key.startswith("d.") if "like" in value.lower():
or key.startswith("i.") conditions.append(f"{key} LIKE ?")
or key.startswith("m.") params.append(value.split(" ")[1])
): else:
conditions.append(f"{key} = ?") conditions.append(f"{key} = ?")
params.append(value)
else: else:
conditions.append(f"d.{key} = ?") conditions.append(f"d.{key} = ?")
params.append(value) params.append(value)
query += " WHERE " + " AND ".join(conditions) query += " WHERE " + " AND ".join(conditions)
query += " ORDER BY d.detected_at DESC" query += " ORDER BY d.detected_at DESC"
@@ -442,18 +427,14 @@ class DatabaseManager:
finally: finally:
conn.close() conn.close()
def get_detections_for_image( def get_detections_for_image(self, image_id: int, model_id: Optional[int] = None) -> List[Dict]:
self, image_id: int, model_id: Optional[int] = None
) -> List[Dict]:
"""Get all detections for a specific image.""" """Get all detections for a specific image."""
filters = {"image_id": image_id} filters = {"image_id": image_id}
if model_id: if model_id:
filters["model_id"] = model_id filters["model_id"] = model_id
return self.get_detections(filters) return self.get_detections(filters)
def delete_detections_for_image( def delete_detections_for_image(self, image_id: int, model_id: Optional[int] = None) -> int:
self, image_id: int, model_id: Optional[int] = None
) -> int:
"""Delete detections tied to a specific image and optional model.""" """Delete detections tied to a specific image and optional model."""
conn = self.get_connection() conn = self.get_connection()
try: try:
@@ -524,9 +505,7 @@ class DatabaseManager:
""", """,
params, params,
) )
class_counts = { class_counts = {row["class_name"]: row["count"] for row in cursor.fetchall()}
row["class_name"]: row["count"] for row in cursor.fetchall()
}
# Average confidence # Average confidence
cursor.execute( cursor.execute(
@@ -583,9 +562,7 @@ class DatabaseManager:
# ==================== Export Operations ==================== # ==================== Export Operations ====================
def export_detections_to_csv( def export_detections_to_csv(self, output_path: str, filters: Optional[Dict] = None) -> bool:
self, output_path: str, filters: Optional[Dict] = None
) -> bool:
"""Export detections to CSV file.""" """Export detections to CSV file."""
try: try:
detections = self.get_detections(filters) detections = self.get_detections(filters)
@@ -614,9 +591,7 @@ class DatabaseManager:
for det in detections: for det in detections:
row = {k: det[k] for k in fieldnames if k in det} row = {k: det[k] for k in fieldnames if k in det}
# Convert segmentation mask list to JSON string for CSV # Convert segmentation mask list to JSON string for CSV
if row.get("segmentation_mask") and isinstance( if row.get("segmentation_mask") and isinstance(row["segmentation_mask"], list):
row["segmentation_mask"], list
):
row["segmentation_mask"] = json.dumps(row["segmentation_mask"]) row["segmentation_mask"] = json.dumps(row["segmentation_mask"])
writer.writerow(row) writer.writerow(row)
@@ -625,9 +600,7 @@ class DatabaseManager:
print(f"Error exporting to CSV: {e}") print(f"Error exporting to CSV: {e}")
return False return False
def export_detections_to_json( def export_detections_to_json(self, output_path: str, filters: Optional[Dict] = None) -> bool:
self, output_path: str, filters: Optional[Dict] = None
) -> bool:
"""Export detections to JSON file.""" """Export detections to JSON file."""
try: try:
detections = self.get_detections(filters) detections = self.get_detections(filters)
@@ -785,17 +758,13 @@ class DatabaseManager:
conn = self.get_connection() conn = self.get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute("SELECT * FROM object_classes WHERE class_name = ?", (class_name,))
"SELECT * FROM object_classes WHERE class_name = ?", (class_name,)
)
row = cursor.fetchone() row = cursor.fetchone()
return dict(row) if row else None return dict(row) if row else None
finally: finally:
conn.close() conn.close()
def add_object_class( def add_object_class(self, class_name: str, color: str, description: Optional[str] = None) -> int:
self, class_name: str, color: str, description: Optional[str] = None
) -> int:
""" """
Add a new object class. Add a new object class.
@@ -928,8 +897,7 @@ class DatabaseManager:
if not split_map[required]: if not split_map[required]:
raise ValueError( raise ValueError(
"Unable to determine %s image directory under %s. Provide it " "Unable to determine %s image directory under %s. Provide it "
"explicitly via the 'splits' argument." "explicitly via the 'splits' argument." % (required, dataset_root_path)
% (required, dataset_root_path)
) )
yaml_splits: Dict[str, str] = {} yaml_splits: Dict[str, str] = {}
@@ -955,11 +923,7 @@ class DatabaseManager:
if yaml_splits.get("test"): if yaml_splits.get("test"):
payload["test"] = yaml_splits["test"] payload["test"] = yaml_splits["test"]
output_path_obj = ( output_path_obj = Path(output_path).expanduser() if output_path else dataset_root_path / "data.yaml"
Path(output_path).expanduser()
if output_path
else dataset_root_path / "data.yaml"
)
output_path_obj.parent.mkdir(parents=True, exist_ok=True) output_path_obj.parent.mkdir(parents=True, exist_ok=True)
with open(output_path_obj, "w", encoding="utf-8") as handle: with open(output_path_obj, "w", encoding="utf-8") as handle:
@@ -1019,15 +983,9 @@ class DatabaseManager:
for split_name, options in patterns.items(): for split_name, options in patterns.items():
for relative in options: for relative in options:
candidate = (dataset_root / relative).resolve() candidate = (dataset_root / relative).resolve()
if ( if candidate.exists() and candidate.is_dir() and self._directory_has_images(candidate):
candidate.exists()
and candidate.is_dir()
and self._directory_has_images(candidate)
):
try: try:
inferred[split_name] = candidate.relative_to( inferred[split_name] = candidate.relative_to(dataset_root).as_posix()
dataset_root
).as_posix()
except ValueError: except ValueError:
inferred[split_name] = candidate.as_posix() inferred[split_name] = candidate.as_posix()
break break

View File

@@ -35,9 +35,7 @@ logger = get_logger(__name__)
class ResultsTab(QWidget): class ResultsTab(QWidget):
"""Results tab showing detection history and preview overlays.""" """Results tab showing detection history and preview overlays."""
def __init__( def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
):
super().__init__(parent) super().__init__(parent)
self.db_manager = db_manager self.db_manager = db_manager
self.config_manager = config_manager self.config_manager = config_manager
@@ -71,24 +69,12 @@ class ResultsTab(QWidget):
left_layout.addLayout(controls_layout) left_layout.addLayout(controls_layout)
self.results_table = QTableWidget(0, 5) self.results_table = QTableWidget(0, 5)
self.results_table.setHorizontalHeaderLabels( self.results_table.setHorizontalHeaderLabels(["Image", "Model", "Detections", "Classes", "Last Updated"])
["Image", "Model", "Detections", "Classes", "Last Updated"] self.results_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Stretch)
) self.results_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.Stretch)
self.results_table.horizontalHeader().setSectionResizeMode( self.results_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeToContents)
0, QHeaderView.Stretch self.results_table.horizontalHeader().setSectionResizeMode(3, QHeaderView.Stretch)
) self.results_table.horizontalHeader().setSectionResizeMode(4, QHeaderView.ResizeToContents)
self.results_table.horizontalHeader().setSectionResizeMode(
1, QHeaderView.Stretch
)
self.results_table.horizontalHeader().setSectionResizeMode(
2, QHeaderView.ResizeToContents
)
self.results_table.horizontalHeader().setSectionResizeMode(
3, QHeaderView.Stretch
)
self.results_table.horizontalHeader().setSectionResizeMode(
4, QHeaderView.ResizeToContents
)
self.results_table.setSelectionBehavior(QAbstractItemView.SelectRows) self.results_table.setSelectionBehavior(QAbstractItemView.SelectRows)
self.results_table.setSelectionMode(QAbstractItemView.SingleSelection) self.results_table.setSelectionMode(QAbstractItemView.SingleSelection)
self.results_table.setEditTriggers(QAbstractItemView.NoEditTriggers) self.results_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
@@ -106,6 +92,8 @@ class ResultsTab(QWidget):
preview_layout = QVBoxLayout() preview_layout = QVBoxLayout()
self.preview_canvas = AnnotationCanvasWidget() self.preview_canvas = AnnotationCanvasWidget()
# Auto-zoom so newly loaded images fill the available preview viewport.
self.preview_canvas.set_auto_fit_to_view(True)
self.preview_canvas.set_polyline_enabled(False) self.preview_canvas.set_polyline_enabled(False)
self.preview_canvas.set_show_bboxes(True) self.preview_canvas.set_show_bboxes(True)
preview_layout.addWidget(self.preview_canvas) preview_layout.addWidget(self.preview_canvas)
@@ -119,9 +107,7 @@ class ResultsTab(QWidget):
self.show_bboxes_checkbox.stateChanged.connect(self._toggle_bboxes) self.show_bboxes_checkbox.stateChanged.connect(self._toggle_bboxes)
self.show_confidence_checkbox = QCheckBox("Show Confidence") self.show_confidence_checkbox = QCheckBox("Show Confidence")
self.show_confidence_checkbox.setChecked(False) self.show_confidence_checkbox.setChecked(False)
self.show_confidence_checkbox.stateChanged.connect( self.show_confidence_checkbox.stateChanged.connect(self._apply_detection_overlays)
self._apply_detection_overlays
)
toggles_layout.addWidget(self.show_masks_checkbox) toggles_layout.addWidget(self.show_masks_checkbox)
toggles_layout.addWidget(self.show_bboxes_checkbox) toggles_layout.addWidget(self.show_bboxes_checkbox)
toggles_layout.addWidget(self.show_confidence_checkbox) toggles_layout.addWidget(self.show_confidence_checkbox)
@@ -169,8 +155,7 @@ class ResultsTab(QWidget):
"image_id": det["image_id"], "image_id": det["image_id"],
"model_id": det["model_id"], "model_id": det["model_id"],
"image_path": det.get("image_path"), "image_path": det.get("image_path"),
"image_filename": det.get("image_filename") "image_filename": det.get("image_filename") or det.get("image_path"),
or det.get("image_path"),
"model_name": det.get("model_name", ""), "model_name": det.get("model_name", ""),
"model_version": det.get("model_version", ""), "model_version": det.get("model_version", ""),
"last_detected": det.get("detected_at"), "last_detected": det.get("detected_at"),
@@ -183,8 +168,7 @@ class ResultsTab(QWidget):
entry["count"] += 1 entry["count"] += 1
if det.get("detected_at") and ( if det.get("detected_at") and (
not entry.get("last_detected") not entry.get("last_detected") or str(det.get("detected_at")) > str(entry.get("last_detected"))
or str(det.get("detected_at")) > str(entry.get("last_detected"))
): ):
entry["last_detected"] = det.get("detected_at") entry["last_detected"] = det.get("detected_at")
if det.get("class_name"): if det.get("class_name"):
@@ -214,9 +198,7 @@ class ResultsTab(QWidget):
for row, entry in enumerate(self.detection_summary): for row, entry in enumerate(self.detection_summary):
model_label = f"{entry['model_name']} {entry['model_version']}".strip() model_label = f"{entry['model_name']} {entry['model_version']}".strip()
class_list = ( class_list = ", ".join(sorted(entry["classes"])) if entry["classes"] else "-"
", ".join(sorted(entry["classes"])) if entry["classes"] else "-"
)
items = [ items = [
QTableWidgetItem(entry.get("image_filename", "")), QTableWidgetItem(entry.get("image_filename", "")),

View File

@@ -10,6 +10,7 @@ from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import yaml import yaml
import numpy as np
from PySide6.QtCore import Qt, QThread, Signal from PySide6.QtCore import Qt, QThread, Signal
from PySide6.QtWidgets import ( from PySide6.QtWidgets import (
QWidget, QWidget,
@@ -91,10 +92,7 @@ class TrainingWorker(QThread):
}, },
} }
] ]
computed_total = sum( computed_total = sum(max(0, int((stage.get("params") or {}).get("epochs", 0))) for stage in self.stage_plan)
max(0, int((stage.get("params") or {}).get("epochs", 0)))
for stage in self.stage_plan
)
self.total_epochs = total_epochs if total_epochs else computed_total or epochs self.total_epochs = total_epochs if total_epochs else computed_total or epochs
self._stop_requested = False self._stop_requested = False
@@ -201,9 +199,7 @@ class TrainingWorker(QThread):
class TrainingTab(QWidget): class TrainingTab(QWidget):
"""Training tab for model training.""" """Training tab for model training."""
def __init__( def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
):
super().__init__(parent) super().__init__(parent)
self.db_manager = db_manager self.db_manager = db_manager
self.config_manager = config_manager self.config_manager = config_manager
@@ -337,18 +333,14 @@ class TrainingTab(QWidget):
self.model_version_edit = QLineEdit("v1") self.model_version_edit = QLineEdit("v1")
form_layout.addRow("Version:", self.model_version_edit) form_layout.addRow("Version:", self.model_version_edit)
default_base_model = self.config_manager.get( default_base_model = self.config_manager.get("models.default_base_model", "yolov8s-seg.pt")
"models.default_base_model", "yolov8s-seg.pt"
)
base_model_choices = self.config_manager.get("models.base_model_choices", []) base_model_choices = self.config_manager.get("models.base_model_choices", [])
self.base_model_combo = QComboBox() self.base_model_combo = QComboBox()
self.base_model_combo.addItem("Custom path…", "") self.base_model_combo.addItem("Custom path…", "")
for choice in base_model_choices: for choice in base_model_choices:
self.base_model_combo.addItem(choice, choice) self.base_model_combo.addItem(choice, choice)
self.base_model_combo.currentIndexChanged.connect( self.base_model_combo.currentIndexChanged.connect(self._on_base_model_preset_changed)
self._on_base_model_preset_changed
)
form_layout.addRow("Base Model Preset:", self.base_model_combo) form_layout.addRow("Base Model Preset:", self.base_model_combo)
base_model_layout = QHBoxLayout() base_model_layout = QHBoxLayout()
@@ -434,12 +426,8 @@ class TrainingTab(QWidget):
group_layout = QVBoxLayout() group_layout = QVBoxLayout()
self.two_stage_checkbox = QCheckBox("Enable staged head-only + full fine-tune") self.two_stage_checkbox = QCheckBox("Enable staged head-only + full fine-tune")
two_stage_defaults = ( two_stage_defaults = training_defaults.get("two_stage", {}) if training_defaults else {}
training_defaults.get("two_stage", {}) if training_defaults else {} self.two_stage_checkbox.setChecked(bool(two_stage_defaults.get("enabled", False)))
)
self.two_stage_checkbox.setChecked(
bool(two_stage_defaults.get("enabled", False))
)
self.two_stage_checkbox.toggled.connect(self._on_two_stage_toggled) self.two_stage_checkbox.toggled.connect(self._on_two_stage_toggled)
group_layout.addWidget(self.two_stage_checkbox) group_layout.addWidget(self.two_stage_checkbox)
@@ -501,9 +489,7 @@ class TrainingTab(QWidget):
stage2_group.setLayout(stage2_form) stage2_group.setLayout(stage2_form)
controls_layout.addWidget(stage2_group) controls_layout.addWidget(stage2_group)
helper_label = QLabel( helper_label = QLabel("When enabled, staged hyperparameters override the global epochs/patience/lr.")
"When enabled, staged hyperparameters override the global epochs/patience/lr."
)
helper_label.setWordWrap(True) helper_label.setWordWrap(True)
controls_layout.addWidget(helper_label) controls_layout.addWidget(helper_label)
@@ -548,9 +534,7 @@ class TrainingTab(QWidget):
if normalized == preset_value: if normalized == preset_value:
target_index = idx target_index = idx
break break
if normalized.endswith(f"/{preset_value}") or normalized.endswith( if normalized.endswith(f"/{preset_value}") or normalized.endswith(f"\\{preset_value}"):
f"\\{preset_value}"
):
target_index = idx target_index = idx
break break
self.base_model_combo.blockSignals(True) self.base_model_combo.blockSignals(True)
@@ -638,9 +622,7 @@ class TrainingTab(QWidget):
def _browse_dataset(self): def _browse_dataset(self):
"""Open a file dialog to manually select data.yaml.""" """Open a file dialog to manually select data.yaml."""
start_dir = self.config_manager.get( start_dir = self.config_manager.get("training.last_dataset_dir", "data/datasets")
"training.last_dataset_dir", "data/datasets"
)
start_path = Path(start_dir).expanduser() start_path = Path(start_dir).expanduser()
if not start_path.exists(): if not start_path.exists():
start_path = Path.cwd() start_path = Path.cwd()
@@ -676,9 +658,7 @@ class TrainingTab(QWidget):
return return
except Exception as exc: except Exception as exc:
logger.exception("Unexpected error while generating data.yaml") logger.exception("Unexpected error while generating data.yaml")
self._display_dataset_error( self._display_dataset_error("Unexpected error while generating data.yaml. Check logs for details.")
"Unexpected error while generating data.yaml. Check logs for details."
)
QMessageBox.critical( QMessageBox.critical(
self, self,
"data.yaml Generation Failed", "data.yaml Generation Failed",
@@ -755,13 +735,9 @@ class TrainingTab(QWidget):
self.selected_dataset = info self.selected_dataset = info
self.dataset_root_label.setText(info["root"]) # type: ignore[arg-type] self.dataset_root_label.setText(info["root"]) # type: ignore[arg-type]
self.train_count_label.setText( self.train_count_label.setText(self._format_split_info(info["splits"].get("train")))
self._format_split_info(info["splits"].get("train"))
)
self.val_count_label.setText(self._format_split_info(info["splits"].get("val"))) self.val_count_label.setText(self._format_split_info(info["splits"].get("val")))
self.test_count_label.setText( self.test_count_label.setText(self._format_split_info(info["splits"].get("test")))
self._format_split_info(info["splits"].get("test"))
)
self.num_classes_label.setText(str(info["num_classes"])) self.num_classes_label.setText(str(info["num_classes"]))
class_names = ", ".join(info["class_names"]) or "" class_names = ", ".join(info["class_names"]) or ""
self.class_names_label.setText(class_names) self.class_names_label.setText(class_names)
@@ -815,18 +791,12 @@ class TrainingTab(QWidget):
if split_path.exists(): if split_path.exists():
split_info["count"] = self._count_images(split_path) split_info["count"] = self._count_images(split_path)
if split_info["count"] == 0: if split_info["count"] == 0:
warnings.append( warnings.append(f"No images found for {split_name} split at {split_path}")
f"No images found for {split_name} split at {split_path}"
)
else: else:
warnings.append( warnings.append(f"{split_name.capitalize()} path does not exist: {split_path}")
f"{split_name.capitalize()} path does not exist: {split_path}"
)
else: else:
if split_name in ("train", "val"): if split_name in ("train", "val"):
warnings.append( warnings.append(f"{split_name.capitalize()} split missing in data.yaml")
f"{split_name.capitalize()} split missing in data.yaml"
)
splits[split_name] = split_info splits[split_name] = split_info
names_list = self._normalize_class_names(data.get("names")) names_list = self._normalize_class_names(data.get("names"))
@@ -844,9 +814,7 @@ class TrainingTab(QWidget):
if not names_list and nc_value: if not names_list and nc_value:
names_list = [f"class_{idx}" for idx in range(int(nc_value))] names_list = [f"class_{idx}" for idx in range(int(nc_value))]
elif nc_value and len(names_list) not in (0, int(nc_value)): elif nc_value and len(names_list) not in (0, int(nc_value)):
warnings.append( warnings.append(f"Number of class names ({len(names_list)}) does not match nc={nc_value}")
f"Number of class names ({len(names_list)}) does not match nc={nc_value}"
)
dataset_name = data.get("name") or base_path.name dataset_name = data.get("name") or base_path.name
@@ -898,16 +866,12 @@ class TrainingTab(QWidget):
class_index_map = self._build_class_index_map(dataset_info) class_index_map = self._build_class_index_map(dataset_info)
if not class_index_map: if not class_index_map:
self._append_training_log( self._append_training_log("Skipping label export: dataset classes do not match database entries.")
"Skipping label export: dataset classes do not match database entries."
)
return return
dataset_root_str = dataset_info.get("root") dataset_root_str = dataset_info.get("root")
dataset_yaml_path = dataset_info.get("yaml_path") dataset_yaml_path = dataset_info.get("yaml_path")
dataset_yaml = ( dataset_yaml = Path(dataset_yaml_path).expanduser() if dataset_yaml_path else None
Path(dataset_yaml_path).expanduser() if dataset_yaml_path else None
)
dataset_root: Optional[Path] dataset_root: Optional[Path]
if dataset_root_str: if dataset_root_str:
dataset_root = Path(dataset_root_str).resolve() dataset_root = Path(dataset_root_str).resolve()
@@ -941,7 +905,9 @@ class TrainingTab(QWidget):
if stats["registered_images"]: if stats["registered_images"]:
message += f" {stats['registered_images']} image(s) had database-backed annotations." message += f" {stats['registered_images']} image(s) had database-backed annotations."
if stats["missing_records"]: if stats["missing_records"]:
message += f" {stats['missing_records']} image(s) had no database entry; empty label files were written." message += (
f" {stats['missing_records']} image(s) had no database entry; empty label files were written."
)
split_messages.append(message) split_messages.append(message)
for msg in split_messages: for msg in split_messages:
@@ -973,9 +939,7 @@ class TrainingTab(QWidget):
continue continue
processed_images += 1 processed_images += 1
label_path = (labels_dir / image_file.relative_to(images_dir)).with_suffix( label_path = (labels_dir / image_file.relative_to(images_dir)).with_suffix(".txt")
".txt"
)
label_path.parent.mkdir(parents=True, exist_ok=True) label_path.parent.mkdir(parents=True, exist_ok=True)
found, annotation_entries = self._fetch_annotations_for_image( found, annotation_entries = self._fetch_annotations_for_image(
@@ -991,25 +955,23 @@ class TrainingTab(QWidget):
for entry in annotation_entries: for entry in annotation_entries:
polygon = entry.get("polygon") or [] polygon = entry.get("polygon") or []
if polygon: if polygon:
print(image_file, polygon[:4], polygon[-2:], entry.get("bbox"))
# coords = " ".join(f"{value:.6f}" for value in entry.get("bbox"))
# coords += " "
coords = " ".join(f"{value:.6f}" for value in polygon) coords = " ".join(f"{value:.6f}" for value in polygon)
handle.write(f"{entry['class_idx']} {coords}\n") handle.write(f"{entry['class_idx']} {coords}\n")
annotations_written += 1 annotations_written += 1
elif entry.get("bbox"): elif entry.get("bbox"):
x_center, y_center, width, height = entry["bbox"] x_center, y_center, width, height = entry["bbox"]
handle.write( handle.write(f"{entry['class_idx']} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n")
f"{entry['class_idx']} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n"
)
annotations_written += 1 annotations_written += 1
total_annotations += annotations_written total_annotations += annotations_written
cache_reset_root = labels_dir.parent cache_reset_root = labels_dir.parent
self._invalidate_split_cache(cache_reset_root) self._invalidate_split_cache(cache_reset_root)
if processed_images == 0: if processed_images == 0:
self._append_training_log( self._append_training_log(f"[{split_name}] No images found to export labels for.")
f"[{split_name}] No images found to export labels for."
)
return None return None
return { return {
@@ -1135,6 +1097,10 @@ class TrainingTab(QWidget):
xs.append(x_val) xs.append(x_val)
ys.append(y_val) ys.append(y_val)
if any(np.abs(np.array(coords[:2]) - np.array(coords[-2:])) < 1e-5):
print("Closing polygon")
coords.extend(coords[:2])
if len(coords) < 6: if len(coords) < 6:
continue continue
@@ -1147,6 +1113,11 @@ class TrainingTab(QWidget):
+ abs((min(ys) if ys else 0.0) - y_min) + abs((min(ys) if ys else 0.0) - y_min)
+ abs((max(ys) if ys else 0.0) - y_max) + abs((max(ys) if ys else 0.0) - y_max)
) )
width = max(0.0, x_max - x_min)
height = max(0.0, y_max - y_min)
x_center = x_min + width / 2.0
y_center = y_min + height / 2.0
score = (x_center, y_center, width, height)
candidates.append((score, coords)) candidates.append((score, coords))
@@ -1164,13 +1135,10 @@ class TrainingTab(QWidget):
return 1.0 return 1.0
return value return value
def _prepare_dataset_for_training( def _prepare_dataset_for_training(self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None) -> Path:
self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None
) -> Path:
dataset_info = dataset_info or ( dataset_info = dataset_info or (
self.selected_dataset self.selected_dataset
if self.selected_dataset if self.selected_dataset and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
else self._parse_dataset_yaml(dataset_yaml) else self._parse_dataset_yaml(dataset_yaml)
) )
@@ -1189,14 +1157,10 @@ class TrainingTab(QWidget):
cache_root = self._get_rgb_cache_root(dataset_yaml) cache_root = self._get_rgb_cache_root(dataset_yaml)
rgb_yaml = cache_root / "data.yaml" rgb_yaml = cache_root / "data.yaml"
if rgb_yaml.exists(): if rgb_yaml.exists():
self._append_training_log( self._append_training_log(f"Detected grayscale dataset; reusing RGB cache at {cache_root}")
f"Detected grayscale dataset; reusing RGB cache at {cache_root}"
)
return rgb_yaml return rgb_yaml
self._append_training_log( self._append_training_log(f"Detected grayscale dataset; creating RGB cache at {cache_root}")
f"Detected grayscale dataset; creating RGB cache at {cache_root}"
)
self._build_rgb_dataset(cache_root, dataset_info) self._build_rgb_dataset(cache_root, dataset_info)
return rgb_yaml return rgb_yaml
@@ -1463,15 +1427,12 @@ class TrainingTab(QWidget):
dataset_path = Path(dataset_yaml).expanduser() dataset_path = Path(dataset_yaml).expanduser()
if not dataset_path.exists(): if not dataset_path.exists():
QMessageBox.warning( QMessageBox.warning(self, "Invalid Dataset", "Selected data.yaml file does not exist.")
self, "Invalid Dataset", "Selected data.yaml file does not exist."
)
return return
dataset_info = ( dataset_info = (
self.selected_dataset self.selected_dataset
if self.selected_dataset if self.selected_dataset and self.selected_dataset.get("yaml_path") == str(dataset_path)
and self.selected_dataset.get("yaml_path") == str(dataset_path)
else self._parse_dataset_yaml(dataset_path) else self._parse_dataset_yaml(dataset_path)
) )
@@ -1480,16 +1441,12 @@ class TrainingTab(QWidget):
dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info) dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info)
if dataset_to_use != dataset_path: if dataset_to_use != dataset_path:
self._append_training_log( self._append_training_log(f"Using RGB-converted dataset at {dataset_to_use.parent}")
f"Using RGB-converted dataset at {dataset_to_use.parent}"
)
params = self._collect_training_params() params = self._collect_training_params()
stage_plan = self._compose_stage_plan(params) stage_plan = self._compose_stage_plan(params)
params["stage_plan"] = stage_plan params["stage_plan"] = stage_plan
total_planned_epochs = ( total_planned_epochs = self._calculate_total_stage_epochs(stage_plan) or params["epochs"]
self._calculate_total_stage_epochs(stage_plan) or params["epochs"]
)
params["total_planned_epochs"] = total_planned_epochs params["total_planned_epochs"] = total_planned_epochs
self._active_training_params = params self._active_training_params = params
self._training_cancelled = False self._training_cancelled = False
@@ -1498,9 +1455,7 @@ class TrainingTab(QWidget):
self._append_training_log("Two-stage fine-tuning schedule:") self._append_training_log("Two-stage fine-tuning schedule:")
self._log_stage_plan(stage_plan) self._log_stage_plan(stage_plan)
self._append_training_log( self._append_training_log(f"Starting training run '{params['run_name']}' using {params['base_model']}")
f"Starting training run '{params['run_name']}' using {params['base_model']}"
)
self.training_progress_bar.setVisible(True) self.training_progress_bar.setVisible(True)
self.training_progress_bar.setMaximum(max(1, total_planned_epochs)) self.training_progress_bar.setMaximum(max(1, total_planned_epochs))
@@ -1528,9 +1483,7 @@ class TrainingTab(QWidget):
def _stop_training(self): def _stop_training(self):
if self.training_worker and self.training_worker.isRunning(): if self.training_worker and self.training_worker.isRunning():
self._training_cancelled = True self._training_cancelled = True
self._append_training_log( self._append_training_log("Stop requested. Waiting for the current epoch to finish...")
"Stop requested. Waiting for the current epoch to finish..."
)
self.training_worker.stop() self.training_worker.stop()
self.stop_training_button.setEnabled(False) self.stop_training_button.setEnabled(False)
@@ -1566,9 +1519,7 @@ class TrainingTab(QWidget):
if worker.isRunning(): if worker.isRunning():
if not worker.wait(wait_timeout_ms): if not worker.wait(wait_timeout_ms):
logger.warning( logger.warning("Training worker did not finish within %sms", wait_timeout_ms)
"Training worker did not finish within %sms", wait_timeout_ms
)
worker.deleteLater() worker.deleteLater()
@@ -1585,16 +1536,12 @@ class TrainingTab(QWidget):
self._set_training_state(False) self._set_training_state(False)
self.training_progress_bar.setVisible(False) self.training_progress_bar.setVisible(False)
def _on_training_progress( def _on_training_progress(self, current_epoch: int, total_epochs: int, metrics: Dict[str, Any]):
self, current_epoch: int, total_epochs: int, metrics: Dict[str, Any]
):
self.training_progress_bar.setMaximum(total_epochs) self.training_progress_bar.setMaximum(total_epochs)
self.training_progress_bar.setValue(current_epoch) self.training_progress_bar.setValue(current_epoch)
parts = [f"Epoch {current_epoch}/{total_epochs}"] parts = [f"Epoch {current_epoch}/{total_epochs}"]
if metrics: if metrics:
metric_text = ", ".join( metric_text = ", ".join(f"{key}: {value:.4f}" for key, value in metrics.items())
f"{key}: {value:.4f}" for key, value in metrics.items()
)
parts.append(metric_text) parts.append(metric_text)
self._append_training_log(" | ".join(parts)) self._append_training_log(" | ".join(parts))
@@ -1621,9 +1568,7 @@ class TrainingTab(QWidget):
f"Model trained but not registered: {exc}", f"Model trained but not registered: {exc}",
) )
else: else:
QMessageBox.information( QMessageBox.information(self, "Training Complete", "Training finished successfully.")
self, "Training Complete", "Training finished successfully."
)
def _on_training_error(self, message: str): def _on_training_error(self, message: str):
self._cleanup_training_worker() self._cleanup_training_worker()
@@ -1669,9 +1614,7 @@ class TrainingTab(QWidget):
metrics=results.get("metrics"), metrics=results.get("metrics"),
) )
self._append_training_log( self._append_training_log(f"Registered model '{params['model_name']}' (ID {model_id}) at {model_path}")
f"Registered model '{params['model_name']}' (ID {model_id}) at {model_path}"
)
self._active_training_params = None self._active_training_params = None
def _set_training_state(self, is_training: bool): def _set_training_state(self, is_training: bool):
@@ -1714,9 +1657,7 @@ class TrainingTab(QWidget):
def _browse_save_dir(self): def _browse_save_dir(self):
start_path = self.save_dir_edit.text().strip() or "data/models" start_path = self.save_dir_edit.text().strip() or "data/models"
directory = QFileDialog.getExistingDirectory( directory = QFileDialog.getExistingDirectory(self, "Select Save Directory", start_path)
self, "Select Save Directory", start_path
)
if directory: if directory:
self.save_dir_edit.setText(directory) self.save_dir_edit.setText(directory)

View File

@@ -18,7 +18,7 @@ from PySide6.QtGui import (
QPaintEvent, QPaintEvent,
QPolygonF, QPolygonF,
) )
from PySide6.QtCore import Qt, QEvent, Signal, QPoint, QPointF, QRect from PySide6.QtCore import Qt, QEvent, Signal, QPoint, QPointF, QRect, QTimer
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from src.utils.image import Image, ImageLoadError from src.utils.image import Image, ImageLoadError
@@ -79,9 +79,7 @@ def rdp(points: List[Tuple[float, float]], epsilon: float) -> List[Tuple[float,
return [start, end] return [start, end]
def simplify_polyline( def simplify_polyline(points: List[Tuple[float, float]], epsilon: float) -> List[Tuple[float, float]]:
points: List[Tuple[float, float]], epsilon: float
) -> List[Tuple[float, float]]:
""" """
Simplify a polyline with RDP while preserving closure semantics. Simplify a polyline with RDP while preserving closure semantics.
@@ -145,6 +143,10 @@ class AnnotationCanvasWidget(QWidget):
self.zoom_step = 0.1 self.zoom_step = 0.1
self.zoom_wheel_step = 0.15 self.zoom_wheel_step = 0.15
# Auto-fit behavior (opt-in): when enabled, newly loaded images (and resizes)
# will scale to fill the available viewport while preserving aspect ratio.
self._auto_fit_to_view: bool = False
# Drawing / interaction state # Drawing / interaction state
self.is_drawing = False self.is_drawing = False
self.polyline_enabled = False self.polyline_enabled = False
@@ -175,6 +177,35 @@ class AnnotationCanvasWidget(QWidget):
self._setup_ui() self._setup_ui()
def set_auto_fit_to_view(self, enabled: bool):
"""Enable/disable automatic zoom-to-fit behavior."""
self._auto_fit_to_view = bool(enabled)
if self._auto_fit_to_view and self.original_pixmap is not None:
QTimer.singleShot(0, self.fit_to_view)
def fit_to_view(self, padding_px: int = 6):
"""Zoom the image so it fits the scroll area's viewport (aspect preserved)."""
if self.original_pixmap is None:
return
viewport = self.scroll_area.viewport().size()
available_w = max(1, int(viewport.width()) - int(padding_px))
available_h = max(1, int(viewport.height()) - int(padding_px))
img_w = max(1, int(self.original_pixmap.width()))
img_h = max(1, int(self.original_pixmap.height()))
scale_w = available_w / img_w
scale_h = available_h / img_h
new_scale = min(scale_w, scale_h)
new_scale = max(self.zoom_min, min(self.zoom_max, float(new_scale)))
if abs(new_scale - self.zoom_scale) < 1e-4:
return
self.zoom_scale = new_scale
self._apply_zoom()
def _setup_ui(self): def _setup_ui(self):
"""Setup user interface.""" """Setup user interface."""
layout = QVBoxLayout() layout = QVBoxLayout()
@@ -187,9 +218,7 @@ class AnnotationCanvasWidget(QWidget):
self.canvas_label = QLabel("No image loaded") self.canvas_label = QLabel("No image loaded")
self.canvas_label.setAlignment(Qt.AlignCenter) self.canvas_label.setAlignment(Qt.AlignCenter)
self.canvas_label.setStyleSheet( self.canvas_label.setStyleSheet("QLabel { background-color: #2b2b2b; color: #888; }")
"QLabel { background-color: #2b2b2b; color: #888; }"
)
self.canvas_label.setScaledContents(False) self.canvas_label.setScaledContents(False)
self.canvas_label.setMouseTracking(True) self.canvas_label.setMouseTracking(True)
@@ -212,9 +241,18 @@ class AnnotationCanvasWidget(QWidget):
self.zoom_scale = 1.0 self.zoom_scale = 1.0
self.clear_annotations() self.clear_annotations()
self._display_image() self._display_image()
logger.debug(
f"Loaded image into annotation canvas: {image.width}x{image.height}" # Defer fit-to-view until the widget has a valid viewport size.
) if self._auto_fit_to_view:
QTimer.singleShot(0, self.fit_to_view)
logger.debug(f"Loaded image into annotation canvas: {image.width}x{image.height}")
def resizeEvent(self, event):
"""Optionally keep the image fitted when the widget is resized."""
super().resizeEvent(event)
if self._auto_fit_to_view and self.original_pixmap is not None:
QTimer.singleShot(0, self.fit_to_view)
def clear(self): def clear(self):
"""Clear the displayed image and all annotations.""" """Clear the displayed image and all annotations."""
@@ -289,22 +327,14 @@ class AnnotationCanvasWidget(QWidget):
scaled_width, scaled_width,
scaled_height, scaled_height,
Qt.KeepAspectRatio, Qt.KeepAspectRatio,
( (Qt.SmoothTransformation if self.zoom_scale >= 1.0 else Qt.FastTransformation),
Qt.SmoothTransformation
if self.zoom_scale >= 1.0
else Qt.FastTransformation
),
) )
scaled_annotations = self.annotation_pixmap.scaled( scaled_annotations = self.annotation_pixmap.scaled(
scaled_width, scaled_width,
scaled_height, scaled_height,
Qt.KeepAspectRatio, Qt.KeepAspectRatio,
( (Qt.SmoothTransformation if self.zoom_scale >= 1.0 else Qt.FastTransformation),
Qt.SmoothTransformation
if self.zoom_scale >= 1.0
else Qt.FastTransformation
),
) )
# Composite image and annotations # Composite image and annotations
@@ -390,16 +420,11 @@ class AnnotationCanvasWidget(QWidget):
y = (pos.y() - offset_y) / self.zoom_scale y = (pos.y() - offset_y) / self.zoom_scale
# Check bounds # Check bounds
if ( if 0 <= x < self.original_pixmap.width() and 0 <= y < self.original_pixmap.height():
0 <= x < self.original_pixmap.width()
and 0 <= y < self.original_pixmap.height()
):
return (int(x), int(y)) return (int(x), int(y))
return None return None
def _find_polyline_at( def _find_polyline_at(self, img_x: float, img_y: float, threshold_px: float = 5.0) -> Optional[int]:
self, img_x: float, img_y: float, threshold_px: float = 5.0
) -> Optional[int]:
""" """
Find index of polyline whose geometry is within threshold_px of (img_x, img_y). Find index of polyline whose geometry is within threshold_px of (img_x, img_y).
Returns the index in self.polylines, or None if none is close enough. Returns the index in self.polylines, or None if none is close enough.
@@ -421,9 +446,7 @@ class AnnotationCanvasWidget(QWidget):
# Precise distance to all segments # Precise distance to all segments
for (x1, y1), (x2, y2) in zip(polyline[:-1], polyline[1:]): for (x1, y1), (x2, y2) in zip(polyline[:-1], polyline[1:]):
d = perpendicular_distance( d = perpendicular_distance((img_x, img_y), (float(x1), float(y1)), (float(x2), float(y2)))
(img_x, img_y), (float(x1), float(y1)), (float(x2), float(y2))
)
if d < best_dist: if d < best_dist:
best_dist = d best_dist = d
best_index = idx best_index = idx
@@ -624,11 +647,7 @@ class AnnotationCanvasWidget(QWidget):
def mouseMoveEvent(self, event: QMouseEvent): def mouseMoveEvent(self, event: QMouseEvent):
"""Handle mouse move events for drawing.""" """Handle mouse move events for drawing."""
if ( if not self.is_drawing or not self.polyline_enabled or self.annotation_pixmap is None:
not self.is_drawing
or not self.polyline_enabled
or self.annotation_pixmap is None
):
super().mouseMoveEvent(event) super().mouseMoveEvent(event)
return return
@@ -688,15 +707,10 @@ class AnnotationCanvasWidget(QWidget):
if len(simplified) >= 2: if len(simplified) >= 2:
# Store polyline and redraw all annotations # Store polyline and redraw all annotations
self._add_polyline( self._add_polyline(simplified, self.polyline_pen_color, self.polyline_pen_width)
simplified, self.polyline_pen_color, self.polyline_pen_width
)
# Convert to normalized coordinates for metadata + signal # Convert to normalized coordinates for metadata + signal
normalized_stroke = [ normalized_stroke = [self._image_to_normalized_coords(int(x), int(y)) for (x, y) in simplified]
self._image_to_normalized_coords(int(x), int(y))
for (x, y) in simplified
]
self.all_strokes.append( self.all_strokes.append(
{ {
"points": normalized_stroke, "points": normalized_stroke,
@@ -709,8 +723,7 @@ class AnnotationCanvasWidget(QWidget):
# Emit signal with normalized coordinates # Emit signal with normalized coordinates
self.annotation_drawn.emit(normalized_stroke) self.annotation_drawn.emit(normalized_stroke)
logger.debug( logger.debug(
f"Completed stroke with {len(simplified)} points " f"Completed stroke with {len(simplified)} points " f"(normalized len={len(normalized_stroke)})"
f"(normalized len={len(normalized_stroke)})"
) )
self.current_stroke = [] self.current_stroke = []
@@ -750,9 +763,7 @@ class AnnotationCanvasWidget(QWidget):
# Store polyline as [y_norm, x_norm] to match DB convention and # Store polyline as [y_norm, x_norm] to match DB convention and
# the expectations of draw_saved_polyline(). # the expectations of draw_saved_polyline().
normalized_polyline = [ normalized_polyline = [[y / img_height, x / img_width] for (x, y) in polyline]
[y / img_height, x / img_width] for (x, y) in polyline
]
logger.debug( logger.debug(
f"Polyline {idx}: {len(polyline)} points, " f"Polyline {idx}: {len(polyline)} points, "
@@ -772,7 +783,7 @@ class AnnotationCanvasWidget(QWidget):
self, self,
polyline: List[List[float]], polyline: List[List[float]],
color: str, color: str,
width: int = 3, width: int = 1,
annotation_id: Optional[int] = None, annotation_id: Optional[int] = None,
): ):
""" """
@@ -810,17 +821,13 @@ class AnnotationCanvasWidget(QWidget):
# Store and redraw using common pipeline # Store and redraw using common pipeline
pen_color = QColor(color) pen_color = QColor(color)
pen_color.setAlpha(128) # Add semi-transparency pen_color.setAlpha(255) # Add semi-transparency
self._add_polyline(img_coords, pen_color, width, annotation_id=annotation_id) self._add_polyline(img_coords, pen_color, width, annotation_id=annotation_id)
# Store in all_strokes for consistency (uses normalized coordinates) # Store in all_strokes for consistency (uses normalized coordinates)
self.all_strokes.append( self.all_strokes.append({"points": polyline, "color": color, "alpha": 255, "width": width})
{"points": polyline, "color": color, "alpha": 128, "width": width}
)
logger.debug( logger.debug(f"Drew saved polyline with {len(polyline)} points in color {color}")
f"Drew saved polyline with {len(polyline)} points in color {color}"
)
def draw_saved_bbox( def draw_saved_bbox(
self, self,
@@ -844,9 +851,7 @@ class AnnotationCanvasWidget(QWidget):
return return
if len(bbox) != 4: if len(bbox) != 4:
logger.warning( logger.warning(f"Invalid bounding box format: expected 4 values, got {len(bbox)}")
f"Invalid bounding box format: expected 4 values, got {len(bbox)}"
)
return return
# Convert normalized coordinates to image coordinates (for logging/debug) # Convert normalized coordinates to image coordinates (for logging/debug)
@@ -867,15 +872,11 @@ class AnnotationCanvasWidget(QWidget):
# in _redraw_annotations() together with all polylines. # in _redraw_annotations() together with all polylines.
pen_color = QColor(color) pen_color = QColor(color)
pen_color.setAlpha(128) # Add semi-transparency pen_color.setAlpha(128) # Add semi-transparency
self.bboxes.append( self.bboxes.append([float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)])
[float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)]
)
self.bbox_meta.append({"color": pen_color, "width": int(width), "label": label}) self.bbox_meta.append({"color": pen_color, "width": int(width), "label": label})
# Store in all_strokes for consistency # Store in all_strokes for consistency
self.all_strokes.append( self.all_strokes.append({"bbox": bbox, "color": color, "alpha": 128, "width": width, "label": label})
{"bbox": bbox, "color": color, "alpha": 128, "width": width, "label": label}
)
# Redraw overlay (polylines + all bounding boxes) # Redraw overlay (polylines + all bounding boxes)
self._redraw_annotations() self._redraw_annotations()

View File

@@ -96,9 +96,7 @@ class YOLOWrapper:
try: try:
logger.info(f"Starting training: {name}") logger.info(f"Starting training: {name}")
logger.info( logger.info(f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}")
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
)
# Defaults for 16-bit safety: disable augmentations that force uint8 and HSV ops that assume 0..255. # Defaults for 16-bit safety: disable augmentations that force uint8 and HSV ops that assume 0..255.
# Users can override by passing explicit kwargs. # Users can override by passing explicit kwargs.
@@ -149,9 +147,7 @@ class YOLOWrapper:
try: try:
logger.info(f"Starting validation on {split} split") logger.info(f"Starting validation on {split} split")
results = self.model.val( results = self.model.val(data=data_yaml, split=split, device=self.device, **kwargs)
data=data_yaml, split=split, device=self.device, **kwargs
)
logger.info("Validation completed successfully") logger.info("Validation completed successfully")
return self._format_validation_results(results) return self._format_validation_results(results)
@@ -190,11 +186,9 @@ class YOLOWrapper:
raise RuntimeError(f"Failed to load model from {self.model_path}") raise RuntimeError(f"Failed to load model from {self.model_path}")
prepared_source, cleanup_path = self._prepare_source(source) prepared_source, cleanup_path = self._prepare_source(source)
imgsz = 1088
try: try:
logger.info( logger.info(f"Running inference on {source} -> prepared_source {prepared_source}")
f"Running inference on {source} -> prepared_source {prepared_source}"
)
results = self.model.predict( results = self.model.predict(
source=source, source=source,
conf=conf, conf=conf,
@@ -203,6 +197,7 @@ class YOLOWrapper:
save_txt=save_txt, save_txt=save_txt,
save_conf=save_conf, save_conf=save_conf,
device=self.device, device=self.device,
imgsz=imgsz,
**kwargs, **kwargs,
) )
@@ -218,13 +213,9 @@ class YOLOWrapper:
try: try:
os.remove(cleanup_path) os.remove(cleanup_path)
except OSError as cleanup_error: except OSError as cleanup_error:
logger.warning( logger.warning(f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}")
f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}"
)
def export( def export(self, format: str = "onnx", output_path: Optional[str] = None, **kwargs) -> str:
self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
) -> str:
""" """
Export model to different format. Export model to different format.
@@ -265,9 +256,7 @@ class YOLOWrapper:
tmp.close() tmp.close()
img_obj.save(tmp_path) img_obj.save(tmp_path)
cleanup_path = tmp_path cleanup_path = tmp_path
logger.info( logger.info(f"Converted image {source_path} to RGB for inference at {tmp_path}")
f"Converted image {source_path} to RGB for inference at {tmp_path}"
)
return tmp_path, cleanup_path return tmp_path, cleanup_path
except Exception as convert_error: except Exception as convert_error:
logger.warning( logger.warning(
@@ -280,9 +269,7 @@ class YOLOWrapper:
"""Format training results into dictionary.""" """Format training results into dictionary."""
try: try:
# Get the results dict # Get the results dict
results_dict = ( results_dict = results.results_dict if hasattr(results, "results_dict") else {}
results.results_dict if hasattr(results, "results_dict") else {}
)
formatted = { formatted = {
"success": True, "success": True,
@@ -315,9 +302,7 @@ class YOLOWrapper:
"mAP50-95": float(box_metrics.map), "mAP50-95": float(box_metrics.map),
"precision": float(box_metrics.mp), "precision": float(box_metrics.mp),
"recall": float(box_metrics.mr), "recall": float(box_metrics.mr),
"fitness": ( "fitness": (float(results.fitness) if hasattr(results, "fitness") else 0.0),
float(results.fitness) if hasattr(results, "fitness") else 0.0
),
} }
# Add per-class metrics if available # Add per-class metrics if available
@@ -327,11 +312,7 @@ class YOLOWrapper:
if idx < len(box_metrics.ap): if idx < len(box_metrics.ap):
class_metrics[name] = { class_metrics[name] = {
"ap": float(box_metrics.ap[idx]), "ap": float(box_metrics.ap[idx]),
"ap50": ( "ap50": (float(box_metrics.ap50[idx]) if hasattr(box_metrics, "ap50") else 0.0),
float(box_metrics.ap50[idx])
if hasattr(box_metrics, "ap50")
else 0.0
),
} }
formatted["class_metrics"] = class_metrics formatted["class_metrics"] = class_metrics
@@ -364,21 +345,15 @@ class YOLOWrapper:
"class_id": int(boxes.cls[i]), "class_id": int(boxes.cls[i]),
"class_name": result.names[int(boxes.cls[i])], "class_name": result.names[int(boxes.cls[i])],
"confidence": float(boxes.conf[i]), "confidence": float(boxes.conf[i]),
"bbox_normalized": [ "bbox_normalized": [float(v) for v in xyxyn], # [x_min, y_min, x_max, y_max]
float(v) for v in xyxyn "bbox_absolute": [float(v) for v in boxes.xyxy[i].cpu().numpy()], # Absolute pixels
], # [x_min, y_min, x_max, y_max]
"bbox_absolute": [
float(v) for v in boxes.xyxy[i].cpu().numpy()
], # Absolute pixels
} }
# Extract segmentation mask if available # Extract segmentation mask if available
if has_masks: if has_masks:
try: try:
# Get the mask for this detection # Get the mask for this detection
mask_data = result.masks.xy[ mask_data = result.masks.xy[i] # Polygon coordinates in absolute pixels
i
] # Polygon coordinates in absolute pixels
# Convert to normalized coordinates # Convert to normalized coordinates
if len(mask_data) > 0: if len(mask_data) > 0:
@@ -391,9 +366,7 @@ class YOLOWrapper:
else: else:
detection["segmentation_mask"] = None detection["segmentation_mask"] = None
except Exception as mask_error: except Exception as mask_error:
logger.warning( logger.warning(f"Error extracting mask for detection {i}: {mask_error}")
f"Error extracting mask for detection {i}: {mask_error}"
)
detection["segmentation_mask"] = None detection["segmentation_mask"] = None
else: else:
detection["segmentation_mask"] = None detection["segmentation_mask"] = None
@@ -407,9 +380,7 @@ class YOLOWrapper:
return [] return []
@staticmethod @staticmethod
def convert_bbox_format( def convert_bbox_format(bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy") -> List[float]:
bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy"
) -> List[float]:
""" """
Convert bounding box between formats. Convert bounding box between formats.

View File

@@ -37,7 +37,7 @@ def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.5) -> np.ndarray:
a1[a1 > p999] = p999 a1[a1 > p999] = p999
a1 /= a1.max() a1 /= a1.max()
if 0: if 1:
a2 = a1.copy() a2 = a1.copy()
a2 = a2**gamma a2 = a2**gamma
a2 /= a2.max() a2 /= a2.max()
@@ -47,9 +47,12 @@ def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.5) -> np.ndarray:
a3[a3 > p9999] = p9999 a3[a3 > p9999] = p9999
a3 /= a3.max() a3 /= a3.max()
return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0) # return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
# return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0) # return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
# return np.stack([a1, a2, a3], axis=0) out = np.stack([a1, a2, a3], axis=0)
# print(any(np.isnan(out).flatten()))
return out
class ImageLoadError(Exception): class ImageLoadError(Exception):
@@ -122,7 +125,7 @@ class Image:
if self.path.suffix.lower() in [".tif", ".tiff"]: if self.path.suffix.lower() in [".tif", ".tiff"]:
self._data = imread(str(self.path)) self._data = imread(str(self.path))
else: else:
raise NotImplementedError("RGB is not implemented") # raise NotImplementedError("RGB is not implemented")
# Load with OpenCV (returns BGR format) # Load with OpenCV (returns BGR format)
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED) self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
@@ -246,27 +249,33 @@ class Image:
if self.channels == 1: if self.channels == 1:
img = get_pseudo_rgb(self.data) img = get_pseudo_rgb(self.data)
self._dtype = img.dtype self._dtype = img.dtype
return img return img, True
raise NotImplementedError
if self._channels == 3: elif self._channels == 3:
return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB) return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB), False
elif self._channels == 4: elif self._channels == 4:
return cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA) return cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA), False
else: else:
return self._data raise NotImplementedError
# else:
# return self._data
def get_qt_rgb(self) -> np.ascontiguousarray: def get_qt_rgb(self) -> np.ascontiguousarray:
# we keep data as (C, H, W) # we keep data as (C, H, W)
_img = self.get_rgb() _img, pseudo = self.get_rgb()
img = np.zeros((self.height, self.width, 4), dtype=np.float32) if pseudo:
img[..., 0] = _img[0] # R gradient img = np.zeros((self.height, self.width, 4), dtype=np.float32)
img[..., 1] = _img[1] # G gradient img[..., 0] = _img[0] # R gradient
img[..., 2] = _img[2] # B constant img[..., 1] = _img[1] # G gradient
img[..., 3] = 1.0 # A = 1.0 (opaque) img[..., 2] = _img[2] # B constant
img[..., 3] = 1.0 # A = 1.0 (opaque)
return np.ascontiguousarray(img) return np.ascontiguousarray(img)
else:
return np.ascontiguousarray(_img)
def get_grayscale(self) -> np.ndarray: def get_grayscale(self) -> np.ndarray:
""" """

View File

@@ -114,11 +114,12 @@ class Label:
return truth_val return truth_val
def to_string(self, bbox: list = None, polygon: list = None): def to_string(self, bbox: list = None, polygon: list = None):
coords = ""
if bbox is None: if bbox is None:
bbox = self.bbox bbox = self.bbox
# coords += " ".join([f"{x:.6f}" for x in self.bbox])
if polygon is None: if polygon is None:
polygon = self.polygon polygon = self.polygon
coords = " ".join([f"{x:.6f}" for x in self.bbox])
if self.polygon is not None: if self.polygon is not None:
coords += " " + " ".join([f"{x:.6f} {y:.6f}" for x, y in self.polygon]) coords += " " + " ".join([f"{x:.6f} {y:.6f}" for x, y in self.polygon])
return f"{self.class_id} {coords}" return f"{self.class_id} {coords}"
@@ -179,6 +180,13 @@ class ImageSplitter:
for i in range(patch_size[0]): for i in range(patch_size[0]):
for j in range(patch_size[1]): for j in range(patch_size[1]):
metadata = {
"image_path": str(self.image_path),
"label_path": str(self.label_path),
"tile_section": f"{i}, {j}",
"tile_size": f"{hstep}, {wstep}",
"patch_size": f"{patch_size[0]}, {patch_size[1]}",
}
tile_reference = f"i{i}j{j}" tile_reference = f"i{i}j{j}"
hrange = (i * hstep / h, (i + 1) * hstep / h) hrange = (i * hstep / h, (i + 1) * hstep / h)
wrange = (j * wstep / w, (j + 1) * wstep / w) wrange = (j * wstep / w, (j + 1) * wstep / w)
@@ -199,7 +207,7 @@ class ImageSplitter:
print(l.bbox) print(l.bbox)
# print(labels) # print(labels)
yield tile_reference, tile, labels yield tile_reference, tile, labels, metadata
def split_respective_to_label(self, padding: int = 67): def split_respective_to_label(self, padding: int = 67):
if self.labels is None: if self.labels is None:
@@ -208,6 +216,7 @@ class ImageSplitter:
for i, label in enumerate(self.labels): for i, label in enumerate(self.labels):
tile_reference = f"_lbl-{i+1:02d}" tile_reference = f"_lbl-{i+1:02d}"
# print(label.bbox) # print(label.bbox)
metadata = {"image_path": str(self.image_path), "label_path": str(self.label_path), "label_index": str(i)}
xc_norm, yc_norm, h_norm, w_norm = label.bbox # normalized coords xc_norm, yc_norm, h_norm, w_norm = label.bbox # normalized coords
xc, yc, h, w = [ xc, yc, h, w = [
@@ -246,17 +255,17 @@ class ImageSplitter:
# print("tile shape:", tile.shape) # print("tile shape:", tile.shape)
yolo_annotation = f"{label.class_id} {x_offset/nx} {y_offset/ny} {h_norm} {w_norm} " yolo_annotation = f"{label.class_id} " # {x_offset/nx} {y_offset/ny} {h_norm} {w_norm} "
print(yolo_annotation)
yolo_annotation += " ".join( yolo_annotation += " ".join(
[ [
f"{(x*self.image.shape[1]-(xc - x_offset))/nx:.6f} {(y*self.image.shape[0]-(yc-y_offset))/ny:.6f}" f"{(x*self.image.shape[1]-(xc - x_offset))/nx:.6f} {(y*self.image.shape[0]-(yc-y_offset))/ny:.6f}"
for x, y in label.polygon for x, y in label.polygon
] ]
) )
print(yolo_annotation)
new_label = Label(yolo_annotation=yolo_annotation) new_label = Label(yolo_annotation=yolo_annotation)
yield tile_reference, tile, [new_label] yield tile_reference, tile, [new_label], metadata
def main(args): def main(args):
@@ -278,9 +287,9 @@ def main(args):
else: else:
data = data.split_into_tiles(patch_size=args.patch_size) data = data.split_into_tiles(patch_size=args.patch_size)
for tile_reference, tile, labels in data: for tile_reference, tile, labels, metadata in data:
print() print()
print(tile_reference, tile.shape, labels) # len(labels) if labels else None) print(tile_reference, tile.shape, labels, metadata) # len(labels) if labels else None)
# { debug # { debug
debug = False debug = False
@@ -310,15 +319,21 @@ def main(args):
# } debug # } debug
if args.output: if args.output:
imwrite(args.output / "images" / f"{image_path.stem}_{tile_reference}.tif", tile) # imwrite(args.output / "images" / f"{image_path.stem}_{tile_reference}.tif", tile, metadata=metadata)
scale = 5 scale = 5
tile_zoomed = zoom(tile, zoom=scale) tile_zoomed = zoom(tile, zoom=scale)
imwrite(args.output / "images-zoomed" / f"{image_path.stem}_{tile_reference}.tif", tile_zoomed) metadata["scale"] = scale
imwrite(
args.output / "images" / f"{image_path.stem}_{tile_reference}.tif",
tile_zoomed,
metadata=metadata,
imagej=True,
)
if labels is not None: if labels is not None:
with open(args.output / "labels" / f"{image_path.stem}_{tile_reference}.txt", "w") as f: with open(args.output / "labels" / f"{image_path.stem}_{tile_reference}.txt", "w") as f:
for label in labels: for label in labels:
label.offset_label(tile.shape[1], tile.shape[0]) # label.offset_label(tile.shape[1], tile.shape[0])
f.write(label.to_string() + "\n") f.write(label.to_string() + "\n")

View File

@@ -72,8 +72,9 @@ def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
# logger.info(f"Loading with monkey-patched imread: {filename}") # logger.info(f"Loading with monkey-patched imread: {filename}")
arr = arr.astype(np.float32) arr = arr.astype(np.float32)
arr /= arr.max() arr /= arr.max()
arr *= 2**16 - 1 arr *= 2**8 - 1
arr = arr.astype(np.uint16) arr = arr.astype(np.uint8)
# print(arr.shape, arr.dtype, any(np.isnan(arr).flatten()), np.where(np.isnan(arr)), arr.min(), arr.max())
return np.ascontiguousarray(arr) return np.ascontiguousarray(arr)
# logger.info(f"Loading with original imread: {filename}") # logger.info(f"Loading with original imread: {filename}")
@@ -105,7 +106,7 @@ def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
def preprocess_batch_16bit(self, batch: dict) -> dict: # type: ignore[override] def preprocess_batch_16bit(self, batch: dict) -> dict: # type: ignore[override]
# Start from upstream behavior to keep device placement + multiscale identical, # Start from upstream behavior to keep device placement + multiscale identical,
# but replace the 255 division with dtype-aware scaling. # but replace the 255 division with dtype-aware scaling.
logger.info(f"Preprocessing batch with monkey-patched preprocess_batch") # logger.info(f"Preprocessing batch with monkey-patched preprocess_batch")
for k, v in batch.items(): for k, v in batch.items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda") batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")

View File

@@ -196,7 +196,9 @@ def main():
bbox = np.array(bbox) * np.array([img.shape[1], img.shape[0], img.shape[1], img.shape[0]]) bbox = np.array(bbox) * np.array([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
yc, xc, h, w = bbox yc, xc, h, w = bbox
print("bbox", bbox) print("bbox", bbox)
polyline = np.array(coords[4:]).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
# polyline = np.array(coords[4:]).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
polyline = np.array(coords).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
print("pl", coords[4:]) print("pl", coords[4:])
print("pl", polyline) print("pl", polyline)
@@ -207,12 +209,13 @@ def main():
plt.figure(figsize=(10, 10 * out.shape[0] / out.shape[1])) plt.figure(figsize=(10, 10 * out.shape[0] / out.shape[1]))
plt.imshow(out_rgb) plt.imshow(out_rgb)
plt.plot(polyline[:, 0], polyline[:, 1], "y", linewidth=2) plt.plot(polyline[:, 0], polyline[:, 1], "y", linewidth=2)
plt.plot( if 0:
[yc - h / 2, yc - h / 2, yc + h / 2, yc + h / 2, yc - h / 2], plt.plot(
[xc - w / 2, xc + w / 2, xc + w / 2, xc - w / 2, xc - w / 2], [yc - h / 2, yc - h / 2, yc + h / 2, yc + h / 2, yc - h / 2],
"r", [xc - w / 2, xc + w / 2, xc + w / 2, xc - w / 2, xc - w / 2],
linewidth=2, "r",
) linewidth=2,
)
# plt.axis("off") # plt.axis("off")
plt.title(f"{img_path.name} ({lbl_path.name})") plt.title(f"{img_path.name} ({lbl_path.name})")