From c0684a9c14d9250763ff8dde5351aa1fbaf20fa1 Mon Sep 17 00:00:00 2001 From: Martin Laasmaa Date: Thu, 11 Dec 2025 12:04:08 +0200 Subject: [PATCH] Implementing 2 stage training --- config/app_config.yaml | 14 ++ src/gui/tabs/results_tab.py | 13 +- src/gui/tabs/training_tab.py | 226 ++++++++++++++++++++ src/gui/widgets/annotation_canvas_widget.py | 49 ++++- src/utils/config_manager.py | 18 ++ 5 files changed, 315 insertions(+), 5 deletions(-) diff --git a/config/app_config.yaml b/config/app_config.yaml index 7aa8e16..bc3d8bd 100644 --- a/config/app_config.yaml +++ b/config/app_config.yaml @@ -12,12 +12,26 @@ image_repository: models: default_base_model: yolov8s-seg.pt models_directory: data/models + base_model_choices: + - yolov8s-seg.pt + - yolov11s-seg.pt training: default_epochs: 100 default_batch_size: 16 default_imgsz: 640 default_patience: 50 default_lr0: 0.01 + two_stage: + enabled: false + stage1: + epochs: 20 + lr0: 0.0005 + patience: 10 + freeze: 10 + stage2: + epochs: 150 + lr0: 0.0003 + patience: 30 last_dataset_yaml: /home/martin/code/object_detection/data/datasets/data.yaml last_dataset_dir: /home/martin/code/object_detection/data/datasets detection: diff --git a/src/gui/tabs/results_tab.py b/src/gui/tabs/results_tab.py index 530ff53..59c0c8e 100644 --- a/src/gui/tabs/results_tab.py +++ b/src/gui/tabs/results_tab.py @@ -117,8 +117,14 @@ class ResultsTab(QWidget): self.show_bboxes_checkbox = QCheckBox("Show Bounding Boxes") self.show_bboxes_checkbox.setChecked(True) self.show_bboxes_checkbox.stateChanged.connect(self._toggle_bboxes) + self.show_confidence_checkbox = QCheckBox("Show Confidence") + self.show_confidence_checkbox.setChecked(False) + self.show_confidence_checkbox.stateChanged.connect( + self._apply_detection_overlays + ) toggles_layout.addWidget(self.show_masks_checkbox) toggles_layout.addWidget(self.show_bboxes_checkbox) + toggles_layout.addWidget(self.show_confidence_checkbox) toggles_layout.addStretch() preview_layout.addLayout(toggles_layout) @@ -312,7 +318,12 @@ class ResultsTab(QWidget): det.get("y_max"), ] if all(v is not None for v in bbox): - self.preview_canvas.draw_saved_bbox(bbox, color) + label = None + if self.show_confidence_checkbox.isChecked(): + confidence = det.get("confidence") + if confidence is not None: + label = f"{confidence:.2f}" + self.preview_canvas.draw_saved_bbox(bbox, color, label=label) def _convert_mask(self, mask_points: List[List[float]]) -> List[List[float]]: """Convert stored [x, y] masks to [y, x] format for the canvas.""" diff --git a/src/gui/tabs/training_tab.py b/src/gui/tabs/training_tab.py index b6ba528..7c3bbf3 100644 --- a/src/gui/tabs/training_tab.py +++ b/src/gui/tabs/training_tab.py @@ -28,6 +28,7 @@ from PySide6.QtWidgets import ( QProgressBar, QSpinBox, QDoubleSpinBox, + QCheckBox, ) from src.database.db_manager import DatabaseManager @@ -249,13 +250,26 @@ class TrainingTab(QWidget): default_base_model = self.config_manager.get( "models.default_base_model", "yolov8s-seg.pt" ) + base_model_choices = self.config_manager.get("models.base_model_choices", []) + + self.base_model_combo = QComboBox() + self.base_model_combo.addItem("Custom path…", "") + for choice in base_model_choices: + self.base_model_combo.addItem(choice, choice) + self.base_model_combo.currentIndexChanged.connect( + self._on_base_model_preset_changed + ) + form_layout.addRow("Base Model Preset:", self.base_model_combo) + base_model_layout = QHBoxLayout() self.base_model_edit = QLineEdit(default_base_model) + self.base_model_edit.editingFinished.connect(self._on_base_model_path_edited) base_model_layout.addWidget(self.base_model_edit) self.base_model_browse_button = QPushButton("Browse…") self.base_model_browse_button.clicked.connect(self._browse_base_model) base_model_layout.addWidget(self.base_model_browse_button) form_layout.addRow("Base Model (.pt):", base_model_layout) + self._sync_base_model_preset_selection(default_base_model) models_dir = self.config_manager.get("models.models_directory", "data/models") save_dir_layout = QHBoxLayout() @@ -298,6 +312,9 @@ class TrainingTab(QWidget): group_layout.addLayout(form_layout) + self.two_stage_group = self._create_two_stage_group(training_defaults) + group_layout.addWidget(self.two_stage_group) + button_layout = QHBoxLayout() self.start_training_button = QPushButton("Start Training") self.start_training_button.clicked.connect(self._start_training) @@ -322,6 +339,134 @@ class TrainingTab(QWidget): group.setLayout(group_layout) return group + def _create_two_stage_group(self, training_defaults: Dict[str, Any]) -> QGroupBox: + group = QGroupBox("Two-Stage Fine-Tuning") + group_layout = QVBoxLayout() + + self.two_stage_checkbox = QCheckBox("Enable staged head-only + full fine-tune") + two_stage_defaults = ( + training_defaults.get("two_stage", {}) if training_defaults else {} + ) + self.two_stage_checkbox.setChecked( + bool(two_stage_defaults.get("enabled", False)) + ) + self.two_stage_checkbox.toggled.connect(self._on_two_stage_toggled) + group_layout.addWidget(self.two_stage_checkbox) + + self.two_stage_controls_widget = QWidget() + controls_layout = QVBoxLayout() + controls_layout.setContentsMargins(0, 0, 0, 0) + controls_layout.setSpacing(8) + + stage1_group = QGroupBox("Stage 1 — Head-only stabilization") + stage1_form = QFormLayout() + stage1_defaults = two_stage_defaults.get("stage1", {}) + + self.stage1_epochs_spin = QSpinBox() + self.stage1_epochs_spin.setRange(1, 500) + self.stage1_epochs_spin.setValue(int(stage1_defaults.get("epochs", 20))) + stage1_form.addRow("Epochs:", self.stage1_epochs_spin) + + self.stage1_lr_spin = QDoubleSpinBox() + self.stage1_lr_spin.setDecimals(5) + self.stage1_lr_spin.setRange(0.00001, 0.1) + self.stage1_lr_spin.setSingleStep(0.0005) + self.stage1_lr_spin.setValue(float(stage1_defaults.get("lr0", 0.0005))) + stage1_form.addRow("Learning Rate:", self.stage1_lr_spin) + + self.stage1_patience_spin = QSpinBox() + self.stage1_patience_spin.setRange(1, 200) + self.stage1_patience_spin.setValue(int(stage1_defaults.get("patience", 10))) + stage1_form.addRow("Patience:", self.stage1_patience_spin) + + self.stage1_freeze_spin = QSpinBox() + self.stage1_freeze_spin.setRange(0, 24) + self.stage1_freeze_spin.setValue(int(stage1_defaults.get("freeze", 10))) + stage1_form.addRow("Freeze layers:", self.stage1_freeze_spin) + + stage1_group.setLayout(stage1_form) + controls_layout.addWidget(stage1_group) + + stage2_group = QGroupBox("Stage 2 — Full fine-tuning") + stage2_form = QFormLayout() + stage2_defaults = two_stage_defaults.get("stage2", {}) + + self.stage2_epochs_spin = QSpinBox() + self.stage2_epochs_spin.setRange(1, 2000) + self.stage2_epochs_spin.setValue(int(stage2_defaults.get("epochs", 150))) + stage2_form.addRow("Epochs:", self.stage2_epochs_spin) + + self.stage2_lr_spin = QDoubleSpinBox() + self.stage2_lr_spin.setDecimals(5) + self.stage2_lr_spin.setRange(0.00001, 0.1) + self.stage2_lr_spin.setSingleStep(0.0005) + self.stage2_lr_spin.setValue(float(stage2_defaults.get("lr0", 0.0003))) + stage2_form.addRow("Learning Rate:", self.stage2_lr_spin) + + self.stage2_patience_spin = QSpinBox() + self.stage2_patience_spin.setRange(1, 200) + self.stage2_patience_spin.setValue(int(stage2_defaults.get("patience", 30))) + stage2_form.addRow("Patience:", self.stage2_patience_spin) + + stage2_group.setLayout(stage2_form) + controls_layout.addWidget(stage2_group) + + helper_label = QLabel( + "When enabled, staged hyperparameters override the global epochs/patience/lr." + ) + helper_label.setWordWrap(True) + controls_layout.addWidget(helper_label) + + self.two_stage_controls_widget.setLayout(controls_layout) + group_layout.addWidget(self.two_stage_controls_widget) + + group.setLayout(group_layout) + self._on_two_stage_toggled(self.two_stage_checkbox.isChecked()) + return group + + def _on_two_stage_toggled(self, checked: bool): + self._refresh_two_stage_controls_enabled(checked) + + def _refresh_two_stage_controls_enabled(self, checked: Optional[bool] = None): + if not hasattr(self, "two_stage_controls_widget"): + return + desired_state = checked + if desired_state is None: + desired_state = self.two_stage_checkbox.isChecked() + can_edit = self.two_stage_checkbox.isEnabled() + self.two_stage_controls_widget.setEnabled(bool(desired_state and can_edit)) + + def _on_base_model_preset_changed(self, index: int): + preset_value = self.base_model_combo.itemData(index) + if preset_value: + self.base_model_edit.setText(str(preset_value)) + elif index == 0: + self.base_model_edit.setFocus() + + def _on_base_model_path_edited(self): + self._sync_base_model_preset_selection(self.base_model_edit.text().strip()) + + def _sync_base_model_preset_selection(self, model_path: str): + if not hasattr(self, "base_model_combo"): + return + normalized = (model_path or "").strip() + target_index = 0 + for idx in range(1, self.base_model_combo.count()): + preset_value = self.base_model_combo.itemData(idx) + if not preset_value: + continue + if normalized == preset_value: + target_index = idx + break + if normalized.endswith(f"/{preset_value}") or normalized.endswith( + f"\\{preset_value}" + ): + target_index = idx + break + self.base_model_combo.blockSignals(True) + self.base_model_combo.setCurrentIndex(target_index) + self.base_model_combo.blockSignals(False) + def _get_dataset_search_roots(self) -> List[Path]: roots: List[Path] = [] default_root = Path("data/datasets").expanduser() @@ -346,6 +491,7 @@ class TrainingTab(QWidget): for yaml_path in root.rglob("*.yaml"): if yaml_path.name.lower() not in {"data.yaml", "dataset.yaml"}: continue + discovered.append(yaml_path.resolve()) except Exception as exc: logger.warning(f"Unable to scan {root}: {exc}") @@ -964,6 +1110,66 @@ class TrainingTab(QWidget): self._build_rgb_dataset(cache_root, dataset_info) return rgb_yaml + def _compose_stage_plan(self, params: Dict[str, Any]) -> List[Dict[str, Any]]: + two_stage = params.get("two_stage") or {} + base_stage = { + "label": "Single Stage", + "model_path": params["base_model"], + "use_previous_best": False, + "params": { + "epochs": params["epochs"], + "batch": params["batch"], + "imgsz": params["imgsz"], + "patience": params["patience"], + "lr0": params["lr0"], + "freeze": 0, + "name": params["run_name"], + }, + } + + if not two_stage.get("enabled"): + return [base_stage] + + stage_plan: List[Dict[str, Any]] = [] + stage1 = two_stage.get("stage1", {}) + stage2 = two_stage.get("stage2", {}) + + stage_plan.append( + { + "label": "Stage 1 — Head-only", + "model_path": params["base_model"], + "use_previous_best": False, + "params": { + "epochs": stage1.get("epochs", params["epochs"]), + "batch": params["batch"], + "imgsz": params["imgsz"], + "patience": stage1.get("patience", params["patience"]), + "lr0": stage1.get("lr0", params["lr0"]), + "freeze": stage1.get("freeze", 0), + "name": f"{params['run_name']}_head_ft", + }, + } + ) + + stage_plan.append( + { + "label": "Stage 2 — Full", + "model_path": params["base_model"], + "use_previous_best": True, + "params": { + "epochs": stage2.get("epochs", params["epochs"]), + "batch": params["batch"], + "imgsz": params["imgsz"], + "patience": stage2.get("patience", params["patience"]), + "lr0": stage2.get("lr0", params["lr0"]), + "freeze": 0, + "name": f"{params['run_name']}_full_ft", + }, + } + ) + + return stage_plan + def _get_rgb_cache_root(self, dataset_yaml: Path) -> Path: cache_base = Path("data/datasets/_rgb_cache") cache_base.mkdir(parents=True, exist_ok=True) @@ -1085,6 +1291,21 @@ class TrainingTab(QWidget): save_dir_path.mkdir(parents=True, exist_ok=True) run_name = f"{model_name}_{model_version}".replace(" ", "_") + two_stage_config = { + "enabled": self.two_stage_checkbox.isChecked(), + "stage1": { + "epochs": self.stage1_epochs_spin.value(), + "lr0": self.stage1_lr_spin.value(), + "patience": self.stage1_patience_spin.value(), + "freeze": self.stage1_freeze_spin.value(), + }, + "stage2": { + "epochs": self.stage2_epochs_spin.value(), + "lr0": self.stage2_lr_spin.value(), + "patience": self.stage2_patience_spin.value(), + }, + } + return { "model_name": model_name, "model_version": model_version, @@ -1096,6 +1317,7 @@ class TrainingTab(QWidget): "imgsz": self.imgsz_spin.value(), "patience": self.patience_spin.value(), "lr0": self.lr_spin.value(), + "two_stage": two_stage_config, } def _start_training(self): @@ -1315,6 +1537,7 @@ class TrainingTab(QWidget): self.rescan_button.setEnabled(not is_training) self.model_name_edit.setEnabled(not is_training) self.model_version_edit.setEnabled(not is_training) + self.base_model_combo.setEnabled(not is_training) self.base_model_edit.setEnabled(not is_training) self.base_model_browse_button.setEnabled(not is_training) self.save_dir_edit.setEnabled(not is_training) @@ -1324,6 +1547,8 @@ class TrainingTab(QWidget): self.imgsz_spin.setEnabled(not is_training) self.patience_spin.setEnabled(not is_training) self.lr_spin.setEnabled(not is_training) + self.two_stage_checkbox.setEnabled(not is_training) + self._refresh_two_stage_controls_enabled() def _append_training_log(self, message: str): timestamp = datetime.now().strftime("%H:%M:%S") @@ -1339,6 +1564,7 @@ class TrainingTab(QWidget): ) if file_path: self.base_model_edit.setText(file_path) + self._sync_base_model_preset_selection(file_path) def _browse_save_dir(self): start_path = self.save_dir_edit.text().strip() or "data/models" diff --git a/src/gui/widgets/annotation_canvas_widget.py b/src/gui/widgets/annotation_canvas_widget.py index e706a6f..baff64d 100644 --- a/src/gui/widgets/annotation_canvas_widget.py +++ b/src/gui/widgets/annotation_canvas_widget.py @@ -17,7 +17,7 @@ from PySide6.QtGui import ( QMouseEvent, QPaintEvent, ) -from PySide6.QtCore import Qt, QEvent, Signal, QPoint +from PySide6.QtCore import Qt, QEvent, Signal, QPoint, QRect from typing import Any, Dict, List, Optional, Tuple from src.utils.image import Image, ImageLoadError @@ -529,6 +529,40 @@ class AnnotationCanvasWidget(QWidget): painter.setPen(pen) painter.drawRect(x_min, y_min, rect_width, rect_height) + label_text = meta.get("label") + if label_text: + painter.save() + font = painter.font() + font.setPointSizeF(max(10.0, width + 4)) + painter.setFont(font) + metrics = painter.fontMetrics() + text_width = metrics.horizontalAdvance(label_text) + text_height = metrics.height() + padding = 4 + bg_width = text_width + padding * 2 + bg_height = text_height + padding * 2 + canvas_width = self.original_pixmap.width() + canvas_height = self.original_pixmap.height() + bg_x = max(0, min(x_min, canvas_width - bg_width)) + bg_y = y_min - bg_height + if bg_y < 0: + bg_y = min(y_min, canvas_height - bg_height) + bg_y = max(0, bg_y) + background_rect = QRect(bg_x, bg_y, bg_width, bg_height) + background_color = QColor(pen_color) + background_color.setAlpha(220) + painter.fillRect(background_rect, background_color) + text_color = QColor(0, 0, 0) + if background_color.lightness() < 128: + text_color = QColor(255, 255, 255) + painter.setPen(text_color) + painter.drawText( + background_rect.adjusted(padding, padding, -padding, -padding), + Qt.AlignLeft | Qt.AlignVCenter, + label_text, + ) + painter.restore() + painter.end() self._update_display() @@ -787,7 +821,13 @@ class AnnotationCanvasWidget(QWidget): f"Drew saved polyline with {len(polyline)} points in color {color}" ) - def draw_saved_bbox(self, bbox: List[float], color: str, width: int = 3): + def draw_saved_bbox( + self, + bbox: List[float], + color: str, + width: int = 3, + label: Optional[str] = None, + ): """ Draw a bounding box from database coordinates onto the annotation canvas. @@ -796,6 +836,7 @@ class AnnotationCanvasWidget(QWidget): in normalized coordinates (0-1) color: Color hex string (e.g., '#FF0000') width: Line width in pixels + label: Optional text label to render near the bounding box """ if not self.annotation_pixmap or not self.original_pixmap: logger.warning("Cannot draw bounding box: no image loaded") @@ -828,11 +869,11 @@ class AnnotationCanvasWidget(QWidget): self.bboxes.append( [float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)] ) - self.bbox_meta.append({"color": pen_color, "width": int(width)}) + self.bbox_meta.append({"color": pen_color, "width": int(width), "label": label}) # Store in all_strokes for consistency self.all_strokes.append( - {"bbox": bbox, "color": color, "alpha": 128, "width": width} + {"bbox": bbox, "color": color, "alpha": 128, "width": width, "label": label} ) # Redraw overlay (polylines + all bounding boxes) diff --git a/src/utils/config_manager.py b/src/utils/config_manager.py index a31516d..5b909ff 100644 --- a/src/utils/config_manager.py +++ b/src/utils/config_manager.py @@ -58,6 +58,10 @@ class ConfigManager: "models": { "default_base_model": "yolov8s-seg.pt", "models_directory": "data/models", + "base_model_choices": [ + "yolov8s-seg.pt", + "yolov11s-seg.pt", + ], }, "training": { "default_epochs": 100, @@ -65,6 +69,20 @@ class ConfigManager: "default_imgsz": 640, "default_patience": 50, "default_lr0": 0.01, + "two_stage": { + "enabled": False, + "stage1": { + "epochs": 20, + "lr0": 0.0005, + "patience": 10, + "freeze": 10, + }, + "stage2": { + "epochs": 150, + "lr0": 0.0003, + "patience": 30, + }, + }, }, "detection": { "default_confidence": 0.25,