Implementing 2 stage training

This commit is contained in:
2025-12-11 12:04:08 +02:00
parent 221c80aa8c
commit c0684a9c14
5 changed files with 315 additions and 5 deletions

View File

@@ -12,12 +12,26 @@ image_repository:
models:
default_base_model: yolov8s-seg.pt
models_directory: data/models
base_model_choices:
- yolov8s-seg.pt
- yolov11s-seg.pt
training:
default_epochs: 100
default_batch_size: 16
default_imgsz: 640
default_patience: 50
default_lr0: 0.01
two_stage:
enabled: false
stage1:
epochs: 20
lr0: 0.0005
patience: 10
freeze: 10
stage2:
epochs: 150
lr0: 0.0003
patience: 30
last_dataset_yaml: /home/martin/code/object_detection/data/datasets/data.yaml
last_dataset_dir: /home/martin/code/object_detection/data/datasets
detection:

View File

@@ -117,8 +117,14 @@ class ResultsTab(QWidget):
self.show_bboxes_checkbox = QCheckBox("Show Bounding Boxes")
self.show_bboxes_checkbox.setChecked(True)
self.show_bboxes_checkbox.stateChanged.connect(self._toggle_bboxes)
self.show_confidence_checkbox = QCheckBox("Show Confidence")
self.show_confidence_checkbox.setChecked(False)
self.show_confidence_checkbox.stateChanged.connect(
self._apply_detection_overlays
)
toggles_layout.addWidget(self.show_masks_checkbox)
toggles_layout.addWidget(self.show_bboxes_checkbox)
toggles_layout.addWidget(self.show_confidence_checkbox)
toggles_layout.addStretch()
preview_layout.addLayout(toggles_layout)
@@ -312,7 +318,12 @@ class ResultsTab(QWidget):
det.get("y_max"),
]
if all(v is not None for v in bbox):
self.preview_canvas.draw_saved_bbox(bbox, color)
label = None
if self.show_confidence_checkbox.isChecked():
confidence = det.get("confidence")
if confidence is not None:
label = f"{confidence:.2f}"
self.preview_canvas.draw_saved_bbox(bbox, color, label=label)
def _convert_mask(self, mask_points: List[List[float]]) -> List[List[float]]:
"""Convert stored [x, y] masks to [y, x] format for the canvas."""

View File

@@ -28,6 +28,7 @@ from PySide6.QtWidgets import (
QProgressBar,
QSpinBox,
QDoubleSpinBox,
QCheckBox,
)
from src.database.db_manager import DatabaseManager
@@ -249,13 +250,26 @@ class TrainingTab(QWidget):
default_base_model = self.config_manager.get(
"models.default_base_model", "yolov8s-seg.pt"
)
base_model_choices = self.config_manager.get("models.base_model_choices", [])
self.base_model_combo = QComboBox()
self.base_model_combo.addItem("Custom path…", "")
for choice in base_model_choices:
self.base_model_combo.addItem(choice, choice)
self.base_model_combo.currentIndexChanged.connect(
self._on_base_model_preset_changed
)
form_layout.addRow("Base Model Preset:", self.base_model_combo)
base_model_layout = QHBoxLayout()
self.base_model_edit = QLineEdit(default_base_model)
self.base_model_edit.editingFinished.connect(self._on_base_model_path_edited)
base_model_layout.addWidget(self.base_model_edit)
self.base_model_browse_button = QPushButton("Browse…")
self.base_model_browse_button.clicked.connect(self._browse_base_model)
base_model_layout.addWidget(self.base_model_browse_button)
form_layout.addRow("Base Model (.pt):", base_model_layout)
self._sync_base_model_preset_selection(default_base_model)
models_dir = self.config_manager.get("models.models_directory", "data/models")
save_dir_layout = QHBoxLayout()
@@ -298,6 +312,9 @@ class TrainingTab(QWidget):
group_layout.addLayout(form_layout)
self.two_stage_group = self._create_two_stage_group(training_defaults)
group_layout.addWidget(self.two_stage_group)
button_layout = QHBoxLayout()
self.start_training_button = QPushButton("Start Training")
self.start_training_button.clicked.connect(self._start_training)
@@ -322,6 +339,134 @@ class TrainingTab(QWidget):
group.setLayout(group_layout)
return group
def _create_two_stage_group(self, training_defaults: Dict[str, Any]) -> QGroupBox:
group = QGroupBox("Two-Stage Fine-Tuning")
group_layout = QVBoxLayout()
self.two_stage_checkbox = QCheckBox("Enable staged head-only + full fine-tune")
two_stage_defaults = (
training_defaults.get("two_stage", {}) if training_defaults else {}
)
self.two_stage_checkbox.setChecked(
bool(two_stage_defaults.get("enabled", False))
)
self.two_stage_checkbox.toggled.connect(self._on_two_stage_toggled)
group_layout.addWidget(self.two_stage_checkbox)
self.two_stage_controls_widget = QWidget()
controls_layout = QVBoxLayout()
controls_layout.setContentsMargins(0, 0, 0, 0)
controls_layout.setSpacing(8)
stage1_group = QGroupBox("Stage 1 — Head-only stabilization")
stage1_form = QFormLayout()
stage1_defaults = two_stage_defaults.get("stage1", {})
self.stage1_epochs_spin = QSpinBox()
self.stage1_epochs_spin.setRange(1, 500)
self.stage1_epochs_spin.setValue(int(stage1_defaults.get("epochs", 20)))
stage1_form.addRow("Epochs:", self.stage1_epochs_spin)
self.stage1_lr_spin = QDoubleSpinBox()
self.stage1_lr_spin.setDecimals(5)
self.stage1_lr_spin.setRange(0.00001, 0.1)
self.stage1_lr_spin.setSingleStep(0.0005)
self.stage1_lr_spin.setValue(float(stage1_defaults.get("lr0", 0.0005)))
stage1_form.addRow("Learning Rate:", self.stage1_lr_spin)
self.stage1_patience_spin = QSpinBox()
self.stage1_patience_spin.setRange(1, 200)
self.stage1_patience_spin.setValue(int(stage1_defaults.get("patience", 10)))
stage1_form.addRow("Patience:", self.stage1_patience_spin)
self.stage1_freeze_spin = QSpinBox()
self.stage1_freeze_spin.setRange(0, 24)
self.stage1_freeze_spin.setValue(int(stage1_defaults.get("freeze", 10)))
stage1_form.addRow("Freeze layers:", self.stage1_freeze_spin)
stage1_group.setLayout(stage1_form)
controls_layout.addWidget(stage1_group)
stage2_group = QGroupBox("Stage 2 — Full fine-tuning")
stage2_form = QFormLayout()
stage2_defaults = two_stage_defaults.get("stage2", {})
self.stage2_epochs_spin = QSpinBox()
self.stage2_epochs_spin.setRange(1, 2000)
self.stage2_epochs_spin.setValue(int(stage2_defaults.get("epochs", 150)))
stage2_form.addRow("Epochs:", self.stage2_epochs_spin)
self.stage2_lr_spin = QDoubleSpinBox()
self.stage2_lr_spin.setDecimals(5)
self.stage2_lr_spin.setRange(0.00001, 0.1)
self.stage2_lr_spin.setSingleStep(0.0005)
self.stage2_lr_spin.setValue(float(stage2_defaults.get("lr0", 0.0003)))
stage2_form.addRow("Learning Rate:", self.stage2_lr_spin)
self.stage2_patience_spin = QSpinBox()
self.stage2_patience_spin.setRange(1, 200)
self.stage2_patience_spin.setValue(int(stage2_defaults.get("patience", 30)))
stage2_form.addRow("Patience:", self.stage2_patience_spin)
stage2_group.setLayout(stage2_form)
controls_layout.addWidget(stage2_group)
helper_label = QLabel(
"When enabled, staged hyperparameters override the global epochs/patience/lr."
)
helper_label.setWordWrap(True)
controls_layout.addWidget(helper_label)
self.two_stage_controls_widget.setLayout(controls_layout)
group_layout.addWidget(self.two_stage_controls_widget)
group.setLayout(group_layout)
self._on_two_stage_toggled(self.two_stage_checkbox.isChecked())
return group
def _on_two_stage_toggled(self, checked: bool):
self._refresh_two_stage_controls_enabled(checked)
def _refresh_two_stage_controls_enabled(self, checked: Optional[bool] = None):
if not hasattr(self, "two_stage_controls_widget"):
return
desired_state = checked
if desired_state is None:
desired_state = self.two_stage_checkbox.isChecked()
can_edit = self.two_stage_checkbox.isEnabled()
self.two_stage_controls_widget.setEnabled(bool(desired_state and can_edit))
def _on_base_model_preset_changed(self, index: int):
preset_value = self.base_model_combo.itemData(index)
if preset_value:
self.base_model_edit.setText(str(preset_value))
elif index == 0:
self.base_model_edit.setFocus()
def _on_base_model_path_edited(self):
self._sync_base_model_preset_selection(self.base_model_edit.text().strip())
def _sync_base_model_preset_selection(self, model_path: str):
if not hasattr(self, "base_model_combo"):
return
normalized = (model_path or "").strip()
target_index = 0
for idx in range(1, self.base_model_combo.count()):
preset_value = self.base_model_combo.itemData(idx)
if not preset_value:
continue
if normalized == preset_value:
target_index = idx
break
if normalized.endswith(f"/{preset_value}") or normalized.endswith(
f"\\{preset_value}"
):
target_index = idx
break
self.base_model_combo.blockSignals(True)
self.base_model_combo.setCurrentIndex(target_index)
self.base_model_combo.blockSignals(False)
def _get_dataset_search_roots(self) -> List[Path]:
roots: List[Path] = []
default_root = Path("data/datasets").expanduser()
@@ -346,6 +491,7 @@ class TrainingTab(QWidget):
for yaml_path in root.rglob("*.yaml"):
if yaml_path.name.lower() not in {"data.yaml", "dataset.yaml"}:
continue
discovered.append(yaml_path.resolve())
except Exception as exc:
logger.warning(f"Unable to scan {root}: {exc}")
@@ -964,6 +1110,66 @@ class TrainingTab(QWidget):
self._build_rgb_dataset(cache_root, dataset_info)
return rgb_yaml
def _compose_stage_plan(self, params: Dict[str, Any]) -> List[Dict[str, Any]]:
two_stage = params.get("two_stage") or {}
base_stage = {
"label": "Single Stage",
"model_path": params["base_model"],
"use_previous_best": False,
"params": {
"epochs": params["epochs"],
"batch": params["batch"],
"imgsz": params["imgsz"],
"patience": params["patience"],
"lr0": params["lr0"],
"freeze": 0,
"name": params["run_name"],
},
}
if not two_stage.get("enabled"):
return [base_stage]
stage_plan: List[Dict[str, Any]] = []
stage1 = two_stage.get("stage1", {})
stage2 = two_stage.get("stage2", {})
stage_plan.append(
{
"label": "Stage 1 — Head-only",
"model_path": params["base_model"],
"use_previous_best": False,
"params": {
"epochs": stage1.get("epochs", params["epochs"]),
"batch": params["batch"],
"imgsz": params["imgsz"],
"patience": stage1.get("patience", params["patience"]),
"lr0": stage1.get("lr0", params["lr0"]),
"freeze": stage1.get("freeze", 0),
"name": f"{params['run_name']}_head_ft",
},
}
)
stage_plan.append(
{
"label": "Stage 2 — Full",
"model_path": params["base_model"],
"use_previous_best": True,
"params": {
"epochs": stage2.get("epochs", params["epochs"]),
"batch": params["batch"],
"imgsz": params["imgsz"],
"patience": stage2.get("patience", params["patience"]),
"lr0": stage2.get("lr0", params["lr0"]),
"freeze": 0,
"name": f"{params['run_name']}_full_ft",
},
}
)
return stage_plan
def _get_rgb_cache_root(self, dataset_yaml: Path) -> Path:
cache_base = Path("data/datasets/_rgb_cache")
cache_base.mkdir(parents=True, exist_ok=True)
@@ -1085,6 +1291,21 @@ class TrainingTab(QWidget):
save_dir_path.mkdir(parents=True, exist_ok=True)
run_name = f"{model_name}_{model_version}".replace(" ", "_")
two_stage_config = {
"enabled": self.two_stage_checkbox.isChecked(),
"stage1": {
"epochs": self.stage1_epochs_spin.value(),
"lr0": self.stage1_lr_spin.value(),
"patience": self.stage1_patience_spin.value(),
"freeze": self.stage1_freeze_spin.value(),
},
"stage2": {
"epochs": self.stage2_epochs_spin.value(),
"lr0": self.stage2_lr_spin.value(),
"patience": self.stage2_patience_spin.value(),
},
}
return {
"model_name": model_name,
"model_version": model_version,
@@ -1096,6 +1317,7 @@ class TrainingTab(QWidget):
"imgsz": self.imgsz_spin.value(),
"patience": self.patience_spin.value(),
"lr0": self.lr_spin.value(),
"two_stage": two_stage_config,
}
def _start_training(self):
@@ -1315,6 +1537,7 @@ class TrainingTab(QWidget):
self.rescan_button.setEnabled(not is_training)
self.model_name_edit.setEnabled(not is_training)
self.model_version_edit.setEnabled(not is_training)
self.base_model_combo.setEnabled(not is_training)
self.base_model_edit.setEnabled(not is_training)
self.base_model_browse_button.setEnabled(not is_training)
self.save_dir_edit.setEnabled(not is_training)
@@ -1324,6 +1547,8 @@ class TrainingTab(QWidget):
self.imgsz_spin.setEnabled(not is_training)
self.patience_spin.setEnabled(not is_training)
self.lr_spin.setEnabled(not is_training)
self.two_stage_checkbox.setEnabled(not is_training)
self._refresh_two_stage_controls_enabled()
def _append_training_log(self, message: str):
timestamp = datetime.now().strftime("%H:%M:%S")
@@ -1339,6 +1564,7 @@ class TrainingTab(QWidget):
)
if file_path:
self.base_model_edit.setText(file_path)
self._sync_base_model_preset_selection(file_path)
def _browse_save_dir(self):
start_path = self.save_dir_edit.text().strip() or "data/models"

View File

@@ -17,7 +17,7 @@ from PySide6.QtGui import (
QMouseEvent,
QPaintEvent,
)
from PySide6.QtCore import Qt, QEvent, Signal, QPoint
from PySide6.QtCore import Qt, QEvent, Signal, QPoint, QRect
from typing import Any, Dict, List, Optional, Tuple
from src.utils.image import Image, ImageLoadError
@@ -529,6 +529,40 @@ class AnnotationCanvasWidget(QWidget):
painter.setPen(pen)
painter.drawRect(x_min, y_min, rect_width, rect_height)
label_text = meta.get("label")
if label_text:
painter.save()
font = painter.font()
font.setPointSizeF(max(10.0, width + 4))
painter.setFont(font)
metrics = painter.fontMetrics()
text_width = metrics.horizontalAdvance(label_text)
text_height = metrics.height()
padding = 4
bg_width = text_width + padding * 2
bg_height = text_height + padding * 2
canvas_width = self.original_pixmap.width()
canvas_height = self.original_pixmap.height()
bg_x = max(0, min(x_min, canvas_width - bg_width))
bg_y = y_min - bg_height
if bg_y < 0:
bg_y = min(y_min, canvas_height - bg_height)
bg_y = max(0, bg_y)
background_rect = QRect(bg_x, bg_y, bg_width, bg_height)
background_color = QColor(pen_color)
background_color.setAlpha(220)
painter.fillRect(background_rect, background_color)
text_color = QColor(0, 0, 0)
if background_color.lightness() < 128:
text_color = QColor(255, 255, 255)
painter.setPen(text_color)
painter.drawText(
background_rect.adjusted(padding, padding, -padding, -padding),
Qt.AlignLeft | Qt.AlignVCenter,
label_text,
)
painter.restore()
painter.end()
self._update_display()
@@ -787,7 +821,13 @@ class AnnotationCanvasWidget(QWidget):
f"Drew saved polyline with {len(polyline)} points in color {color}"
)
def draw_saved_bbox(self, bbox: List[float], color: str, width: int = 3):
def draw_saved_bbox(
self,
bbox: List[float],
color: str,
width: int = 3,
label: Optional[str] = None,
):
"""
Draw a bounding box from database coordinates onto the annotation canvas.
@@ -796,6 +836,7 @@ class AnnotationCanvasWidget(QWidget):
in normalized coordinates (0-1)
color: Color hex string (e.g., '#FF0000')
width: Line width in pixels
label: Optional text label to render near the bounding box
"""
if not self.annotation_pixmap or not self.original_pixmap:
logger.warning("Cannot draw bounding box: no image loaded")
@@ -828,11 +869,11 @@ class AnnotationCanvasWidget(QWidget):
self.bboxes.append(
[float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)]
)
self.bbox_meta.append({"color": pen_color, "width": int(width)})
self.bbox_meta.append({"color": pen_color, "width": int(width), "label": label})
# Store in all_strokes for consistency
self.all_strokes.append(
{"bbox": bbox, "color": color, "alpha": 128, "width": width}
{"bbox": bbox, "color": color, "alpha": 128, "width": width, "label": label}
)
# Redraw overlay (polylines + all bounding boxes)

View File

@@ -58,6 +58,10 @@ class ConfigManager:
"models": {
"default_base_model": "yolov8s-seg.pt",
"models_directory": "data/models",
"base_model_choices": [
"yolov8s-seg.pt",
"yolov11s-seg.pt",
],
},
"training": {
"default_epochs": 100,
@@ -65,6 +69,20 @@ class ConfigManager:
"default_imgsz": 640,
"default_patience": 50,
"default_lr0": 0.01,
"two_stage": {
"enabled": False,
"stage1": {
"epochs": 20,
"lr0": 0.0005,
"patience": 10,
"freeze": 10,
},
"stage2": {
"epochs": 150,
"lr0": 0.0003,
"patience": 30,
},
},
},
"detection": {
"default_confidence": 0.25,