Implementing 2 stage training
This commit is contained in:
@@ -12,12 +12,26 @@ image_repository:
|
|||||||
models:
|
models:
|
||||||
default_base_model: yolov8s-seg.pt
|
default_base_model: yolov8s-seg.pt
|
||||||
models_directory: data/models
|
models_directory: data/models
|
||||||
|
base_model_choices:
|
||||||
|
- yolov8s-seg.pt
|
||||||
|
- yolov11s-seg.pt
|
||||||
training:
|
training:
|
||||||
default_epochs: 100
|
default_epochs: 100
|
||||||
default_batch_size: 16
|
default_batch_size: 16
|
||||||
default_imgsz: 640
|
default_imgsz: 640
|
||||||
default_patience: 50
|
default_patience: 50
|
||||||
default_lr0: 0.01
|
default_lr0: 0.01
|
||||||
|
two_stage:
|
||||||
|
enabled: false
|
||||||
|
stage1:
|
||||||
|
epochs: 20
|
||||||
|
lr0: 0.0005
|
||||||
|
patience: 10
|
||||||
|
freeze: 10
|
||||||
|
stage2:
|
||||||
|
epochs: 150
|
||||||
|
lr0: 0.0003
|
||||||
|
patience: 30
|
||||||
last_dataset_yaml: /home/martin/code/object_detection/data/datasets/data.yaml
|
last_dataset_yaml: /home/martin/code/object_detection/data/datasets/data.yaml
|
||||||
last_dataset_dir: /home/martin/code/object_detection/data/datasets
|
last_dataset_dir: /home/martin/code/object_detection/data/datasets
|
||||||
detection:
|
detection:
|
||||||
|
|||||||
@@ -117,8 +117,14 @@ class ResultsTab(QWidget):
|
|||||||
self.show_bboxes_checkbox = QCheckBox("Show Bounding Boxes")
|
self.show_bboxes_checkbox = QCheckBox("Show Bounding Boxes")
|
||||||
self.show_bboxes_checkbox.setChecked(True)
|
self.show_bboxes_checkbox.setChecked(True)
|
||||||
self.show_bboxes_checkbox.stateChanged.connect(self._toggle_bboxes)
|
self.show_bboxes_checkbox.stateChanged.connect(self._toggle_bboxes)
|
||||||
|
self.show_confidence_checkbox = QCheckBox("Show Confidence")
|
||||||
|
self.show_confidence_checkbox.setChecked(False)
|
||||||
|
self.show_confidence_checkbox.stateChanged.connect(
|
||||||
|
self._apply_detection_overlays
|
||||||
|
)
|
||||||
toggles_layout.addWidget(self.show_masks_checkbox)
|
toggles_layout.addWidget(self.show_masks_checkbox)
|
||||||
toggles_layout.addWidget(self.show_bboxes_checkbox)
|
toggles_layout.addWidget(self.show_bboxes_checkbox)
|
||||||
|
toggles_layout.addWidget(self.show_confidence_checkbox)
|
||||||
toggles_layout.addStretch()
|
toggles_layout.addStretch()
|
||||||
preview_layout.addLayout(toggles_layout)
|
preview_layout.addLayout(toggles_layout)
|
||||||
|
|
||||||
@@ -312,7 +318,12 @@ class ResultsTab(QWidget):
|
|||||||
det.get("y_max"),
|
det.get("y_max"),
|
||||||
]
|
]
|
||||||
if all(v is not None for v in bbox):
|
if all(v is not None for v in bbox):
|
||||||
self.preview_canvas.draw_saved_bbox(bbox, color)
|
label = None
|
||||||
|
if self.show_confidence_checkbox.isChecked():
|
||||||
|
confidence = det.get("confidence")
|
||||||
|
if confidence is not None:
|
||||||
|
label = f"{confidence:.2f}"
|
||||||
|
self.preview_canvas.draw_saved_bbox(bbox, color, label=label)
|
||||||
|
|
||||||
def _convert_mask(self, mask_points: List[List[float]]) -> List[List[float]]:
|
def _convert_mask(self, mask_points: List[List[float]]) -> List[List[float]]:
|
||||||
"""Convert stored [x, y] masks to [y, x] format for the canvas."""
|
"""Convert stored [x, y] masks to [y, x] format for the canvas."""
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from PySide6.QtWidgets import (
|
|||||||
QProgressBar,
|
QProgressBar,
|
||||||
QSpinBox,
|
QSpinBox,
|
||||||
QDoubleSpinBox,
|
QDoubleSpinBox,
|
||||||
|
QCheckBox,
|
||||||
)
|
)
|
||||||
|
|
||||||
from src.database.db_manager import DatabaseManager
|
from src.database.db_manager import DatabaseManager
|
||||||
@@ -249,13 +250,26 @@ class TrainingTab(QWidget):
|
|||||||
default_base_model = self.config_manager.get(
|
default_base_model = self.config_manager.get(
|
||||||
"models.default_base_model", "yolov8s-seg.pt"
|
"models.default_base_model", "yolov8s-seg.pt"
|
||||||
)
|
)
|
||||||
|
base_model_choices = self.config_manager.get("models.base_model_choices", [])
|
||||||
|
|
||||||
|
self.base_model_combo = QComboBox()
|
||||||
|
self.base_model_combo.addItem("Custom path…", "")
|
||||||
|
for choice in base_model_choices:
|
||||||
|
self.base_model_combo.addItem(choice, choice)
|
||||||
|
self.base_model_combo.currentIndexChanged.connect(
|
||||||
|
self._on_base_model_preset_changed
|
||||||
|
)
|
||||||
|
form_layout.addRow("Base Model Preset:", self.base_model_combo)
|
||||||
|
|
||||||
base_model_layout = QHBoxLayout()
|
base_model_layout = QHBoxLayout()
|
||||||
self.base_model_edit = QLineEdit(default_base_model)
|
self.base_model_edit = QLineEdit(default_base_model)
|
||||||
|
self.base_model_edit.editingFinished.connect(self._on_base_model_path_edited)
|
||||||
base_model_layout.addWidget(self.base_model_edit)
|
base_model_layout.addWidget(self.base_model_edit)
|
||||||
self.base_model_browse_button = QPushButton("Browse…")
|
self.base_model_browse_button = QPushButton("Browse…")
|
||||||
self.base_model_browse_button.clicked.connect(self._browse_base_model)
|
self.base_model_browse_button.clicked.connect(self._browse_base_model)
|
||||||
base_model_layout.addWidget(self.base_model_browse_button)
|
base_model_layout.addWidget(self.base_model_browse_button)
|
||||||
form_layout.addRow("Base Model (.pt):", base_model_layout)
|
form_layout.addRow("Base Model (.pt):", base_model_layout)
|
||||||
|
self._sync_base_model_preset_selection(default_base_model)
|
||||||
|
|
||||||
models_dir = self.config_manager.get("models.models_directory", "data/models")
|
models_dir = self.config_manager.get("models.models_directory", "data/models")
|
||||||
save_dir_layout = QHBoxLayout()
|
save_dir_layout = QHBoxLayout()
|
||||||
@@ -298,6 +312,9 @@ class TrainingTab(QWidget):
|
|||||||
|
|
||||||
group_layout.addLayout(form_layout)
|
group_layout.addLayout(form_layout)
|
||||||
|
|
||||||
|
self.two_stage_group = self._create_two_stage_group(training_defaults)
|
||||||
|
group_layout.addWidget(self.two_stage_group)
|
||||||
|
|
||||||
button_layout = QHBoxLayout()
|
button_layout = QHBoxLayout()
|
||||||
self.start_training_button = QPushButton("Start Training")
|
self.start_training_button = QPushButton("Start Training")
|
||||||
self.start_training_button.clicked.connect(self._start_training)
|
self.start_training_button.clicked.connect(self._start_training)
|
||||||
@@ -322,6 +339,134 @@ class TrainingTab(QWidget):
|
|||||||
group.setLayout(group_layout)
|
group.setLayout(group_layout)
|
||||||
return group
|
return group
|
||||||
|
|
||||||
|
def _create_two_stage_group(self, training_defaults: Dict[str, Any]) -> QGroupBox:
|
||||||
|
group = QGroupBox("Two-Stage Fine-Tuning")
|
||||||
|
group_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
self.two_stage_checkbox = QCheckBox("Enable staged head-only + full fine-tune")
|
||||||
|
two_stage_defaults = (
|
||||||
|
training_defaults.get("two_stage", {}) if training_defaults else {}
|
||||||
|
)
|
||||||
|
self.two_stage_checkbox.setChecked(
|
||||||
|
bool(two_stage_defaults.get("enabled", False))
|
||||||
|
)
|
||||||
|
self.two_stage_checkbox.toggled.connect(self._on_two_stage_toggled)
|
||||||
|
group_layout.addWidget(self.two_stage_checkbox)
|
||||||
|
|
||||||
|
self.two_stage_controls_widget = QWidget()
|
||||||
|
controls_layout = QVBoxLayout()
|
||||||
|
controls_layout.setContentsMargins(0, 0, 0, 0)
|
||||||
|
controls_layout.setSpacing(8)
|
||||||
|
|
||||||
|
stage1_group = QGroupBox("Stage 1 — Head-only stabilization")
|
||||||
|
stage1_form = QFormLayout()
|
||||||
|
stage1_defaults = two_stage_defaults.get("stage1", {})
|
||||||
|
|
||||||
|
self.stage1_epochs_spin = QSpinBox()
|
||||||
|
self.stage1_epochs_spin.setRange(1, 500)
|
||||||
|
self.stage1_epochs_spin.setValue(int(stage1_defaults.get("epochs", 20)))
|
||||||
|
stage1_form.addRow("Epochs:", self.stage1_epochs_spin)
|
||||||
|
|
||||||
|
self.stage1_lr_spin = QDoubleSpinBox()
|
||||||
|
self.stage1_lr_spin.setDecimals(5)
|
||||||
|
self.stage1_lr_spin.setRange(0.00001, 0.1)
|
||||||
|
self.stage1_lr_spin.setSingleStep(0.0005)
|
||||||
|
self.stage1_lr_spin.setValue(float(stage1_defaults.get("lr0", 0.0005)))
|
||||||
|
stage1_form.addRow("Learning Rate:", self.stage1_lr_spin)
|
||||||
|
|
||||||
|
self.stage1_patience_spin = QSpinBox()
|
||||||
|
self.stage1_patience_spin.setRange(1, 200)
|
||||||
|
self.stage1_patience_spin.setValue(int(stage1_defaults.get("patience", 10)))
|
||||||
|
stage1_form.addRow("Patience:", self.stage1_patience_spin)
|
||||||
|
|
||||||
|
self.stage1_freeze_spin = QSpinBox()
|
||||||
|
self.stage1_freeze_spin.setRange(0, 24)
|
||||||
|
self.stage1_freeze_spin.setValue(int(stage1_defaults.get("freeze", 10)))
|
||||||
|
stage1_form.addRow("Freeze layers:", self.stage1_freeze_spin)
|
||||||
|
|
||||||
|
stage1_group.setLayout(stage1_form)
|
||||||
|
controls_layout.addWidget(stage1_group)
|
||||||
|
|
||||||
|
stage2_group = QGroupBox("Stage 2 — Full fine-tuning")
|
||||||
|
stage2_form = QFormLayout()
|
||||||
|
stage2_defaults = two_stage_defaults.get("stage2", {})
|
||||||
|
|
||||||
|
self.stage2_epochs_spin = QSpinBox()
|
||||||
|
self.stage2_epochs_spin.setRange(1, 2000)
|
||||||
|
self.stage2_epochs_spin.setValue(int(stage2_defaults.get("epochs", 150)))
|
||||||
|
stage2_form.addRow("Epochs:", self.stage2_epochs_spin)
|
||||||
|
|
||||||
|
self.stage2_lr_spin = QDoubleSpinBox()
|
||||||
|
self.stage2_lr_spin.setDecimals(5)
|
||||||
|
self.stage2_lr_spin.setRange(0.00001, 0.1)
|
||||||
|
self.stage2_lr_spin.setSingleStep(0.0005)
|
||||||
|
self.stage2_lr_spin.setValue(float(stage2_defaults.get("lr0", 0.0003)))
|
||||||
|
stage2_form.addRow("Learning Rate:", self.stage2_lr_spin)
|
||||||
|
|
||||||
|
self.stage2_patience_spin = QSpinBox()
|
||||||
|
self.stage2_patience_spin.setRange(1, 200)
|
||||||
|
self.stage2_patience_spin.setValue(int(stage2_defaults.get("patience", 30)))
|
||||||
|
stage2_form.addRow("Patience:", self.stage2_patience_spin)
|
||||||
|
|
||||||
|
stage2_group.setLayout(stage2_form)
|
||||||
|
controls_layout.addWidget(stage2_group)
|
||||||
|
|
||||||
|
helper_label = QLabel(
|
||||||
|
"When enabled, staged hyperparameters override the global epochs/patience/lr."
|
||||||
|
)
|
||||||
|
helper_label.setWordWrap(True)
|
||||||
|
controls_layout.addWidget(helper_label)
|
||||||
|
|
||||||
|
self.two_stage_controls_widget.setLayout(controls_layout)
|
||||||
|
group_layout.addWidget(self.two_stage_controls_widget)
|
||||||
|
|
||||||
|
group.setLayout(group_layout)
|
||||||
|
self._on_two_stage_toggled(self.two_stage_checkbox.isChecked())
|
||||||
|
return group
|
||||||
|
|
||||||
|
def _on_two_stage_toggled(self, checked: bool):
|
||||||
|
self._refresh_two_stage_controls_enabled(checked)
|
||||||
|
|
||||||
|
def _refresh_two_stage_controls_enabled(self, checked: Optional[bool] = None):
|
||||||
|
if not hasattr(self, "two_stage_controls_widget"):
|
||||||
|
return
|
||||||
|
desired_state = checked
|
||||||
|
if desired_state is None:
|
||||||
|
desired_state = self.two_stage_checkbox.isChecked()
|
||||||
|
can_edit = self.two_stage_checkbox.isEnabled()
|
||||||
|
self.two_stage_controls_widget.setEnabled(bool(desired_state and can_edit))
|
||||||
|
|
||||||
|
def _on_base_model_preset_changed(self, index: int):
|
||||||
|
preset_value = self.base_model_combo.itemData(index)
|
||||||
|
if preset_value:
|
||||||
|
self.base_model_edit.setText(str(preset_value))
|
||||||
|
elif index == 0:
|
||||||
|
self.base_model_edit.setFocus()
|
||||||
|
|
||||||
|
def _on_base_model_path_edited(self):
|
||||||
|
self._sync_base_model_preset_selection(self.base_model_edit.text().strip())
|
||||||
|
|
||||||
|
def _sync_base_model_preset_selection(self, model_path: str):
|
||||||
|
if not hasattr(self, "base_model_combo"):
|
||||||
|
return
|
||||||
|
normalized = (model_path or "").strip()
|
||||||
|
target_index = 0
|
||||||
|
for idx in range(1, self.base_model_combo.count()):
|
||||||
|
preset_value = self.base_model_combo.itemData(idx)
|
||||||
|
if not preset_value:
|
||||||
|
continue
|
||||||
|
if normalized == preset_value:
|
||||||
|
target_index = idx
|
||||||
|
break
|
||||||
|
if normalized.endswith(f"/{preset_value}") or normalized.endswith(
|
||||||
|
f"\\{preset_value}"
|
||||||
|
):
|
||||||
|
target_index = idx
|
||||||
|
break
|
||||||
|
self.base_model_combo.blockSignals(True)
|
||||||
|
self.base_model_combo.setCurrentIndex(target_index)
|
||||||
|
self.base_model_combo.blockSignals(False)
|
||||||
|
|
||||||
def _get_dataset_search_roots(self) -> List[Path]:
|
def _get_dataset_search_roots(self) -> List[Path]:
|
||||||
roots: List[Path] = []
|
roots: List[Path] = []
|
||||||
default_root = Path("data/datasets").expanduser()
|
default_root = Path("data/datasets").expanduser()
|
||||||
@@ -346,6 +491,7 @@ class TrainingTab(QWidget):
|
|||||||
for yaml_path in root.rglob("*.yaml"):
|
for yaml_path in root.rglob("*.yaml"):
|
||||||
if yaml_path.name.lower() not in {"data.yaml", "dataset.yaml"}:
|
if yaml_path.name.lower() not in {"data.yaml", "dataset.yaml"}:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
discovered.append(yaml_path.resolve())
|
discovered.append(yaml_path.resolve())
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning(f"Unable to scan {root}: {exc}")
|
logger.warning(f"Unable to scan {root}: {exc}")
|
||||||
@@ -964,6 +1110,66 @@ class TrainingTab(QWidget):
|
|||||||
self._build_rgb_dataset(cache_root, dataset_info)
|
self._build_rgb_dataset(cache_root, dataset_info)
|
||||||
return rgb_yaml
|
return rgb_yaml
|
||||||
|
|
||||||
|
def _compose_stage_plan(self, params: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||||
|
two_stage = params.get("two_stage") or {}
|
||||||
|
base_stage = {
|
||||||
|
"label": "Single Stage",
|
||||||
|
"model_path": params["base_model"],
|
||||||
|
"use_previous_best": False,
|
||||||
|
"params": {
|
||||||
|
"epochs": params["epochs"],
|
||||||
|
"batch": params["batch"],
|
||||||
|
"imgsz": params["imgsz"],
|
||||||
|
"patience": params["patience"],
|
||||||
|
"lr0": params["lr0"],
|
||||||
|
"freeze": 0,
|
||||||
|
"name": params["run_name"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if not two_stage.get("enabled"):
|
||||||
|
return [base_stage]
|
||||||
|
|
||||||
|
stage_plan: List[Dict[str, Any]] = []
|
||||||
|
stage1 = two_stage.get("stage1", {})
|
||||||
|
stage2 = two_stage.get("stage2", {})
|
||||||
|
|
||||||
|
stage_plan.append(
|
||||||
|
{
|
||||||
|
"label": "Stage 1 — Head-only",
|
||||||
|
"model_path": params["base_model"],
|
||||||
|
"use_previous_best": False,
|
||||||
|
"params": {
|
||||||
|
"epochs": stage1.get("epochs", params["epochs"]),
|
||||||
|
"batch": params["batch"],
|
||||||
|
"imgsz": params["imgsz"],
|
||||||
|
"patience": stage1.get("patience", params["patience"]),
|
||||||
|
"lr0": stage1.get("lr0", params["lr0"]),
|
||||||
|
"freeze": stage1.get("freeze", 0),
|
||||||
|
"name": f"{params['run_name']}_head_ft",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
stage_plan.append(
|
||||||
|
{
|
||||||
|
"label": "Stage 2 — Full",
|
||||||
|
"model_path": params["base_model"],
|
||||||
|
"use_previous_best": True,
|
||||||
|
"params": {
|
||||||
|
"epochs": stage2.get("epochs", params["epochs"]),
|
||||||
|
"batch": params["batch"],
|
||||||
|
"imgsz": params["imgsz"],
|
||||||
|
"patience": stage2.get("patience", params["patience"]),
|
||||||
|
"lr0": stage2.get("lr0", params["lr0"]),
|
||||||
|
"freeze": 0,
|
||||||
|
"name": f"{params['run_name']}_full_ft",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return stage_plan
|
||||||
|
|
||||||
def _get_rgb_cache_root(self, dataset_yaml: Path) -> Path:
|
def _get_rgb_cache_root(self, dataset_yaml: Path) -> Path:
|
||||||
cache_base = Path("data/datasets/_rgb_cache")
|
cache_base = Path("data/datasets/_rgb_cache")
|
||||||
cache_base.mkdir(parents=True, exist_ok=True)
|
cache_base.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -1085,6 +1291,21 @@ class TrainingTab(QWidget):
|
|||||||
save_dir_path.mkdir(parents=True, exist_ok=True)
|
save_dir_path.mkdir(parents=True, exist_ok=True)
|
||||||
run_name = f"{model_name}_{model_version}".replace(" ", "_")
|
run_name = f"{model_name}_{model_version}".replace(" ", "_")
|
||||||
|
|
||||||
|
two_stage_config = {
|
||||||
|
"enabled": self.two_stage_checkbox.isChecked(),
|
||||||
|
"stage1": {
|
||||||
|
"epochs": self.stage1_epochs_spin.value(),
|
||||||
|
"lr0": self.stage1_lr_spin.value(),
|
||||||
|
"patience": self.stage1_patience_spin.value(),
|
||||||
|
"freeze": self.stage1_freeze_spin.value(),
|
||||||
|
},
|
||||||
|
"stage2": {
|
||||||
|
"epochs": self.stage2_epochs_spin.value(),
|
||||||
|
"lr0": self.stage2_lr_spin.value(),
|
||||||
|
"patience": self.stage2_patience_spin.value(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model_name": model_name,
|
"model_name": model_name,
|
||||||
"model_version": model_version,
|
"model_version": model_version,
|
||||||
@@ -1096,6 +1317,7 @@ class TrainingTab(QWidget):
|
|||||||
"imgsz": self.imgsz_spin.value(),
|
"imgsz": self.imgsz_spin.value(),
|
||||||
"patience": self.patience_spin.value(),
|
"patience": self.patience_spin.value(),
|
||||||
"lr0": self.lr_spin.value(),
|
"lr0": self.lr_spin.value(),
|
||||||
|
"two_stage": two_stage_config,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _start_training(self):
|
def _start_training(self):
|
||||||
@@ -1315,6 +1537,7 @@ class TrainingTab(QWidget):
|
|||||||
self.rescan_button.setEnabled(not is_training)
|
self.rescan_button.setEnabled(not is_training)
|
||||||
self.model_name_edit.setEnabled(not is_training)
|
self.model_name_edit.setEnabled(not is_training)
|
||||||
self.model_version_edit.setEnabled(not is_training)
|
self.model_version_edit.setEnabled(not is_training)
|
||||||
|
self.base_model_combo.setEnabled(not is_training)
|
||||||
self.base_model_edit.setEnabled(not is_training)
|
self.base_model_edit.setEnabled(not is_training)
|
||||||
self.base_model_browse_button.setEnabled(not is_training)
|
self.base_model_browse_button.setEnabled(not is_training)
|
||||||
self.save_dir_edit.setEnabled(not is_training)
|
self.save_dir_edit.setEnabled(not is_training)
|
||||||
@@ -1324,6 +1547,8 @@ class TrainingTab(QWidget):
|
|||||||
self.imgsz_spin.setEnabled(not is_training)
|
self.imgsz_spin.setEnabled(not is_training)
|
||||||
self.patience_spin.setEnabled(not is_training)
|
self.patience_spin.setEnabled(not is_training)
|
||||||
self.lr_spin.setEnabled(not is_training)
|
self.lr_spin.setEnabled(not is_training)
|
||||||
|
self.two_stage_checkbox.setEnabled(not is_training)
|
||||||
|
self._refresh_two_stage_controls_enabled()
|
||||||
|
|
||||||
def _append_training_log(self, message: str):
|
def _append_training_log(self, message: str):
|
||||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||||
@@ -1339,6 +1564,7 @@ class TrainingTab(QWidget):
|
|||||||
)
|
)
|
||||||
if file_path:
|
if file_path:
|
||||||
self.base_model_edit.setText(file_path)
|
self.base_model_edit.setText(file_path)
|
||||||
|
self._sync_base_model_preset_selection(file_path)
|
||||||
|
|
||||||
def _browse_save_dir(self):
|
def _browse_save_dir(self):
|
||||||
start_path = self.save_dir_edit.text().strip() or "data/models"
|
start_path = self.save_dir_edit.text().strip() or "data/models"
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from PySide6.QtGui import (
|
|||||||
QMouseEvent,
|
QMouseEvent,
|
||||||
QPaintEvent,
|
QPaintEvent,
|
||||||
)
|
)
|
||||||
from PySide6.QtCore import Qt, QEvent, Signal, QPoint
|
from PySide6.QtCore import Qt, QEvent, Signal, QPoint, QRect
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from src.utils.image import Image, ImageLoadError
|
from src.utils.image import Image, ImageLoadError
|
||||||
@@ -529,6 +529,40 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
painter.setPen(pen)
|
painter.setPen(pen)
|
||||||
painter.drawRect(x_min, y_min, rect_width, rect_height)
|
painter.drawRect(x_min, y_min, rect_width, rect_height)
|
||||||
|
|
||||||
|
label_text = meta.get("label")
|
||||||
|
if label_text:
|
||||||
|
painter.save()
|
||||||
|
font = painter.font()
|
||||||
|
font.setPointSizeF(max(10.0, width + 4))
|
||||||
|
painter.setFont(font)
|
||||||
|
metrics = painter.fontMetrics()
|
||||||
|
text_width = metrics.horizontalAdvance(label_text)
|
||||||
|
text_height = metrics.height()
|
||||||
|
padding = 4
|
||||||
|
bg_width = text_width + padding * 2
|
||||||
|
bg_height = text_height + padding * 2
|
||||||
|
canvas_width = self.original_pixmap.width()
|
||||||
|
canvas_height = self.original_pixmap.height()
|
||||||
|
bg_x = max(0, min(x_min, canvas_width - bg_width))
|
||||||
|
bg_y = y_min - bg_height
|
||||||
|
if bg_y < 0:
|
||||||
|
bg_y = min(y_min, canvas_height - bg_height)
|
||||||
|
bg_y = max(0, bg_y)
|
||||||
|
background_rect = QRect(bg_x, bg_y, bg_width, bg_height)
|
||||||
|
background_color = QColor(pen_color)
|
||||||
|
background_color.setAlpha(220)
|
||||||
|
painter.fillRect(background_rect, background_color)
|
||||||
|
text_color = QColor(0, 0, 0)
|
||||||
|
if background_color.lightness() < 128:
|
||||||
|
text_color = QColor(255, 255, 255)
|
||||||
|
painter.setPen(text_color)
|
||||||
|
painter.drawText(
|
||||||
|
background_rect.adjusted(padding, padding, -padding, -padding),
|
||||||
|
Qt.AlignLeft | Qt.AlignVCenter,
|
||||||
|
label_text,
|
||||||
|
)
|
||||||
|
painter.restore()
|
||||||
|
|
||||||
painter.end()
|
painter.end()
|
||||||
|
|
||||||
self._update_display()
|
self._update_display()
|
||||||
@@ -787,7 +821,13 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
f"Drew saved polyline with {len(polyline)} points in color {color}"
|
f"Drew saved polyline with {len(polyline)} points in color {color}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def draw_saved_bbox(self, bbox: List[float], color: str, width: int = 3):
|
def draw_saved_bbox(
|
||||||
|
self,
|
||||||
|
bbox: List[float],
|
||||||
|
color: str,
|
||||||
|
width: int = 3,
|
||||||
|
label: Optional[str] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Draw a bounding box from database coordinates onto the annotation canvas.
|
Draw a bounding box from database coordinates onto the annotation canvas.
|
||||||
|
|
||||||
@@ -796,6 +836,7 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
in normalized coordinates (0-1)
|
in normalized coordinates (0-1)
|
||||||
color: Color hex string (e.g., '#FF0000')
|
color: Color hex string (e.g., '#FF0000')
|
||||||
width: Line width in pixels
|
width: Line width in pixels
|
||||||
|
label: Optional text label to render near the bounding box
|
||||||
"""
|
"""
|
||||||
if not self.annotation_pixmap or not self.original_pixmap:
|
if not self.annotation_pixmap or not self.original_pixmap:
|
||||||
logger.warning("Cannot draw bounding box: no image loaded")
|
logger.warning("Cannot draw bounding box: no image loaded")
|
||||||
@@ -828,11 +869,11 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
self.bboxes.append(
|
self.bboxes.append(
|
||||||
[float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)]
|
[float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)]
|
||||||
)
|
)
|
||||||
self.bbox_meta.append({"color": pen_color, "width": int(width)})
|
self.bbox_meta.append({"color": pen_color, "width": int(width), "label": label})
|
||||||
|
|
||||||
# Store in all_strokes for consistency
|
# Store in all_strokes for consistency
|
||||||
self.all_strokes.append(
|
self.all_strokes.append(
|
||||||
{"bbox": bbox, "color": color, "alpha": 128, "width": width}
|
{"bbox": bbox, "color": color, "alpha": 128, "width": width, "label": label}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Redraw overlay (polylines + all bounding boxes)
|
# Redraw overlay (polylines + all bounding boxes)
|
||||||
|
|||||||
@@ -58,6 +58,10 @@ class ConfigManager:
|
|||||||
"models": {
|
"models": {
|
||||||
"default_base_model": "yolov8s-seg.pt",
|
"default_base_model": "yolov8s-seg.pt",
|
||||||
"models_directory": "data/models",
|
"models_directory": "data/models",
|
||||||
|
"base_model_choices": [
|
||||||
|
"yolov8s-seg.pt",
|
||||||
|
"yolov11s-seg.pt",
|
||||||
|
],
|
||||||
},
|
},
|
||||||
"training": {
|
"training": {
|
||||||
"default_epochs": 100,
|
"default_epochs": 100,
|
||||||
@@ -65,6 +69,20 @@ class ConfigManager:
|
|||||||
"default_imgsz": 640,
|
"default_imgsz": 640,
|
||||||
"default_patience": 50,
|
"default_patience": 50,
|
||||||
"default_lr0": 0.01,
|
"default_lr0": 0.01,
|
||||||
|
"two_stage": {
|
||||||
|
"enabled": False,
|
||||||
|
"stage1": {
|
||||||
|
"epochs": 20,
|
||||||
|
"lr0": 0.0005,
|
||||||
|
"patience": 10,
|
||||||
|
"freeze": 10,
|
||||||
|
},
|
||||||
|
"stage2": {
|
||||||
|
"epochs": 150,
|
||||||
|
"lr0": 0.0003,
|
||||||
|
"patience": 30,
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"detection": {
|
"detection": {
|
||||||
"default_confidence": 0.25,
|
"default_confidence": 0.25,
|
||||||
|
|||||||
Reference in New Issue
Block a user