Implementing 2 stage training
This commit is contained in:
@@ -12,12 +12,26 @@ image_repository:
|
||||
models:
|
||||
default_base_model: yolov8s-seg.pt
|
||||
models_directory: data/models
|
||||
base_model_choices:
|
||||
- yolov8s-seg.pt
|
||||
- yolov11s-seg.pt
|
||||
training:
|
||||
default_epochs: 100
|
||||
default_batch_size: 16
|
||||
default_imgsz: 640
|
||||
default_patience: 50
|
||||
default_lr0: 0.01
|
||||
two_stage:
|
||||
enabled: false
|
||||
stage1:
|
||||
epochs: 20
|
||||
lr0: 0.0005
|
||||
patience: 10
|
||||
freeze: 10
|
||||
stage2:
|
||||
epochs: 150
|
||||
lr0: 0.0003
|
||||
patience: 30
|
||||
last_dataset_yaml: /home/martin/code/object_detection/data/datasets/data.yaml
|
||||
last_dataset_dir: /home/martin/code/object_detection/data/datasets
|
||||
detection:
|
||||
|
||||
@@ -117,8 +117,14 @@ class ResultsTab(QWidget):
|
||||
self.show_bboxes_checkbox = QCheckBox("Show Bounding Boxes")
|
||||
self.show_bboxes_checkbox.setChecked(True)
|
||||
self.show_bboxes_checkbox.stateChanged.connect(self._toggle_bboxes)
|
||||
self.show_confidence_checkbox = QCheckBox("Show Confidence")
|
||||
self.show_confidence_checkbox.setChecked(False)
|
||||
self.show_confidence_checkbox.stateChanged.connect(
|
||||
self._apply_detection_overlays
|
||||
)
|
||||
toggles_layout.addWidget(self.show_masks_checkbox)
|
||||
toggles_layout.addWidget(self.show_bboxes_checkbox)
|
||||
toggles_layout.addWidget(self.show_confidence_checkbox)
|
||||
toggles_layout.addStretch()
|
||||
preview_layout.addLayout(toggles_layout)
|
||||
|
||||
@@ -312,7 +318,12 @@ class ResultsTab(QWidget):
|
||||
det.get("y_max"),
|
||||
]
|
||||
if all(v is not None for v in bbox):
|
||||
self.preview_canvas.draw_saved_bbox(bbox, color)
|
||||
label = None
|
||||
if self.show_confidence_checkbox.isChecked():
|
||||
confidence = det.get("confidence")
|
||||
if confidence is not None:
|
||||
label = f"{confidence:.2f}"
|
||||
self.preview_canvas.draw_saved_bbox(bbox, color, label=label)
|
||||
|
||||
def _convert_mask(self, mask_points: List[List[float]]) -> List[List[float]]:
|
||||
"""Convert stored [x, y] masks to [y, x] format for the canvas."""
|
||||
|
||||
@@ -28,6 +28,7 @@ from PySide6.QtWidgets import (
|
||||
QProgressBar,
|
||||
QSpinBox,
|
||||
QDoubleSpinBox,
|
||||
QCheckBox,
|
||||
)
|
||||
|
||||
from src.database.db_manager import DatabaseManager
|
||||
@@ -249,13 +250,26 @@ class TrainingTab(QWidget):
|
||||
default_base_model = self.config_manager.get(
|
||||
"models.default_base_model", "yolov8s-seg.pt"
|
||||
)
|
||||
base_model_choices = self.config_manager.get("models.base_model_choices", [])
|
||||
|
||||
self.base_model_combo = QComboBox()
|
||||
self.base_model_combo.addItem("Custom path…", "")
|
||||
for choice in base_model_choices:
|
||||
self.base_model_combo.addItem(choice, choice)
|
||||
self.base_model_combo.currentIndexChanged.connect(
|
||||
self._on_base_model_preset_changed
|
||||
)
|
||||
form_layout.addRow("Base Model Preset:", self.base_model_combo)
|
||||
|
||||
base_model_layout = QHBoxLayout()
|
||||
self.base_model_edit = QLineEdit(default_base_model)
|
||||
self.base_model_edit.editingFinished.connect(self._on_base_model_path_edited)
|
||||
base_model_layout.addWidget(self.base_model_edit)
|
||||
self.base_model_browse_button = QPushButton("Browse…")
|
||||
self.base_model_browse_button.clicked.connect(self._browse_base_model)
|
||||
base_model_layout.addWidget(self.base_model_browse_button)
|
||||
form_layout.addRow("Base Model (.pt):", base_model_layout)
|
||||
self._sync_base_model_preset_selection(default_base_model)
|
||||
|
||||
models_dir = self.config_manager.get("models.models_directory", "data/models")
|
||||
save_dir_layout = QHBoxLayout()
|
||||
@@ -298,6 +312,9 @@ class TrainingTab(QWidget):
|
||||
|
||||
group_layout.addLayout(form_layout)
|
||||
|
||||
self.two_stage_group = self._create_two_stage_group(training_defaults)
|
||||
group_layout.addWidget(self.two_stage_group)
|
||||
|
||||
button_layout = QHBoxLayout()
|
||||
self.start_training_button = QPushButton("Start Training")
|
||||
self.start_training_button.clicked.connect(self._start_training)
|
||||
@@ -322,6 +339,134 @@ class TrainingTab(QWidget):
|
||||
group.setLayout(group_layout)
|
||||
return group
|
||||
|
||||
def _create_two_stage_group(self, training_defaults: Dict[str, Any]) -> QGroupBox:
|
||||
group = QGroupBox("Two-Stage Fine-Tuning")
|
||||
group_layout = QVBoxLayout()
|
||||
|
||||
self.two_stage_checkbox = QCheckBox("Enable staged head-only + full fine-tune")
|
||||
two_stage_defaults = (
|
||||
training_defaults.get("two_stage", {}) if training_defaults else {}
|
||||
)
|
||||
self.two_stage_checkbox.setChecked(
|
||||
bool(two_stage_defaults.get("enabled", False))
|
||||
)
|
||||
self.two_stage_checkbox.toggled.connect(self._on_two_stage_toggled)
|
||||
group_layout.addWidget(self.two_stage_checkbox)
|
||||
|
||||
self.two_stage_controls_widget = QWidget()
|
||||
controls_layout = QVBoxLayout()
|
||||
controls_layout.setContentsMargins(0, 0, 0, 0)
|
||||
controls_layout.setSpacing(8)
|
||||
|
||||
stage1_group = QGroupBox("Stage 1 — Head-only stabilization")
|
||||
stage1_form = QFormLayout()
|
||||
stage1_defaults = two_stage_defaults.get("stage1", {})
|
||||
|
||||
self.stage1_epochs_spin = QSpinBox()
|
||||
self.stage1_epochs_spin.setRange(1, 500)
|
||||
self.stage1_epochs_spin.setValue(int(stage1_defaults.get("epochs", 20)))
|
||||
stage1_form.addRow("Epochs:", self.stage1_epochs_spin)
|
||||
|
||||
self.stage1_lr_spin = QDoubleSpinBox()
|
||||
self.stage1_lr_spin.setDecimals(5)
|
||||
self.stage1_lr_spin.setRange(0.00001, 0.1)
|
||||
self.stage1_lr_spin.setSingleStep(0.0005)
|
||||
self.stage1_lr_spin.setValue(float(stage1_defaults.get("lr0", 0.0005)))
|
||||
stage1_form.addRow("Learning Rate:", self.stage1_lr_spin)
|
||||
|
||||
self.stage1_patience_spin = QSpinBox()
|
||||
self.stage1_patience_spin.setRange(1, 200)
|
||||
self.stage1_patience_spin.setValue(int(stage1_defaults.get("patience", 10)))
|
||||
stage1_form.addRow("Patience:", self.stage1_patience_spin)
|
||||
|
||||
self.stage1_freeze_spin = QSpinBox()
|
||||
self.stage1_freeze_spin.setRange(0, 24)
|
||||
self.stage1_freeze_spin.setValue(int(stage1_defaults.get("freeze", 10)))
|
||||
stage1_form.addRow("Freeze layers:", self.stage1_freeze_spin)
|
||||
|
||||
stage1_group.setLayout(stage1_form)
|
||||
controls_layout.addWidget(stage1_group)
|
||||
|
||||
stage2_group = QGroupBox("Stage 2 — Full fine-tuning")
|
||||
stage2_form = QFormLayout()
|
||||
stage2_defaults = two_stage_defaults.get("stage2", {})
|
||||
|
||||
self.stage2_epochs_spin = QSpinBox()
|
||||
self.stage2_epochs_spin.setRange(1, 2000)
|
||||
self.stage2_epochs_spin.setValue(int(stage2_defaults.get("epochs", 150)))
|
||||
stage2_form.addRow("Epochs:", self.stage2_epochs_spin)
|
||||
|
||||
self.stage2_lr_spin = QDoubleSpinBox()
|
||||
self.stage2_lr_spin.setDecimals(5)
|
||||
self.stage2_lr_spin.setRange(0.00001, 0.1)
|
||||
self.stage2_lr_spin.setSingleStep(0.0005)
|
||||
self.stage2_lr_spin.setValue(float(stage2_defaults.get("lr0", 0.0003)))
|
||||
stage2_form.addRow("Learning Rate:", self.stage2_lr_spin)
|
||||
|
||||
self.stage2_patience_spin = QSpinBox()
|
||||
self.stage2_patience_spin.setRange(1, 200)
|
||||
self.stage2_patience_spin.setValue(int(stage2_defaults.get("patience", 30)))
|
||||
stage2_form.addRow("Patience:", self.stage2_patience_spin)
|
||||
|
||||
stage2_group.setLayout(stage2_form)
|
||||
controls_layout.addWidget(stage2_group)
|
||||
|
||||
helper_label = QLabel(
|
||||
"When enabled, staged hyperparameters override the global epochs/patience/lr."
|
||||
)
|
||||
helper_label.setWordWrap(True)
|
||||
controls_layout.addWidget(helper_label)
|
||||
|
||||
self.two_stage_controls_widget.setLayout(controls_layout)
|
||||
group_layout.addWidget(self.two_stage_controls_widget)
|
||||
|
||||
group.setLayout(group_layout)
|
||||
self._on_two_stage_toggled(self.two_stage_checkbox.isChecked())
|
||||
return group
|
||||
|
||||
def _on_two_stage_toggled(self, checked: bool):
|
||||
self._refresh_two_stage_controls_enabled(checked)
|
||||
|
||||
def _refresh_two_stage_controls_enabled(self, checked: Optional[bool] = None):
|
||||
if not hasattr(self, "two_stage_controls_widget"):
|
||||
return
|
||||
desired_state = checked
|
||||
if desired_state is None:
|
||||
desired_state = self.two_stage_checkbox.isChecked()
|
||||
can_edit = self.two_stage_checkbox.isEnabled()
|
||||
self.two_stage_controls_widget.setEnabled(bool(desired_state and can_edit))
|
||||
|
||||
def _on_base_model_preset_changed(self, index: int):
|
||||
preset_value = self.base_model_combo.itemData(index)
|
||||
if preset_value:
|
||||
self.base_model_edit.setText(str(preset_value))
|
||||
elif index == 0:
|
||||
self.base_model_edit.setFocus()
|
||||
|
||||
def _on_base_model_path_edited(self):
|
||||
self._sync_base_model_preset_selection(self.base_model_edit.text().strip())
|
||||
|
||||
def _sync_base_model_preset_selection(self, model_path: str):
|
||||
if not hasattr(self, "base_model_combo"):
|
||||
return
|
||||
normalized = (model_path or "").strip()
|
||||
target_index = 0
|
||||
for idx in range(1, self.base_model_combo.count()):
|
||||
preset_value = self.base_model_combo.itemData(idx)
|
||||
if not preset_value:
|
||||
continue
|
||||
if normalized == preset_value:
|
||||
target_index = idx
|
||||
break
|
||||
if normalized.endswith(f"/{preset_value}") or normalized.endswith(
|
||||
f"\\{preset_value}"
|
||||
):
|
||||
target_index = idx
|
||||
break
|
||||
self.base_model_combo.blockSignals(True)
|
||||
self.base_model_combo.setCurrentIndex(target_index)
|
||||
self.base_model_combo.blockSignals(False)
|
||||
|
||||
def _get_dataset_search_roots(self) -> List[Path]:
|
||||
roots: List[Path] = []
|
||||
default_root = Path("data/datasets").expanduser()
|
||||
@@ -346,6 +491,7 @@ class TrainingTab(QWidget):
|
||||
for yaml_path in root.rglob("*.yaml"):
|
||||
if yaml_path.name.lower() not in {"data.yaml", "dataset.yaml"}:
|
||||
continue
|
||||
|
||||
discovered.append(yaml_path.resolve())
|
||||
except Exception as exc:
|
||||
logger.warning(f"Unable to scan {root}: {exc}")
|
||||
@@ -964,6 +1110,66 @@ class TrainingTab(QWidget):
|
||||
self._build_rgb_dataset(cache_root, dataset_info)
|
||||
return rgb_yaml
|
||||
|
||||
def _compose_stage_plan(self, params: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
two_stage = params.get("two_stage") or {}
|
||||
base_stage = {
|
||||
"label": "Single Stage",
|
||||
"model_path": params["base_model"],
|
||||
"use_previous_best": False,
|
||||
"params": {
|
||||
"epochs": params["epochs"],
|
||||
"batch": params["batch"],
|
||||
"imgsz": params["imgsz"],
|
||||
"patience": params["patience"],
|
||||
"lr0": params["lr0"],
|
||||
"freeze": 0,
|
||||
"name": params["run_name"],
|
||||
},
|
||||
}
|
||||
|
||||
if not two_stage.get("enabled"):
|
||||
return [base_stage]
|
||||
|
||||
stage_plan: List[Dict[str, Any]] = []
|
||||
stage1 = two_stage.get("stage1", {})
|
||||
stage2 = two_stage.get("stage2", {})
|
||||
|
||||
stage_plan.append(
|
||||
{
|
||||
"label": "Stage 1 — Head-only",
|
||||
"model_path": params["base_model"],
|
||||
"use_previous_best": False,
|
||||
"params": {
|
||||
"epochs": stage1.get("epochs", params["epochs"]),
|
||||
"batch": params["batch"],
|
||||
"imgsz": params["imgsz"],
|
||||
"patience": stage1.get("patience", params["patience"]),
|
||||
"lr0": stage1.get("lr0", params["lr0"]),
|
||||
"freeze": stage1.get("freeze", 0),
|
||||
"name": f"{params['run_name']}_head_ft",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
stage_plan.append(
|
||||
{
|
||||
"label": "Stage 2 — Full",
|
||||
"model_path": params["base_model"],
|
||||
"use_previous_best": True,
|
||||
"params": {
|
||||
"epochs": stage2.get("epochs", params["epochs"]),
|
||||
"batch": params["batch"],
|
||||
"imgsz": params["imgsz"],
|
||||
"patience": stage2.get("patience", params["patience"]),
|
||||
"lr0": stage2.get("lr0", params["lr0"]),
|
||||
"freeze": 0,
|
||||
"name": f"{params['run_name']}_full_ft",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return stage_plan
|
||||
|
||||
def _get_rgb_cache_root(self, dataset_yaml: Path) -> Path:
|
||||
cache_base = Path("data/datasets/_rgb_cache")
|
||||
cache_base.mkdir(parents=True, exist_ok=True)
|
||||
@@ -1085,6 +1291,21 @@ class TrainingTab(QWidget):
|
||||
save_dir_path.mkdir(parents=True, exist_ok=True)
|
||||
run_name = f"{model_name}_{model_version}".replace(" ", "_")
|
||||
|
||||
two_stage_config = {
|
||||
"enabled": self.two_stage_checkbox.isChecked(),
|
||||
"stage1": {
|
||||
"epochs": self.stage1_epochs_spin.value(),
|
||||
"lr0": self.stage1_lr_spin.value(),
|
||||
"patience": self.stage1_patience_spin.value(),
|
||||
"freeze": self.stage1_freeze_spin.value(),
|
||||
},
|
||||
"stage2": {
|
||||
"epochs": self.stage2_epochs_spin.value(),
|
||||
"lr0": self.stage2_lr_spin.value(),
|
||||
"patience": self.stage2_patience_spin.value(),
|
||||
},
|
||||
}
|
||||
|
||||
return {
|
||||
"model_name": model_name,
|
||||
"model_version": model_version,
|
||||
@@ -1096,6 +1317,7 @@ class TrainingTab(QWidget):
|
||||
"imgsz": self.imgsz_spin.value(),
|
||||
"patience": self.patience_spin.value(),
|
||||
"lr0": self.lr_spin.value(),
|
||||
"two_stage": two_stage_config,
|
||||
}
|
||||
|
||||
def _start_training(self):
|
||||
@@ -1315,6 +1537,7 @@ class TrainingTab(QWidget):
|
||||
self.rescan_button.setEnabled(not is_training)
|
||||
self.model_name_edit.setEnabled(not is_training)
|
||||
self.model_version_edit.setEnabled(not is_training)
|
||||
self.base_model_combo.setEnabled(not is_training)
|
||||
self.base_model_edit.setEnabled(not is_training)
|
||||
self.base_model_browse_button.setEnabled(not is_training)
|
||||
self.save_dir_edit.setEnabled(not is_training)
|
||||
@@ -1324,6 +1547,8 @@ class TrainingTab(QWidget):
|
||||
self.imgsz_spin.setEnabled(not is_training)
|
||||
self.patience_spin.setEnabled(not is_training)
|
||||
self.lr_spin.setEnabled(not is_training)
|
||||
self.two_stage_checkbox.setEnabled(not is_training)
|
||||
self._refresh_two_stage_controls_enabled()
|
||||
|
||||
def _append_training_log(self, message: str):
|
||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||
@@ -1339,6 +1564,7 @@ class TrainingTab(QWidget):
|
||||
)
|
||||
if file_path:
|
||||
self.base_model_edit.setText(file_path)
|
||||
self._sync_base_model_preset_selection(file_path)
|
||||
|
||||
def _browse_save_dir(self):
|
||||
start_path = self.save_dir_edit.text().strip() or "data/models"
|
||||
|
||||
@@ -17,7 +17,7 @@ from PySide6.QtGui import (
|
||||
QMouseEvent,
|
||||
QPaintEvent,
|
||||
)
|
||||
from PySide6.QtCore import Qt, QEvent, Signal, QPoint
|
||||
from PySide6.QtCore import Qt, QEvent, Signal, QPoint, QRect
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from src.utils.image import Image, ImageLoadError
|
||||
@@ -529,6 +529,40 @@ class AnnotationCanvasWidget(QWidget):
|
||||
painter.setPen(pen)
|
||||
painter.drawRect(x_min, y_min, rect_width, rect_height)
|
||||
|
||||
label_text = meta.get("label")
|
||||
if label_text:
|
||||
painter.save()
|
||||
font = painter.font()
|
||||
font.setPointSizeF(max(10.0, width + 4))
|
||||
painter.setFont(font)
|
||||
metrics = painter.fontMetrics()
|
||||
text_width = metrics.horizontalAdvance(label_text)
|
||||
text_height = metrics.height()
|
||||
padding = 4
|
||||
bg_width = text_width + padding * 2
|
||||
bg_height = text_height + padding * 2
|
||||
canvas_width = self.original_pixmap.width()
|
||||
canvas_height = self.original_pixmap.height()
|
||||
bg_x = max(0, min(x_min, canvas_width - bg_width))
|
||||
bg_y = y_min - bg_height
|
||||
if bg_y < 0:
|
||||
bg_y = min(y_min, canvas_height - bg_height)
|
||||
bg_y = max(0, bg_y)
|
||||
background_rect = QRect(bg_x, bg_y, bg_width, bg_height)
|
||||
background_color = QColor(pen_color)
|
||||
background_color.setAlpha(220)
|
||||
painter.fillRect(background_rect, background_color)
|
||||
text_color = QColor(0, 0, 0)
|
||||
if background_color.lightness() < 128:
|
||||
text_color = QColor(255, 255, 255)
|
||||
painter.setPen(text_color)
|
||||
painter.drawText(
|
||||
background_rect.adjusted(padding, padding, -padding, -padding),
|
||||
Qt.AlignLeft | Qt.AlignVCenter,
|
||||
label_text,
|
||||
)
|
||||
painter.restore()
|
||||
|
||||
painter.end()
|
||||
|
||||
self._update_display()
|
||||
@@ -787,7 +821,13 @@ class AnnotationCanvasWidget(QWidget):
|
||||
f"Drew saved polyline with {len(polyline)} points in color {color}"
|
||||
)
|
||||
|
||||
def draw_saved_bbox(self, bbox: List[float], color: str, width: int = 3):
|
||||
def draw_saved_bbox(
|
||||
self,
|
||||
bbox: List[float],
|
||||
color: str,
|
||||
width: int = 3,
|
||||
label: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Draw a bounding box from database coordinates onto the annotation canvas.
|
||||
|
||||
@@ -796,6 +836,7 @@ class AnnotationCanvasWidget(QWidget):
|
||||
in normalized coordinates (0-1)
|
||||
color: Color hex string (e.g., '#FF0000')
|
||||
width: Line width in pixels
|
||||
label: Optional text label to render near the bounding box
|
||||
"""
|
||||
if not self.annotation_pixmap or not self.original_pixmap:
|
||||
logger.warning("Cannot draw bounding box: no image loaded")
|
||||
@@ -828,11 +869,11 @@ class AnnotationCanvasWidget(QWidget):
|
||||
self.bboxes.append(
|
||||
[float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)]
|
||||
)
|
||||
self.bbox_meta.append({"color": pen_color, "width": int(width)})
|
||||
self.bbox_meta.append({"color": pen_color, "width": int(width), "label": label})
|
||||
|
||||
# Store in all_strokes for consistency
|
||||
self.all_strokes.append(
|
||||
{"bbox": bbox, "color": color, "alpha": 128, "width": width}
|
||||
{"bbox": bbox, "color": color, "alpha": 128, "width": width, "label": label}
|
||||
)
|
||||
|
||||
# Redraw overlay (polylines + all bounding boxes)
|
||||
|
||||
@@ -58,6 +58,10 @@ class ConfigManager:
|
||||
"models": {
|
||||
"default_base_model": "yolov8s-seg.pt",
|
||||
"models_directory": "data/models",
|
||||
"base_model_choices": [
|
||||
"yolov8s-seg.pt",
|
||||
"yolov11s-seg.pt",
|
||||
],
|
||||
},
|
||||
"training": {
|
||||
"default_epochs": 100,
|
||||
@@ -65,6 +69,20 @@ class ConfigManager:
|
||||
"default_imgsz": 640,
|
||||
"default_patience": 50,
|
||||
"default_lr0": 0.01,
|
||||
"two_stage": {
|
||||
"enabled": False,
|
||||
"stage1": {
|
||||
"epochs": 20,
|
||||
"lr0": 0.0005,
|
||||
"patience": 10,
|
||||
"freeze": 10,
|
||||
},
|
||||
"stage2": {
|
||||
"epochs": 150,
|
||||
"lr0": 0.0003,
|
||||
"patience": 30,
|
||||
},
|
||||
},
|
||||
},
|
||||
"detection": {
|
||||
"default_confidence": 0.25,
|
||||
|
||||
Reference in New Issue
Block a user