Implementing 2 stage training

This commit is contained in:
2025-12-11 12:04:08 +02:00
parent 221c80aa8c
commit c0684a9c14
5 changed files with 315 additions and 5 deletions

View File

@@ -12,12 +12,26 @@ image_repository:
models: models:
default_base_model: yolov8s-seg.pt default_base_model: yolov8s-seg.pt
models_directory: data/models models_directory: data/models
base_model_choices:
- yolov8s-seg.pt
- yolov11s-seg.pt
training: training:
default_epochs: 100 default_epochs: 100
default_batch_size: 16 default_batch_size: 16
default_imgsz: 640 default_imgsz: 640
default_patience: 50 default_patience: 50
default_lr0: 0.01 default_lr0: 0.01
two_stage:
enabled: false
stage1:
epochs: 20
lr0: 0.0005
patience: 10
freeze: 10
stage2:
epochs: 150
lr0: 0.0003
patience: 30
last_dataset_yaml: /home/martin/code/object_detection/data/datasets/data.yaml last_dataset_yaml: /home/martin/code/object_detection/data/datasets/data.yaml
last_dataset_dir: /home/martin/code/object_detection/data/datasets last_dataset_dir: /home/martin/code/object_detection/data/datasets
detection: detection:

View File

@@ -117,8 +117,14 @@ class ResultsTab(QWidget):
self.show_bboxes_checkbox = QCheckBox("Show Bounding Boxes") self.show_bboxes_checkbox = QCheckBox("Show Bounding Boxes")
self.show_bboxes_checkbox.setChecked(True) self.show_bboxes_checkbox.setChecked(True)
self.show_bboxes_checkbox.stateChanged.connect(self._toggle_bboxes) self.show_bboxes_checkbox.stateChanged.connect(self._toggle_bboxes)
self.show_confidence_checkbox = QCheckBox("Show Confidence")
self.show_confidence_checkbox.setChecked(False)
self.show_confidence_checkbox.stateChanged.connect(
self._apply_detection_overlays
)
toggles_layout.addWidget(self.show_masks_checkbox) toggles_layout.addWidget(self.show_masks_checkbox)
toggles_layout.addWidget(self.show_bboxes_checkbox) toggles_layout.addWidget(self.show_bboxes_checkbox)
toggles_layout.addWidget(self.show_confidence_checkbox)
toggles_layout.addStretch() toggles_layout.addStretch()
preview_layout.addLayout(toggles_layout) preview_layout.addLayout(toggles_layout)
@@ -312,7 +318,12 @@ class ResultsTab(QWidget):
det.get("y_max"), det.get("y_max"),
] ]
if all(v is not None for v in bbox): if all(v is not None for v in bbox):
self.preview_canvas.draw_saved_bbox(bbox, color) label = None
if self.show_confidence_checkbox.isChecked():
confidence = det.get("confidence")
if confidence is not None:
label = f"{confidence:.2f}"
self.preview_canvas.draw_saved_bbox(bbox, color, label=label)
def _convert_mask(self, mask_points: List[List[float]]) -> List[List[float]]: def _convert_mask(self, mask_points: List[List[float]]) -> List[List[float]]:
"""Convert stored [x, y] masks to [y, x] format for the canvas.""" """Convert stored [x, y] masks to [y, x] format for the canvas."""

View File

@@ -28,6 +28,7 @@ from PySide6.QtWidgets import (
QProgressBar, QProgressBar,
QSpinBox, QSpinBox,
QDoubleSpinBox, QDoubleSpinBox,
QCheckBox,
) )
from src.database.db_manager import DatabaseManager from src.database.db_manager import DatabaseManager
@@ -249,13 +250,26 @@ class TrainingTab(QWidget):
default_base_model = self.config_manager.get( default_base_model = self.config_manager.get(
"models.default_base_model", "yolov8s-seg.pt" "models.default_base_model", "yolov8s-seg.pt"
) )
base_model_choices = self.config_manager.get("models.base_model_choices", [])
self.base_model_combo = QComboBox()
self.base_model_combo.addItem("Custom path…", "")
for choice in base_model_choices:
self.base_model_combo.addItem(choice, choice)
self.base_model_combo.currentIndexChanged.connect(
self._on_base_model_preset_changed
)
form_layout.addRow("Base Model Preset:", self.base_model_combo)
base_model_layout = QHBoxLayout() base_model_layout = QHBoxLayout()
self.base_model_edit = QLineEdit(default_base_model) self.base_model_edit = QLineEdit(default_base_model)
self.base_model_edit.editingFinished.connect(self._on_base_model_path_edited)
base_model_layout.addWidget(self.base_model_edit) base_model_layout.addWidget(self.base_model_edit)
self.base_model_browse_button = QPushButton("Browse…") self.base_model_browse_button = QPushButton("Browse…")
self.base_model_browse_button.clicked.connect(self._browse_base_model) self.base_model_browse_button.clicked.connect(self._browse_base_model)
base_model_layout.addWidget(self.base_model_browse_button) base_model_layout.addWidget(self.base_model_browse_button)
form_layout.addRow("Base Model (.pt):", base_model_layout) form_layout.addRow("Base Model (.pt):", base_model_layout)
self._sync_base_model_preset_selection(default_base_model)
models_dir = self.config_manager.get("models.models_directory", "data/models") models_dir = self.config_manager.get("models.models_directory", "data/models")
save_dir_layout = QHBoxLayout() save_dir_layout = QHBoxLayout()
@@ -298,6 +312,9 @@ class TrainingTab(QWidget):
group_layout.addLayout(form_layout) group_layout.addLayout(form_layout)
self.two_stage_group = self._create_two_stage_group(training_defaults)
group_layout.addWidget(self.two_stage_group)
button_layout = QHBoxLayout() button_layout = QHBoxLayout()
self.start_training_button = QPushButton("Start Training") self.start_training_button = QPushButton("Start Training")
self.start_training_button.clicked.connect(self._start_training) self.start_training_button.clicked.connect(self._start_training)
@@ -322,6 +339,134 @@ class TrainingTab(QWidget):
group.setLayout(group_layout) group.setLayout(group_layout)
return group return group
def _create_two_stage_group(self, training_defaults: Dict[str, Any]) -> QGroupBox:
group = QGroupBox("Two-Stage Fine-Tuning")
group_layout = QVBoxLayout()
self.two_stage_checkbox = QCheckBox("Enable staged head-only + full fine-tune")
two_stage_defaults = (
training_defaults.get("two_stage", {}) if training_defaults else {}
)
self.two_stage_checkbox.setChecked(
bool(two_stage_defaults.get("enabled", False))
)
self.two_stage_checkbox.toggled.connect(self._on_two_stage_toggled)
group_layout.addWidget(self.two_stage_checkbox)
self.two_stage_controls_widget = QWidget()
controls_layout = QVBoxLayout()
controls_layout.setContentsMargins(0, 0, 0, 0)
controls_layout.setSpacing(8)
stage1_group = QGroupBox("Stage 1 — Head-only stabilization")
stage1_form = QFormLayout()
stage1_defaults = two_stage_defaults.get("stage1", {})
self.stage1_epochs_spin = QSpinBox()
self.stage1_epochs_spin.setRange(1, 500)
self.stage1_epochs_spin.setValue(int(stage1_defaults.get("epochs", 20)))
stage1_form.addRow("Epochs:", self.stage1_epochs_spin)
self.stage1_lr_spin = QDoubleSpinBox()
self.stage1_lr_spin.setDecimals(5)
self.stage1_lr_spin.setRange(0.00001, 0.1)
self.stage1_lr_spin.setSingleStep(0.0005)
self.stage1_lr_spin.setValue(float(stage1_defaults.get("lr0", 0.0005)))
stage1_form.addRow("Learning Rate:", self.stage1_lr_spin)
self.stage1_patience_spin = QSpinBox()
self.stage1_patience_spin.setRange(1, 200)
self.stage1_patience_spin.setValue(int(stage1_defaults.get("patience", 10)))
stage1_form.addRow("Patience:", self.stage1_patience_spin)
self.stage1_freeze_spin = QSpinBox()
self.stage1_freeze_spin.setRange(0, 24)
self.stage1_freeze_spin.setValue(int(stage1_defaults.get("freeze", 10)))
stage1_form.addRow("Freeze layers:", self.stage1_freeze_spin)
stage1_group.setLayout(stage1_form)
controls_layout.addWidget(stage1_group)
stage2_group = QGroupBox("Stage 2 — Full fine-tuning")
stage2_form = QFormLayout()
stage2_defaults = two_stage_defaults.get("stage2", {})
self.stage2_epochs_spin = QSpinBox()
self.stage2_epochs_spin.setRange(1, 2000)
self.stage2_epochs_spin.setValue(int(stage2_defaults.get("epochs", 150)))
stage2_form.addRow("Epochs:", self.stage2_epochs_spin)
self.stage2_lr_spin = QDoubleSpinBox()
self.stage2_lr_spin.setDecimals(5)
self.stage2_lr_spin.setRange(0.00001, 0.1)
self.stage2_lr_spin.setSingleStep(0.0005)
self.stage2_lr_spin.setValue(float(stage2_defaults.get("lr0", 0.0003)))
stage2_form.addRow("Learning Rate:", self.stage2_lr_spin)
self.stage2_patience_spin = QSpinBox()
self.stage2_patience_spin.setRange(1, 200)
self.stage2_patience_spin.setValue(int(stage2_defaults.get("patience", 30)))
stage2_form.addRow("Patience:", self.stage2_patience_spin)
stage2_group.setLayout(stage2_form)
controls_layout.addWidget(stage2_group)
helper_label = QLabel(
"When enabled, staged hyperparameters override the global epochs/patience/lr."
)
helper_label.setWordWrap(True)
controls_layout.addWidget(helper_label)
self.two_stage_controls_widget.setLayout(controls_layout)
group_layout.addWidget(self.two_stage_controls_widget)
group.setLayout(group_layout)
self._on_two_stage_toggled(self.two_stage_checkbox.isChecked())
return group
def _on_two_stage_toggled(self, checked: bool):
self._refresh_two_stage_controls_enabled(checked)
def _refresh_two_stage_controls_enabled(self, checked: Optional[bool] = None):
if not hasattr(self, "two_stage_controls_widget"):
return
desired_state = checked
if desired_state is None:
desired_state = self.two_stage_checkbox.isChecked()
can_edit = self.two_stage_checkbox.isEnabled()
self.two_stage_controls_widget.setEnabled(bool(desired_state and can_edit))
def _on_base_model_preset_changed(self, index: int):
preset_value = self.base_model_combo.itemData(index)
if preset_value:
self.base_model_edit.setText(str(preset_value))
elif index == 0:
self.base_model_edit.setFocus()
def _on_base_model_path_edited(self):
self._sync_base_model_preset_selection(self.base_model_edit.text().strip())
def _sync_base_model_preset_selection(self, model_path: str):
if not hasattr(self, "base_model_combo"):
return
normalized = (model_path or "").strip()
target_index = 0
for idx in range(1, self.base_model_combo.count()):
preset_value = self.base_model_combo.itemData(idx)
if not preset_value:
continue
if normalized == preset_value:
target_index = idx
break
if normalized.endswith(f"/{preset_value}") or normalized.endswith(
f"\\{preset_value}"
):
target_index = idx
break
self.base_model_combo.blockSignals(True)
self.base_model_combo.setCurrentIndex(target_index)
self.base_model_combo.blockSignals(False)
def _get_dataset_search_roots(self) -> List[Path]: def _get_dataset_search_roots(self) -> List[Path]:
roots: List[Path] = [] roots: List[Path] = []
default_root = Path("data/datasets").expanduser() default_root = Path("data/datasets").expanduser()
@@ -346,6 +491,7 @@ class TrainingTab(QWidget):
for yaml_path in root.rglob("*.yaml"): for yaml_path in root.rglob("*.yaml"):
if yaml_path.name.lower() not in {"data.yaml", "dataset.yaml"}: if yaml_path.name.lower() not in {"data.yaml", "dataset.yaml"}:
continue continue
discovered.append(yaml_path.resolve()) discovered.append(yaml_path.resolve())
except Exception as exc: except Exception as exc:
logger.warning(f"Unable to scan {root}: {exc}") logger.warning(f"Unable to scan {root}: {exc}")
@@ -964,6 +1110,66 @@ class TrainingTab(QWidget):
self._build_rgb_dataset(cache_root, dataset_info) self._build_rgb_dataset(cache_root, dataset_info)
return rgb_yaml return rgb_yaml
def _compose_stage_plan(self, params: Dict[str, Any]) -> List[Dict[str, Any]]:
two_stage = params.get("two_stage") or {}
base_stage = {
"label": "Single Stage",
"model_path": params["base_model"],
"use_previous_best": False,
"params": {
"epochs": params["epochs"],
"batch": params["batch"],
"imgsz": params["imgsz"],
"patience": params["patience"],
"lr0": params["lr0"],
"freeze": 0,
"name": params["run_name"],
},
}
if not two_stage.get("enabled"):
return [base_stage]
stage_plan: List[Dict[str, Any]] = []
stage1 = two_stage.get("stage1", {})
stage2 = two_stage.get("stage2", {})
stage_plan.append(
{
"label": "Stage 1 — Head-only",
"model_path": params["base_model"],
"use_previous_best": False,
"params": {
"epochs": stage1.get("epochs", params["epochs"]),
"batch": params["batch"],
"imgsz": params["imgsz"],
"patience": stage1.get("patience", params["patience"]),
"lr0": stage1.get("lr0", params["lr0"]),
"freeze": stage1.get("freeze", 0),
"name": f"{params['run_name']}_head_ft",
},
}
)
stage_plan.append(
{
"label": "Stage 2 — Full",
"model_path": params["base_model"],
"use_previous_best": True,
"params": {
"epochs": stage2.get("epochs", params["epochs"]),
"batch": params["batch"],
"imgsz": params["imgsz"],
"patience": stage2.get("patience", params["patience"]),
"lr0": stage2.get("lr0", params["lr0"]),
"freeze": 0,
"name": f"{params['run_name']}_full_ft",
},
}
)
return stage_plan
def _get_rgb_cache_root(self, dataset_yaml: Path) -> Path: def _get_rgb_cache_root(self, dataset_yaml: Path) -> Path:
cache_base = Path("data/datasets/_rgb_cache") cache_base = Path("data/datasets/_rgb_cache")
cache_base.mkdir(parents=True, exist_ok=True) cache_base.mkdir(parents=True, exist_ok=True)
@@ -1085,6 +1291,21 @@ class TrainingTab(QWidget):
save_dir_path.mkdir(parents=True, exist_ok=True) save_dir_path.mkdir(parents=True, exist_ok=True)
run_name = f"{model_name}_{model_version}".replace(" ", "_") run_name = f"{model_name}_{model_version}".replace(" ", "_")
two_stage_config = {
"enabled": self.two_stage_checkbox.isChecked(),
"stage1": {
"epochs": self.stage1_epochs_spin.value(),
"lr0": self.stage1_lr_spin.value(),
"patience": self.stage1_patience_spin.value(),
"freeze": self.stage1_freeze_spin.value(),
},
"stage2": {
"epochs": self.stage2_epochs_spin.value(),
"lr0": self.stage2_lr_spin.value(),
"patience": self.stage2_patience_spin.value(),
},
}
return { return {
"model_name": model_name, "model_name": model_name,
"model_version": model_version, "model_version": model_version,
@@ -1096,6 +1317,7 @@ class TrainingTab(QWidget):
"imgsz": self.imgsz_spin.value(), "imgsz": self.imgsz_spin.value(),
"patience": self.patience_spin.value(), "patience": self.patience_spin.value(),
"lr0": self.lr_spin.value(), "lr0": self.lr_spin.value(),
"two_stage": two_stage_config,
} }
def _start_training(self): def _start_training(self):
@@ -1315,6 +1537,7 @@ class TrainingTab(QWidget):
self.rescan_button.setEnabled(not is_training) self.rescan_button.setEnabled(not is_training)
self.model_name_edit.setEnabled(not is_training) self.model_name_edit.setEnabled(not is_training)
self.model_version_edit.setEnabled(not is_training) self.model_version_edit.setEnabled(not is_training)
self.base_model_combo.setEnabled(not is_training)
self.base_model_edit.setEnabled(not is_training) self.base_model_edit.setEnabled(not is_training)
self.base_model_browse_button.setEnabled(not is_training) self.base_model_browse_button.setEnabled(not is_training)
self.save_dir_edit.setEnabled(not is_training) self.save_dir_edit.setEnabled(not is_training)
@@ -1324,6 +1547,8 @@ class TrainingTab(QWidget):
self.imgsz_spin.setEnabled(not is_training) self.imgsz_spin.setEnabled(not is_training)
self.patience_spin.setEnabled(not is_training) self.patience_spin.setEnabled(not is_training)
self.lr_spin.setEnabled(not is_training) self.lr_spin.setEnabled(not is_training)
self.two_stage_checkbox.setEnabled(not is_training)
self._refresh_two_stage_controls_enabled()
def _append_training_log(self, message: str): def _append_training_log(self, message: str):
timestamp = datetime.now().strftime("%H:%M:%S") timestamp = datetime.now().strftime("%H:%M:%S")
@@ -1339,6 +1564,7 @@ class TrainingTab(QWidget):
) )
if file_path: if file_path:
self.base_model_edit.setText(file_path) self.base_model_edit.setText(file_path)
self._sync_base_model_preset_selection(file_path)
def _browse_save_dir(self): def _browse_save_dir(self):
start_path = self.save_dir_edit.text().strip() or "data/models" start_path = self.save_dir_edit.text().strip() or "data/models"

View File

@@ -17,7 +17,7 @@ from PySide6.QtGui import (
QMouseEvent, QMouseEvent,
QPaintEvent, QPaintEvent,
) )
from PySide6.QtCore import Qt, QEvent, Signal, QPoint from PySide6.QtCore import Qt, QEvent, Signal, QPoint, QRect
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from src.utils.image import Image, ImageLoadError from src.utils.image import Image, ImageLoadError
@@ -529,6 +529,40 @@ class AnnotationCanvasWidget(QWidget):
painter.setPen(pen) painter.setPen(pen)
painter.drawRect(x_min, y_min, rect_width, rect_height) painter.drawRect(x_min, y_min, rect_width, rect_height)
label_text = meta.get("label")
if label_text:
painter.save()
font = painter.font()
font.setPointSizeF(max(10.0, width + 4))
painter.setFont(font)
metrics = painter.fontMetrics()
text_width = metrics.horizontalAdvance(label_text)
text_height = metrics.height()
padding = 4
bg_width = text_width + padding * 2
bg_height = text_height + padding * 2
canvas_width = self.original_pixmap.width()
canvas_height = self.original_pixmap.height()
bg_x = max(0, min(x_min, canvas_width - bg_width))
bg_y = y_min - bg_height
if bg_y < 0:
bg_y = min(y_min, canvas_height - bg_height)
bg_y = max(0, bg_y)
background_rect = QRect(bg_x, bg_y, bg_width, bg_height)
background_color = QColor(pen_color)
background_color.setAlpha(220)
painter.fillRect(background_rect, background_color)
text_color = QColor(0, 0, 0)
if background_color.lightness() < 128:
text_color = QColor(255, 255, 255)
painter.setPen(text_color)
painter.drawText(
background_rect.adjusted(padding, padding, -padding, -padding),
Qt.AlignLeft | Qt.AlignVCenter,
label_text,
)
painter.restore()
painter.end() painter.end()
self._update_display() self._update_display()
@@ -787,7 +821,13 @@ class AnnotationCanvasWidget(QWidget):
f"Drew saved polyline with {len(polyline)} points in color {color}" f"Drew saved polyline with {len(polyline)} points in color {color}"
) )
def draw_saved_bbox(self, bbox: List[float], color: str, width: int = 3): def draw_saved_bbox(
self,
bbox: List[float],
color: str,
width: int = 3,
label: Optional[str] = None,
):
""" """
Draw a bounding box from database coordinates onto the annotation canvas. Draw a bounding box from database coordinates onto the annotation canvas.
@@ -796,6 +836,7 @@ class AnnotationCanvasWidget(QWidget):
in normalized coordinates (0-1) in normalized coordinates (0-1)
color: Color hex string (e.g., '#FF0000') color: Color hex string (e.g., '#FF0000')
width: Line width in pixels width: Line width in pixels
label: Optional text label to render near the bounding box
""" """
if not self.annotation_pixmap or not self.original_pixmap: if not self.annotation_pixmap or not self.original_pixmap:
logger.warning("Cannot draw bounding box: no image loaded") logger.warning("Cannot draw bounding box: no image loaded")
@@ -828,11 +869,11 @@ class AnnotationCanvasWidget(QWidget):
self.bboxes.append( self.bboxes.append(
[float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)] [float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)]
) )
self.bbox_meta.append({"color": pen_color, "width": int(width)}) self.bbox_meta.append({"color": pen_color, "width": int(width), "label": label})
# Store in all_strokes for consistency # Store in all_strokes for consistency
self.all_strokes.append( self.all_strokes.append(
{"bbox": bbox, "color": color, "alpha": 128, "width": width} {"bbox": bbox, "color": color, "alpha": 128, "width": width, "label": label}
) )
# Redraw overlay (polylines + all bounding boxes) # Redraw overlay (polylines + all bounding boxes)

View File

@@ -58,6 +58,10 @@ class ConfigManager:
"models": { "models": {
"default_base_model": "yolov8s-seg.pt", "default_base_model": "yolov8s-seg.pt",
"models_directory": "data/models", "models_directory": "data/models",
"base_model_choices": [
"yolov8s-seg.pt",
"yolov11s-seg.pt",
],
}, },
"training": { "training": {
"default_epochs": 100, "default_epochs": 100,
@@ -65,6 +69,20 @@ class ConfigManager:
"default_imgsz": 640, "default_imgsz": 640,
"default_patience": 50, "default_patience": 50,
"default_lr0": 0.01, "default_lr0": 0.01,
"two_stage": {
"enabled": False,
"stage1": {
"epochs": 20,
"lr0": 0.0005,
"patience": 10,
"freeze": 10,
},
"stage2": {
"epochs": 150,
"lr0": 0.0003,
"patience": 30,
},
},
}, },
"detection": { "detection": {
"default_confidence": 0.25, "default_confidence": 0.25,