Files
object-segmentation/src/gui/tabs/training_tab.py

1688 lines
66 KiB
Python
Raw Normal View History

2025-12-05 09:50:50 +02:00
"""
Training tab for the microscopy object detection application.
Handles model training with YOLO.
"""
2025-12-10 15:46:26 +02:00
import hashlib
import shutil
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import yaml
import numpy as np
2025-12-10 15:46:26 +02:00
from PySide6.QtCore import Qt, QThread, Signal
from PySide6.QtWidgets import (
QWidget,
QVBoxLayout,
QLabel,
QGroupBox,
QHBoxLayout,
QPushButton,
QLineEdit,
QFileDialog,
QComboBox,
QFormLayout,
QMessageBox,
QTextEdit,
QProgressBar,
QSpinBox,
QDoubleSpinBox,
2025-12-11 12:04:08 +02:00
QCheckBox,
2025-12-11 23:12:39 +02:00
QScrollArea,
2025-12-10 15:46:26 +02:00
)
2025-12-05 09:50:50 +02:00
from src.database.db_manager import DatabaseManager
2025-12-10 15:46:26 +02:00
from src.model.yolo_wrapper import YOLOWrapper
2025-12-05 09:50:50 +02:00
from src.utils.config_manager import ConfigManager
from src.utils.image import Image
2025-12-05 09:50:50 +02:00
from src.utils.logger import get_logger
logger = get_logger(__name__)
DEFAULT_IMAGE_EXTENSIONS = set(Image.SUPPORTED_EXTENSIONS)
2025-12-10 15:46:26 +02:00
class TrainingWorker(QThread):
"""Background worker that runs YOLO training without blocking the UI."""
progress = Signal(int, int, dict) # current_epoch, total_epochs, metrics
finished = Signal(dict)
error = Signal(str)
def __init__(
self,
data_yaml: str,
base_model: str,
epochs: int,
batch: int,
imgsz: int,
patience: int,
lr0: float,
save_dir: str,
run_name: str,
parent: Optional[QThread] = None,
2025-12-11 12:50:34 +02:00
stage_plan: Optional[List[Dict[str, Any]]] = None,
total_epochs: Optional[int] = None,
2025-12-10 15:46:26 +02:00
):
super().__init__(parent)
self.data_yaml = data_yaml
self.base_model = base_model
self.epochs = epochs
self.batch = batch
self.imgsz = imgsz
self.patience = patience
self.lr0 = lr0
self.save_dir = save_dir
self.run_name = run_name
2025-12-11 12:50:34 +02:00
self.stage_plan = stage_plan or [
{
"label": "Single Stage",
"model_path": base_model,
"use_previous_best": False,
"params": {
"epochs": epochs,
"batch": batch,
"imgsz": imgsz,
"patience": patience,
"lr0": lr0,
"freeze": 0,
"name": run_name,
},
}
]
computed_total = sum(max(0, int((stage.get("params") or {}).get("epochs", 0))) for stage in self.stage_plan)
2025-12-11 12:50:34 +02:00
self.total_epochs = total_epochs if total_epochs else computed_total or epochs
2025-12-10 15:46:26 +02:00
self._stop_requested = False
def stop(self):
"""Request the training process to stop at the next epoch boundary."""
self._stop_requested = True
self.requestInterruption()
def run(self):
2025-12-11 12:50:34 +02:00
"""Execute YOLO training over one or more stages and emit progress/finished signals."""
2025-12-10 15:46:26 +02:00
2025-12-11 12:50:34 +02:00
completed_epochs = 0
stage_history: List[Dict[str, Any]] = []
last_stage_results: Optional[Dict[str, Any]] = None
2025-12-10 15:46:26 +02:00
2025-12-11 12:50:34 +02:00
for stage_index, stage in enumerate(self.stage_plan, start=1):
if self._stop_requested or self.isInterruptionRequested():
break
2025-12-10 15:46:26 +02:00
2025-12-11 12:50:34 +02:00
stage_label = stage.get("label") or f"Stage {stage_index}"
stage_params = dict(stage.get("params") or {})
stage_epochs = int(stage_params.get("epochs", self.epochs))
if stage_epochs <= 0:
stage_epochs = 1
batch = int(stage_params.get("batch", self.batch))
imgsz = int(stage_params.get("imgsz", self.imgsz))
patience = int(stage_params.get("patience", self.patience))
lr0 = float(stage_params.get("lr0", self.lr0))
freeze = int(stage_params.get("freeze", 0))
run_name = stage_params.get("name") or f"{self.run_name}_stage{stage_index}"
weights_path = stage.get("model_path") or self.base_model
if stage.get("use_previous_best") and last_stage_results:
weights_path = (
last_stage_results.get("best_model_path")
or last_stage_results.get("last_model_path")
or weights_path
)
wrapper = YOLOWrapper(weights_path)
stage_offset = completed_epochs
def on_epoch_end(trainer, offset=stage_offset):
current_epoch = getattr(trainer, "epoch", 0) + 1
metrics: Dict[str, float] = {}
loss_items = getattr(trainer, "loss_items", None)
if loss_items:
metrics["loss"] = float(loss_items[-1])
absolute_epoch = min(
max(1, offset + current_epoch),
max(1, self.total_epochs),
)
self.progress.emit(absolute_epoch, self.total_epochs, metrics)
if self.isInterruptionRequested() or self._stop_requested:
setattr(trainer, "stop_training", True)
callbacks = {"on_fit_epoch_end": on_epoch_end}
try:
stage_result = wrapper.train(
data_yaml=self.data_yaml,
epochs=stage_epochs,
imgsz=imgsz,
batch=batch,
patience=patience,
save_dir=self.save_dir,
name=run_name,
lr0=lr0,
callbacks=callbacks,
freeze=freeze,
)
except Exception as exc:
self.error.emit(str(exc))
return
stage_history.append(
{
"label": stage_label,
"params": stage_params,
"weights_used": weights_path,
"results": stage_result,
}
2025-12-10 15:46:26 +02:00
)
2025-12-11 12:50:34 +02:00
last_stage_results = stage_result
completed_epochs += stage_epochs
final_payload: Dict[str, Any]
if last_stage_results:
final_payload = dict(last_stage_results)
else:
final_payload = {
"success": False,
"message": "Training stopped before any stage completed.",
}
final_payload["stage_results"] = stage_history
final_payload["total_epochs_completed"] = completed_epochs
final_payload["total_epochs_planned"] = self.total_epochs
final_payload["stages_completed"] = len(stage_history)
self.finished.emit(final_payload)
2025-12-10 15:46:26 +02:00
2025-12-05 09:50:50 +02:00
class TrainingTab(QWidget):
"""Training tab for model training."""
def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
2025-12-05 09:50:50 +02:00
super().__init__(parent)
self.db_manager = db_manager
self.config_manager = config_manager
2025-12-10 15:46:26 +02:00
self.selected_dataset: Optional[Dict[str, Any]] = None
self.allowed_extensions = {
ext.lower() for ext in self.config_manager.get_allowed_extensions()
} or DEFAULT_IMAGE_EXTENSIONS
self._status_styles = {
"default": "",
"success": "color: #2e7d32; font-weight: 600;",
"warning": "color: #b26a00; font-weight: 600;",
"error": "color: #b00020; font-weight: 600;",
}
self.training_worker: Optional[TrainingWorker] = None
self._active_training_params: Optional[Dict[str, Any]] = None
self._training_cancelled = False
2025-12-05 09:50:50 +02:00
self._setup_ui()
def _setup_ui(self):
"""Setup user interface."""
2025-12-11 23:12:39 +02:00
# Create a container widget for all content
container = QWidget()
container_layout = QVBoxLayout(container)
container_layout.addWidget(self._create_dataset_group())
container_layout.addWidget(self._create_training_controls_group())
container_layout.addStretch()
# Create scroll area and set the container as its widget
scroll_area = QScrollArea()
scroll_area.setWidget(container)
scroll_area.setWidgetResizable(True)
# Set main layout with scroll area
main_layout = QVBoxLayout(self)
main_layout.setContentsMargins(0, 0, 0, 0)
main_layout.addWidget(scroll_area)
2025-12-10 15:46:26 +02:00
self._discover_datasets()
self._load_saved_dataset()
def _create_dataset_group(self) -> QGroupBox:
group = QGroupBox("Dataset Selection")
2025-12-05 09:50:50 +02:00
group_layout = QVBoxLayout()
2025-12-10 15:46:26 +02:00
description = QLabel(
"Select a YOLO-format data.yaml file to power model training. "
"You can pick from discovered datasets or browse manually."
)
description.setWordWrap(True)
group_layout.addWidget(description)
combo_layout = QHBoxLayout()
combo_label = QLabel("Discovered datasets:")
combo_layout.addWidget(combo_label)
self.dataset_combo = QComboBox()
self.dataset_combo.currentIndexChanged.connect(self._on_dataset_combo_changed)
combo_layout.addWidget(self.dataset_combo, 1)
self.rescan_button = QPushButton("Rescan")
self.rescan_button.clicked.connect(self._discover_datasets)
combo_layout.addWidget(self.rescan_button)
group_layout.addLayout(combo_layout)
path_layout = QHBoxLayout()
self.dataset_path_edit = QLineEdit()
self.dataset_path_edit.setPlaceholderText("Select a data.yaml file to continue")
self.dataset_path_edit.setReadOnly(True)
path_layout.addWidget(self.dataset_path_edit)
self.browse_button = QPushButton("Browse…")
self.browse_button.clicked.connect(self._browse_dataset)
path_layout.addWidget(self.browse_button)
group_layout.addLayout(path_layout)
action_layout = QHBoxLayout()
action_layout.addStretch()
self.generate_yaml_button = QPushButton("Generate data.yaml")
self.generate_yaml_button.clicked.connect(self._generate_data_yaml)
action_layout.addWidget(self.generate_yaml_button)
group_layout.addLayout(action_layout)
self.dataset_status_label = QLabel("No dataset selected.")
self.dataset_status_label.setWordWrap(True)
group_layout.addWidget(self.dataset_status_label)
summary_group = QGroupBox("Dataset Summary")
summary_layout = QFormLayout()
self.dataset_root_label = QLabel("")
self.dataset_root_label.setWordWrap(True)
self.dataset_root_label.setTextInteractionFlags(Qt.TextSelectableByMouse)
summary_layout.addRow("Root:", self.dataset_root_label)
self.train_count_label = QLabel("")
summary_layout.addRow("Train Images:", self.train_count_label)
self.val_count_label = QLabel("")
summary_layout.addRow("Val Images:", self.val_count_label)
self.test_count_label = QLabel("")
summary_layout.addRow("Test Images:", self.test_count_label)
self.num_classes_label = QLabel("")
summary_layout.addRow("Classes:", self.num_classes_label)
self.class_names_label = QLabel("")
self.class_names_label.setWordWrap(True)
summary_layout.addRow("Class Names:", self.class_names_label)
summary_group.setLayout(summary_layout)
group_layout.addWidget(summary_group)
2025-12-05 09:50:50 +02:00
group.setLayout(group_layout)
2025-12-10 15:46:26 +02:00
return group
2025-12-05 09:50:50 +02:00
2025-12-10 15:46:26 +02:00
def _create_training_controls_group(self) -> QGroupBox:
group = QGroupBox("Training Controls")
group_layout = QVBoxLayout()
form_layout = QFormLayout()
self.model_name_edit = QLineEdit("custom_model")
form_layout.addRow("Model Name:", self.model_name_edit)
self.model_version_edit = QLineEdit("v1")
form_layout.addRow("Version:", self.model_version_edit)
default_base_model = self.config_manager.get("models.default_base_model", "yolov8s-seg.pt")
2025-12-11 12:04:08 +02:00
base_model_choices = self.config_manager.get("models.base_model_choices", [])
self.base_model_combo = QComboBox()
self.base_model_combo.addItem("Custom path…", "")
for choice in base_model_choices:
self.base_model_combo.addItem(choice, choice)
self.base_model_combo.currentIndexChanged.connect(self._on_base_model_preset_changed)
2025-12-11 12:04:08 +02:00
form_layout.addRow("Base Model Preset:", self.base_model_combo)
2025-12-10 15:46:26 +02:00
base_model_layout = QHBoxLayout()
self.base_model_edit = QLineEdit(default_base_model)
2025-12-11 12:04:08 +02:00
self.base_model_edit.editingFinished.connect(self._on_base_model_path_edited)
2025-12-10 15:46:26 +02:00
base_model_layout.addWidget(self.base_model_edit)
self.base_model_browse_button = QPushButton("Browse…")
self.base_model_browse_button.clicked.connect(self._browse_base_model)
base_model_layout.addWidget(self.base_model_browse_button)
form_layout.addRow("Base Model (.pt):", base_model_layout)
2025-12-11 12:04:08 +02:00
self._sync_base_model_preset_selection(default_base_model)
2025-12-10 15:46:26 +02:00
models_dir = self.config_manager.get("models.models_directory", "data/models")
save_dir_layout = QHBoxLayout()
self.save_dir_edit = QLineEdit(models_dir)
save_dir_layout.addWidget(self.save_dir_edit)
self.save_dir_browse_button = QPushButton("Browse…")
self.save_dir_browse_button.clicked.connect(self._browse_save_dir)
save_dir_layout.addWidget(self.save_dir_browse_button)
form_layout.addRow("Save Directory:", save_dir_layout)
training_defaults = self.config_manager.get_section("training")
self.epochs_spin = QSpinBox()
self.epochs_spin.setRange(1, 1000)
self.epochs_spin.setValue(int(training_defaults.get("default_epochs", 10)))
form_layout.addRow("Epochs:", self.epochs_spin)
self.batch_spin = QSpinBox()
self.batch_spin.setRange(1, 256)
self.batch_spin.setValue(int(training_defaults.get("default_batch_size", 16)))
form_layout.addRow("Batch Size:", self.batch_spin)
self.imgsz_spin = QSpinBox()
self.imgsz_spin.setRange(320, 2048)
self.imgsz_spin.setSingleStep(32)
self.imgsz_spin.setValue(int(training_defaults.get("default_imgsz", 640)))
form_layout.addRow("Image Size:", self.imgsz_spin)
self.patience_spin = QSpinBox()
self.patience_spin.setRange(1, 200)
self.patience_spin.setValue(int(training_defaults.get("default_patience", 50)))
form_layout.addRow("Patience:", self.patience_spin)
self.lr_spin = QDoubleSpinBox()
self.lr_spin.setDecimals(5)
self.lr_spin.setRange(0.00001, 0.1)
self.lr_spin.setSingleStep(0.0005)
self.lr_spin.setValue(float(training_defaults.get("default_lr0", 0.01)))
form_layout.addRow("Learning Rate (lr0):", self.lr_spin)
group_layout.addLayout(form_layout)
2025-12-11 12:04:08 +02:00
self.two_stage_group = self._create_two_stage_group(training_defaults)
group_layout.addWidget(self.two_stage_group)
2025-12-10 15:46:26 +02:00
button_layout = QHBoxLayout()
self.start_training_button = QPushButton("Start Training")
self.start_training_button.clicked.connect(self._start_training)
button_layout.addWidget(self.start_training_button)
self.stop_training_button = QPushButton("Stop")
self.stop_training_button.setEnabled(False)
self.stop_training_button.clicked.connect(self._stop_training)
button_layout.addWidget(self.stop_training_button)
button_layout.addStretch()
group_layout.addLayout(button_layout)
self.training_progress_bar = QProgressBar()
self.training_progress_bar.setVisible(False)
group_layout.addWidget(self.training_progress_bar)
self.training_log = QTextEdit()
self.training_log.setReadOnly(True)
self.training_log.setMaximumHeight(200)
group_layout.addWidget(self.training_log)
group.setLayout(group_layout)
return group
2025-12-11 12:04:08 +02:00
def _create_two_stage_group(self, training_defaults: Dict[str, Any]) -> QGroupBox:
group = QGroupBox("Two-Stage Fine-Tuning")
group_layout = QVBoxLayout()
self.two_stage_checkbox = QCheckBox("Enable staged head-only + full fine-tune")
two_stage_defaults = training_defaults.get("two_stage", {}) if training_defaults else {}
self.two_stage_checkbox.setChecked(bool(two_stage_defaults.get("enabled", False)))
2025-12-11 12:04:08 +02:00
self.two_stage_checkbox.toggled.connect(self._on_two_stage_toggled)
group_layout.addWidget(self.two_stage_checkbox)
self.two_stage_controls_widget = QWidget()
controls_layout = QVBoxLayout()
controls_layout.setContentsMargins(0, 0, 0, 0)
controls_layout.setSpacing(8)
stage1_group = QGroupBox("Stage 1 — Head-only stabilization")
stage1_form = QFormLayout()
stage1_defaults = two_stage_defaults.get("stage1", {})
self.stage1_epochs_spin = QSpinBox()
self.stage1_epochs_spin.setRange(1, 500)
self.stage1_epochs_spin.setValue(int(stage1_defaults.get("epochs", 20)))
stage1_form.addRow("Epochs:", self.stage1_epochs_spin)
self.stage1_lr_spin = QDoubleSpinBox()
self.stage1_lr_spin.setDecimals(5)
self.stage1_lr_spin.setRange(0.00001, 0.1)
self.stage1_lr_spin.setSingleStep(0.0005)
self.stage1_lr_spin.setValue(float(stage1_defaults.get("lr0", 0.0005)))
stage1_form.addRow("Learning Rate:", self.stage1_lr_spin)
self.stage1_patience_spin = QSpinBox()
self.stage1_patience_spin.setRange(1, 200)
self.stage1_patience_spin.setValue(int(stage1_defaults.get("patience", 10)))
stage1_form.addRow("Patience:", self.stage1_patience_spin)
self.stage1_freeze_spin = QSpinBox()
self.stage1_freeze_spin.setRange(0, 24)
self.stage1_freeze_spin.setValue(int(stage1_defaults.get("freeze", 10)))
stage1_form.addRow("Freeze layers:", self.stage1_freeze_spin)
stage1_group.setLayout(stage1_form)
controls_layout.addWidget(stage1_group)
stage2_group = QGroupBox("Stage 2 — Full fine-tuning")
stage2_form = QFormLayout()
stage2_defaults = two_stage_defaults.get("stage2", {})
self.stage2_epochs_spin = QSpinBox()
self.stage2_epochs_spin.setRange(1, 2000)
self.stage2_epochs_spin.setValue(int(stage2_defaults.get("epochs", 150)))
stage2_form.addRow("Epochs:", self.stage2_epochs_spin)
self.stage2_lr_spin = QDoubleSpinBox()
self.stage2_lr_spin.setDecimals(5)
self.stage2_lr_spin.setRange(0.00001, 0.1)
self.stage2_lr_spin.setSingleStep(0.0005)
self.stage2_lr_spin.setValue(float(stage2_defaults.get("lr0", 0.0003)))
stage2_form.addRow("Learning Rate:", self.stage2_lr_spin)
self.stage2_patience_spin = QSpinBox()
self.stage2_patience_spin.setRange(1, 200)
self.stage2_patience_spin.setValue(int(stage2_defaults.get("patience", 30)))
stage2_form.addRow("Patience:", self.stage2_patience_spin)
stage2_group.setLayout(stage2_form)
controls_layout.addWidget(stage2_group)
helper_label = QLabel("When enabled, staged hyperparameters override the global epochs/patience/lr.")
2025-12-11 12:04:08 +02:00
helper_label.setWordWrap(True)
controls_layout.addWidget(helper_label)
self.two_stage_controls_widget.setLayout(controls_layout)
group_layout.addWidget(self.two_stage_controls_widget)
group.setLayout(group_layout)
self._on_two_stage_toggled(self.two_stage_checkbox.isChecked())
return group
def _on_two_stage_toggled(self, checked: bool):
self._refresh_two_stage_controls_enabled(checked)
def _refresh_two_stage_controls_enabled(self, checked: Optional[bool] = None):
if not hasattr(self, "two_stage_controls_widget"):
return
desired_state = checked
if desired_state is None:
desired_state = self.two_stage_checkbox.isChecked()
can_edit = self.two_stage_checkbox.isEnabled()
self.two_stage_controls_widget.setEnabled(bool(desired_state and can_edit))
def _on_base_model_preset_changed(self, index: int):
preset_value = self.base_model_combo.itemData(index)
if preset_value:
self.base_model_edit.setText(str(preset_value))
elif index == 0:
self.base_model_edit.setFocus()
def _on_base_model_path_edited(self):
self._sync_base_model_preset_selection(self.base_model_edit.text().strip())
def _sync_base_model_preset_selection(self, model_path: str):
if not hasattr(self, "base_model_combo"):
return
normalized = (model_path or "").strip()
target_index = 0
for idx in range(1, self.base_model_combo.count()):
preset_value = self.base_model_combo.itemData(idx)
if not preset_value:
continue
if normalized == preset_value:
target_index = idx
break
if normalized.endswith(f"/{preset_value}") or normalized.endswith(f"\\{preset_value}"):
2025-12-11 12:04:08 +02:00
target_index = idx
break
self.base_model_combo.blockSignals(True)
self.base_model_combo.setCurrentIndex(target_index)
self.base_model_combo.blockSignals(False)
2025-12-10 15:46:26 +02:00
def _get_dataset_search_roots(self) -> List[Path]:
roots: List[Path] = []
default_root = Path("data/datasets").expanduser()
if default_root.exists():
roots.append(default_root.resolve())
repo_root = self.config_manager.get_image_repository_path()
if repo_root:
repo_path = Path(repo_root).expanduser()
if repo_path.exists():
resolved = repo_path.resolve()
if resolved not in roots:
roots.append(resolved)
return roots
def _discover_datasets(self):
"""Populate combo box with discovered data.yaml files."""
discovered: List[Path] = []
for root in self._get_dataset_search_roots():
try:
for yaml_path in root.rglob("*.yaml"):
if yaml_path.name.lower() not in {"data.yaml", "dataset.yaml"}:
continue
2025-12-11 12:04:08 +02:00
2025-12-10 15:46:26 +02:00
discovered.append(yaml_path.resolve())
except Exception as exc:
logger.warning(f"Unable to scan {root}: {exc}")
unique_paths: List[Path] = []
seen: set[str] = set()
for path in sorted(discovered):
normalized = str(path)
if normalized not in seen:
seen.add(normalized)
unique_paths.append(path)
current_selection = self.dataset_path_edit.text().strip()
self.dataset_combo.blockSignals(True)
self.dataset_combo.clear()
self.dataset_combo.addItem("Select discovered dataset…", None)
for path in unique_paths:
display_name = f"{path.parent.name} ({path})"
self.dataset_combo.addItem(display_name, str(path))
self.dataset_combo.setEnabled(bool(unique_paths))
self.dataset_combo.blockSignals(False)
if current_selection:
self._select_combo_entry_for_path(current_selection)
def _select_combo_entry_for_path(self, target_path: str):
try:
normalized_target = str(Path(target_path).expanduser().resolve())
except Exception:
normalized_target = str(target_path)
for idx in range(1, self.dataset_combo.count()):
data = self.dataset_combo.itemData(idx)
if not data:
continue
try:
normalized_item = str(Path(data).expanduser().resolve())
except Exception:
normalized_item = str(data)
if normalized_item == normalized_target:
self.dataset_combo.setCurrentIndex(idx)
return
self.dataset_combo.setCurrentIndex(0)
def _on_dataset_combo_changed(self, index: int):
dataset_path = self.dataset_combo.itemData(index)
if dataset_path:
self._set_dataset_path(dataset_path)
def _browse_dataset(self):
"""Open a file dialog to manually select data.yaml."""
start_dir = self.config_manager.get("training.last_dataset_dir", "data/datasets")
2025-12-10 15:46:26 +02:00
start_path = Path(start_dir).expanduser()
if not start_path.exists():
start_path = Path.cwd()
file_path, _ = QFileDialog.getOpenFileName(
self,
"Select YOLO data.yaml",
str(start_path),
"YAML Files (*.yaml *.yml)",
)
if file_path:
self._set_dataset_path(file_path)
def _generate_data_yaml(self):
"""Compose a data.yaml file using database metadata and dataset folders."""
dataset_root = self._determine_dataset_root()
if dataset_root is None:
QMessageBox.warning(
self,
"Dataset Root Missing",
"Unable to locate a dataset root. Please create data/datasets or browse to an existing dataset first.",
)
return
self.generate_yaml_button.setEnabled(False)
try:
output_path = self.db_manager.compose_data_yaml(str(dataset_root))
except ValueError as exc:
logger.error(f"data.yaml generation failed: {exc}")
self._display_dataset_error(str(exc))
QMessageBox.critical(self, "data.yaml Generation Failed", str(exc))
return
except Exception as exc:
logger.exception("Unexpected error while generating data.yaml")
self._display_dataset_error("Unexpected error while generating data.yaml. Check logs for details.")
2025-12-10 15:46:26 +02:00
QMessageBox.critical(
self,
"data.yaml Generation Failed",
"An unexpected error occurred. Please inspect the logs for details.",
)
return
finally:
self.generate_yaml_button.setEnabled(True)
QMessageBox.information(
self,
"data.yaml Generated",
f"data.yaml saved to:\n{output_path}",
)
self._set_dataset_path(output_path)
self._display_dataset_success(f"Generated data.yaml at {output_path}")
def _determine_dataset_root(self) -> Optional[Path]:
"""Infer the dataset root directory for generating data.yaml."""
path_text = self.dataset_path_edit.text().strip()
if path_text:
candidate = Path(path_text).expanduser().parent
if candidate.exists():
return candidate
last_dir = self.config_manager.get("training.last_dataset_dir", "")
if last_dir:
candidate = Path(last_dir).expanduser()
if candidate.exists():
return candidate
default_root = Path("data/datasets").expanduser()
if default_root.exists():
return default_root
return None
def _set_dataset_path(self, path: str, persist: bool = True):
path_obj = Path(path).expanduser()
if not path_obj.exists():
self._display_dataset_error(f"File not found: {path_obj}")
return
if path_obj.suffix.lower() not in {".yaml", ".yml"}:
self._display_dataset_error("Please select a YAML file.")
return
self.dataset_path_edit.setText(str(path_obj))
try:
self._update_dataset_info(path_obj)
except ValueError as exc:
logger.error(f"Dataset parsing failed: {exc}")
self._display_dataset_error(str(exc))
QMessageBox.warning(self, "Invalid Dataset", f"{exc}")
return
if persist:
self.config_manager.set("training.last_dataset_yaml", str(path_obj))
self.config_manager.set("training.last_dataset_dir", str(path_obj.parent))
self.config_manager.save_config()
self._select_combo_entry_for_path(str(path_obj))
def _load_saved_dataset(self):
saved_path = self.config_manager.get("training.last_dataset_yaml", "")
if saved_path and Path(saved_path).expanduser().exists():
self._set_dataset_path(saved_path, persist=False)
elif self.dataset_combo.isEnabled() and self.dataset_combo.count() > 1:
self.dataset_combo.setCurrentIndex(1)
def _update_dataset_info(self, yaml_path: Path):
info = self._parse_dataset_yaml(yaml_path)
self.selected_dataset = info
self.dataset_root_label.setText(info["root"]) # type: ignore[arg-type]
self.train_count_label.setText(self._format_split_info(info["splits"].get("train")))
2025-12-10 15:46:26 +02:00
self.val_count_label.setText(self._format_split_info(info["splits"].get("val")))
self.test_count_label.setText(self._format_split_info(info["splits"].get("test")))
2025-12-10 15:46:26 +02:00
self.num_classes_label.setText(str(info["num_classes"]))
class_names = ", ".join(info["class_names"]) or ""
self.class_names_label.setText(class_names)
warnings = info.get("warnings", [])
if warnings:
warning_text = "Dataset loaded with warnings:\n- " + "\n- ".join(warnings)
self._display_dataset_warning(warning_text)
else:
self._display_dataset_success("Dataset ready for training.")
def _format_split_info(self, split: Optional[Dict[str, Any]]) -> str:
if not split:
return "n/a"
path = split.get("path") or "n/a"
count = split.get("count", 0)
return f"{count} @ {path}"
def _parse_dataset_yaml(self, yaml_path: Path) -> Dict[str, Any]:
try:
with open(yaml_path, "r", encoding="utf-8") as handle:
data = yaml.safe_load(handle) or {}
except yaml.YAMLError as exc:
raise ValueError(f"Invalid YAML syntax in {yaml_path}: {exc}") from exc
except OSError as exc:
raise ValueError(f"Unable to read {yaml_path}: {exc}") from exc
base_entry = data.get("path")
if base_entry:
base_path = Path(base_entry)
if not base_path.is_absolute():
base_path = (yaml_path.parent / base_path).resolve()
else:
base_path = base_path.resolve()
else:
base_path = yaml_path.parent.resolve()
warnings: List[str] = []
splits: Dict[str, Dict[str, Any]] = {}
for split_name in ("train", "val", "test"):
split_value = data.get(split_name)
split_info = {"path": "", "count": 0}
if split_value:
split_path = Path(split_value)
if not split_path.is_absolute():
split_path = (base_path / split_value).resolve()
else:
split_path = split_path.resolve()
split_info["path"] = str(split_path)
if split_path.exists():
split_info["count"] = self._count_images(split_path)
if split_info["count"] == 0:
warnings.append(f"No images found for {split_name} split at {split_path}")
2025-12-10 15:46:26 +02:00
else:
warnings.append(f"{split_name.capitalize()} path does not exist: {split_path}")
2025-12-10 15:46:26 +02:00
else:
if split_name in ("train", "val"):
warnings.append(f"{split_name.capitalize()} split missing in data.yaml")
2025-12-10 15:46:26 +02:00
splits[split_name] = split_info
names_list = self._normalize_class_names(data.get("names"))
nc_value = data.get("nc")
if nc_value is not None:
try:
nc_value = int(nc_value)
except (ValueError, TypeError):
warnings.append("Invalid 'nc' value detected; using class name count.")
nc_value = len(names_list)
else:
nc_value = len(names_list)
if not names_list and nc_value:
names_list = [f"class_{idx}" for idx in range(int(nc_value))]
elif nc_value and len(names_list) not in (0, int(nc_value)):
warnings.append(f"Number of class names ({len(names_list)}) does not match nc={nc_value}")
2025-12-10 15:46:26 +02:00
dataset_name = data.get("name") or base_path.name
return {
"yaml_path": str(yaml_path),
"root": str(base_path),
"name": dataset_name,
"splits": splits,
"class_names": names_list,
"num_classes": int(nc_value) if nc_value else len(names_list),
"warnings": warnings,
}
def _normalize_class_names(self, names_field: Any) -> List[str]:
if isinstance(names_field, dict):
def sort_key(item):
key, _ = item
try:
return int(key)
except (ValueError, TypeError):
return key
return [str(name) for _, name in sorted(names_field.items(), key=sort_key)]
if isinstance(names_field, (list, tuple)):
return [str(name) for name in names_field]
return []
def _count_images(self, directory: Path) -> int:
count = 0
try:
for file_path in directory.rglob("*"):
if file_path.is_file():
suffix = file_path.suffix.lower()
if not self.allowed_extensions or suffix in self.allowed_extensions:
count += 1
except Exception as exc:
logger.warning(f"Unable to count images in {directory}: {exc}")
return count
def _export_labels_from_database(self, dataset_info: Dict[str, Any]) -> None:
"""Write YOLO txt labels for dataset splits using database annotations."""
splits = dataset_info.get("splits", {})
if not splits:
return
class_index_map = self._build_class_index_map(dataset_info)
if not class_index_map:
self._append_training_log("Skipping label export: dataset classes do not match database entries.")
2025-12-10 15:46:26 +02:00
return
dataset_root_str = dataset_info.get("root")
dataset_yaml_path = dataset_info.get("yaml_path")
dataset_yaml = Path(dataset_yaml_path).expanduser() if dataset_yaml_path else None
2025-12-10 15:46:26 +02:00
dataset_root: Optional[Path]
if dataset_root_str:
dataset_root = Path(dataset_root_str).resolve()
else:
dataset_root = dataset_yaml.parent.resolve() if dataset_yaml else None
split_messages: List[str] = []
for split_name in ("train", "val", "test"):
split_entry = splits.get(split_name) or {}
images_dir_str = split_entry.get("path")
if not images_dir_str:
continue
images_dir = Path(images_dir_str)
if not images_dir.exists():
continue
stats = self._export_labels_for_split(
split_name=split_name,
images_dir=images_dir,
dataset_root=dataset_root or images_dir.resolve(),
class_index_map=class_index_map,
)
if not stats:
continue
message = (
f"[{split_name}] Exported {stats['total_annotations']} annotations "
f"across {stats['processed_images']} image(s)."
)
if stats["registered_images"]:
message += f" {stats['registered_images']} image(s) had database-backed annotations."
if stats["missing_records"]:
message += (
f" {stats['missing_records']} image(s) had no database entry; empty label files were written."
)
2025-12-10 15:46:26 +02:00
split_messages.append(message)
for msg in split_messages:
self._append_training_log(msg)
if dataset_yaml:
self._clear_rgb_cache_for_dataset(dataset_yaml)
def _export_labels_for_split(
self,
split_name: str,
images_dir: Path,
dataset_root: Path,
class_index_map: Dict[int, int],
) -> Optional[Dict[str, int]]:
labels_dir = self._infer_labels_dir(images_dir)
labels_dir.mkdir(parents=True, exist_ok=True)
processed_images = 0
registered_images = 0
missing_records = 0
total_annotations = 0
for image_file in images_dir.rglob("*"):
if not image_file.is_file():
continue
suffix = image_file.suffix.lower()
if self.allowed_extensions and suffix not in self.allowed_extensions:
continue
processed_images += 1
label_path = (labels_dir / image_file.relative_to(images_dir)).with_suffix(".txt")
2025-12-10 15:46:26 +02:00
label_path.parent.mkdir(parents=True, exist_ok=True)
found, annotation_entries = self._fetch_annotations_for_image(
image_file, dataset_root, images_dir, class_index_map
)
if found:
registered_images += 1
else:
missing_records += 1
annotations_written = 0
with open(label_path, "w", encoding="utf-8") as handle:
for entry in annotation_entries:
polygon = entry.get("polygon") or []
if polygon:
print(image_file, polygon[:4], polygon[-2:], entry.get("bbox"))
# coords = " ".join(f"{value:.6f}" for value in entry.get("bbox"))
# coords += " "
2025-12-10 15:46:26 +02:00
coords = " ".join(f"{value:.6f}" for value in polygon)
handle.write(f"{entry['class_idx']} {coords}\n")
annotations_written += 1
elif entry.get("bbox"):
x_center, y_center, width, height = entry["bbox"]
handle.write(f"{entry['class_idx']} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n")
2025-12-10 15:46:26 +02:00
annotations_written += 1
total_annotations += annotations_written
cache_reset_root = labels_dir.parent
self._invalidate_split_cache(cache_reset_root)
if processed_images == 0:
self._append_training_log(f"[{split_name}] No images found to export labels for.")
2025-12-10 15:46:26 +02:00
return None
return {
"split": split_name,
"processed_images": processed_images,
"registered_images": registered_images,
"missing_records": missing_records,
"total_annotations": total_annotations,
}
def _build_class_index_map(self, dataset_info: Dict[str, Any]) -> Dict[int, int]:
class_names = dataset_info.get("class_names") or []
name_to_index = {name: idx for idx, name in enumerate(class_names)}
mapping: Dict[int, int] = {}
try:
db_classes = self.db_manager.get_object_classes()
except Exception as exc:
logger.error(f"Failed to read object classes from database: {exc}")
return mapping
for cls in db_classes:
idx = name_to_index.get(cls.get("class_name"))
if idx is not None:
mapping[cls["id"]] = idx
return mapping
def _fetch_annotations_for_image(
self,
image_path: Path,
dataset_root: Path,
images_dir: Path,
class_index_map: Dict[int, int],
) -> Tuple[bool, List[Dict[str, Any]]]:
resolved_image = image_path.resolve()
candidates: List[str] = []
for base in (dataset_root, images_dir):
try:
relative = resolved_image.relative_to(base.resolve()).as_posix()
candidates.append(relative)
except ValueError:
continue
candidates.append(resolved_image.name)
image_row: Optional[Dict[str, Any]] = None
seen: set[str] = set()
for candidate in candidates:
normalized = candidate.replace("\\", "/")
if normalized in seen:
continue
seen.add(normalized)
image_row = self.db_manager.get_image_by_path(normalized)
if image_row:
break
if not image_row:
return False, []
annotations = self.db_manager.get_annotations_for_image(image_row["id"]) or []
yolo_entries: List[Dict[str, Any]] = []
for ann in annotations:
class_idx = class_index_map.get(ann.get("class_id"))
if class_idx is None:
continue
x_min = self._clamp01(float(ann.get("x_min", 0.0)))
y_min = self._clamp01(float(ann.get("y_min", 0.0)))
x_max = self._clamp01(float(ann.get("x_max", 0.0)))
y_max = self._clamp01(float(ann.get("y_max", 0.0)))
width = max(0.0, x_max - x_min)
height = max(0.0, y_max - y_min)
bbox_tuple: Optional[Tuple[float, float, float, float]] = None
if width > 0 and height > 0:
x_center = x_min + width / 2.0
y_center = y_min + height / 2.0
bbox_tuple = (x_center, y_center, width, height)
polygon = self._convert_segmentation_mask_to_polygon(
ann.get("segmentation_mask"), (x_min, y_min, x_max, y_max)
)
if not bbox_tuple and not polygon:
continue
yolo_entries.append(
{
"class_idx": class_idx,
"bbox": bbox_tuple,
"polygon": polygon,
}
)
return True, yolo_entries
def _convert_segmentation_mask_to_polygon(
self,
mask_data: Any,
bbox: Optional[Tuple[float, float, float, float]] = None,
) -> List[float]:
if not isinstance(mask_data, list):
return []
candidates: List[Tuple[float, List[float]]] = []
for order in ("yx", "xy"):
coords: List[float] = []
xs: List[float] = []
ys: List[float] = []
for point in mask_data:
if not isinstance(point, (list, tuple)) or len(point) != 2:
continue
first = float(point[0])
second = float(point[1])
if order == "yx":
x_val = self._clamp01(second)
y_val = self._clamp01(first)
else:
x_val = self._clamp01(first)
y_val = self._clamp01(second)
coords.extend([x_val, y_val])
xs.append(x_val)
ys.append(y_val)
if any(np.abs(np.array(coords[:2]) - np.array(coords[-2:])) < 1e-5):
print("Closing polygon")
coords.extend(coords[:2])
2025-12-10 15:46:26 +02:00
if len(coords) < 6:
continue
score = 0.0
if bbox:
x_min, y_min, x_max, y_max = bbox
score = (
abs((min(xs) if xs else 0.0) - x_min)
+ abs((max(xs) if xs else 0.0) - x_max)
+ abs((min(ys) if ys else 0.0) - y_min)
+ abs((max(ys) if ys else 0.0) - y_max)
)
width = max(0.0, x_max - x_min)
height = max(0.0, y_max - y_min)
x_center = x_min + width / 2.0
y_center = y_min + height / 2.0
score = (x_center, y_center, width, height)
2025-12-10 15:46:26 +02:00
candidates.append((score, coords))
if not candidates:
return []
candidates.sort(key=lambda item: item[0])
return candidates[0][1]
@staticmethod
def _clamp01(value: float) -> float:
if value < 0.0:
return 0.0
if value > 1.0:
return 1.0
return value
def _prepare_dataset_for_training(self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None) -> Path:
2025-12-10 15:46:26 +02:00
dataset_info = dataset_info or (
self.selected_dataset
if self.selected_dataset and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
2025-12-10 15:46:26 +02:00
else self._parse_dataset_yaml(dataset_yaml)
)
train_split = dataset_info.get("splits", {}).get("train") or {}
images_path_str = train_split.get("path")
if not images_path_str:
return dataset_yaml
images_path = Path(images_path_str)
if not images_path.exists():
return dataset_yaml
if not self._dataset_requires_rgb_conversion(images_path):
return dataset_yaml
cache_root = self._get_rgb_cache_root(dataset_yaml)
rgb_yaml = cache_root / "data.yaml"
if rgb_yaml.exists():
self._append_training_log(f"Detected grayscale dataset; reusing RGB cache at {cache_root}")
2025-12-10 15:46:26 +02:00
return rgb_yaml
self._append_training_log(f"Detected grayscale dataset; creating RGB cache at {cache_root}")
2025-12-10 15:46:26 +02:00
self._build_rgb_dataset(cache_root, dataset_info)
return rgb_yaml
2025-12-11 12:04:08 +02:00
def _compose_stage_plan(self, params: Dict[str, Any]) -> List[Dict[str, Any]]:
two_stage = params.get("two_stage") or {}
base_stage = {
"label": "Single Stage",
"model_path": params["base_model"],
"use_previous_best": False,
"params": {
"epochs": params["epochs"],
"batch": params["batch"],
"imgsz": params["imgsz"],
"patience": params["patience"],
"lr0": params["lr0"],
"freeze": 0,
"name": params["run_name"],
},
}
if not two_stage.get("enabled"):
return [base_stage]
stage_plan: List[Dict[str, Any]] = []
stage1 = two_stage.get("stage1", {})
stage2 = two_stage.get("stage2", {})
stage_plan.append(
{
"label": "Stage 1 — Head-only",
"model_path": params["base_model"],
"use_previous_best": False,
"params": {
"epochs": stage1.get("epochs", params["epochs"]),
"batch": params["batch"],
"imgsz": params["imgsz"],
"patience": stage1.get("patience", params["patience"]),
"lr0": stage1.get("lr0", params["lr0"]),
"freeze": stage1.get("freeze", 0),
"name": f"{params['run_name']}_head_ft",
},
}
)
stage_plan.append(
{
"label": "Stage 2 — Full",
"model_path": params["base_model"],
"use_previous_best": True,
"params": {
"epochs": stage2.get("epochs", params["epochs"]),
"batch": params["batch"],
"imgsz": params["imgsz"],
"patience": stage2.get("patience", params["patience"]),
"lr0": stage2.get("lr0", params["lr0"]),
2025-12-11 12:50:34 +02:00
"freeze": stage2.get("freeze", 0),
2025-12-11 12:04:08 +02:00
"name": f"{params['run_name']}_full_ft",
},
}
)
return stage_plan
2025-12-11 12:50:34 +02:00
def _calculate_total_stage_epochs(self, stage_plan: List[Dict[str, Any]]) -> int:
total = 0
for stage in stage_plan:
params = stage.get("params") or {}
try:
stage_epochs = int(params.get("epochs", 0))
except (TypeError, ValueError):
stage_epochs = 0
if stage_epochs > 0:
total += stage_epochs
return total
def _log_stage_plan(self, stage_plan: List[Dict[str, Any]]):
for index, stage in enumerate(stage_plan, start=1):
stage_label = stage.get("label") or f"Stage {index}"
params = stage.get("params") or {}
epochs = params.get("epochs", "?")
lr0 = params.get("lr0", "?")
patience = params.get("patience", "?")
freeze = params.get("freeze", 0)
self._append_training_log(
f"{stage_label}: epochs={epochs}, lr0={lr0}, patience={patience}, freeze={freeze}"
)
2025-12-10 15:46:26 +02:00
def _get_rgb_cache_root(self, dataset_yaml: Path) -> Path:
cache_base = Path("data/datasets/_rgb_cache")
cache_base.mkdir(parents=True, exist_ok=True)
key = hashlib.md5(str(dataset_yaml.parent.resolve()).encode()).hexdigest()[:8]
return cache_base / f"{dataset_yaml.parent.name}_{key}"
def _clear_rgb_cache_for_dataset(self, dataset_yaml: Path):
cache_root = self._get_rgb_cache_root(dataset_yaml)
if cache_root.exists():
try:
shutil.rmtree(cache_root)
logger.debug(f"Removed RGB cache at {cache_root}")
except OSError as exc:
logger.warning(f"Failed to remove RGB cache {cache_root}: {exc}")
def _dataset_requires_rgb_conversion(self, images_dir: Path) -> bool:
sample_image = self._find_first_image(images_dir)
if not sample_image:
return False
# Do not force an RGB cache for TIFF datasets.
# We handle grayscale/16-bit TIFFs via runtime Ultralytics patches that:
# - load TIFFs with `tifffile`
# - replicate grayscale to 3 channels without quantization
# - normalize uint16 correctly during training
if sample_image.suffix.lower() in {".tif", ".tiff"}:
return False
2025-12-10 15:46:26 +02:00
try:
img = Image(sample_image)
return img.pil_image.mode.upper() != "RGB"
2025-12-10 15:46:26 +02:00
except Exception as exc:
logger.warning(f"Failed to inspect image {sample_image}: {exc}")
return False
def _find_first_image(self, directory: Path) -> Optional[Path]:
if not directory.exists():
return None
for path in directory.rglob("*"):
if path.is_file() and path.suffix.lower() in self.allowed_extensions:
return path
return None
def _build_rgb_dataset(self, cache_root: Path, dataset_info: Dict[str, Any]):
if cache_root.exists():
shutil.rmtree(cache_root)
cache_root.mkdir(parents=True, exist_ok=True)
splits = dataset_info.get("splits", {})
for split_name in ("train", "val", "test"):
split_entry = splits.get(split_name)
if not split_entry:
continue
images_src = Path(split_entry.get("path", ""))
if not images_src.exists():
continue
images_dst = cache_root / split_name / "images"
self._convert_images_to_rgb(images_src, images_dst)
labels_src = self._infer_labels_dir(images_src)
if labels_src.exists():
labels_dst = cache_root / split_name / "labels"
self._copy_labels(labels_src, labels_dst)
class_names = dataset_info.get("class_names") or []
names_map = {idx: name for idx, name in enumerate(class_names)}
num_classes = dataset_info.get("num_classes") or len(class_names)
yaml_payload: Dict[str, Any] = {
"path": cache_root.as_posix(),
"names": names_map,
"nc": num_classes,
}
for split_name in ("train", "val", "test"):
images_dir = cache_root / split_name / "images"
if images_dir.exists():
yaml_payload[split_name] = f"{split_name}/images"
with open(cache_root / "data.yaml", "w", encoding="utf-8") as handle:
yaml.safe_dump(yaml_payload, handle, sort_keys=False)
def _convert_images_to_rgb(self, src_dir: Path, dst_dir: Path):
for src in src_dir.rglob("*"):
if not src.is_file() or src.suffix.lower() not in self.allowed_extensions:
continue
relative = src.relative_to(src_dir)
dst = dst_dir / relative
dst.parent.mkdir(parents=True, exist_ok=True)
try:
img_obj = Image(src)
pil_img = img_obj.pil_image
if len(pil_img.getbands()) == 1:
rgb_img = img_obj.convert_grayscale_to_rgb_preserve_range()
else:
rgb_img = pil_img.convert("RGB")
rgb_img.save(dst)
2025-12-10 15:46:26 +02:00
except Exception as exc:
logger.warning(f"Failed to convert {src} to RGB: {exc}")
def _copy_labels(self, labels_src: Path, labels_dst: Path):
label_files = list(labels_src.rglob("*.txt"))
for label_file in label_files:
relative = label_file.relative_to(labels_src)
dst = labels_dst / relative
dst.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(label_file, dst)
def _infer_labels_dir(self, images_dir: Path) -> Path:
return images_dir.parent / "labels"
def _invalidate_split_cache(self, split_root: Path):
for cache_name in ("labels.cache", "images.cache"):
cache_path = split_root / cache_name
if cache_path.exists():
try:
cache_path.unlink()
logger.debug(f"Removed stale cache file: {cache_path}")
except OSError as exc:
logger.warning(f"Failed to remove cache {cache_path}: {exc}")
def _collect_training_params(self) -> Dict[str, Any]:
model_name = self.model_name_edit.text().strip() or "custom_model"
model_version = self.model_version_edit.text().strip() or "v1"
base_model = self.base_model_edit.text().strip() or self.config_manager.get(
"models.default_base_model", "yolov8s-seg.pt"
)
save_dir = self.save_dir_edit.text().strip() or self.config_manager.get(
"models.models_directory", "data/models"
)
save_dir_path = Path(save_dir).expanduser()
save_dir_path.mkdir(parents=True, exist_ok=True)
run_name = f"{model_name}_{model_version}".replace(" ", "_")
2025-12-11 12:04:08 +02:00
two_stage_config = {
"enabled": self.two_stage_checkbox.isChecked(),
"stage1": {
"epochs": self.stage1_epochs_spin.value(),
"lr0": self.stage1_lr_spin.value(),
"patience": self.stage1_patience_spin.value(),
"freeze": self.stage1_freeze_spin.value(),
},
"stage2": {
"epochs": self.stage2_epochs_spin.value(),
"lr0": self.stage2_lr_spin.value(),
"patience": self.stage2_patience_spin.value(),
},
}
2025-12-10 15:46:26 +02:00
return {
"model_name": model_name,
"model_version": model_version,
"base_model": base_model,
"save_dir": save_dir_path.as_posix(),
"run_name": run_name,
"epochs": self.epochs_spin.value(),
"batch": self.batch_spin.value(),
"imgsz": self.imgsz_spin.value(),
"patience": self.patience_spin.value(),
"lr0": self.lr_spin.value(),
2025-12-11 12:04:08 +02:00
"two_stage": two_stage_config,
2025-12-10 15:46:26 +02:00
}
def _start_training(self):
if self.training_worker and self.training_worker.isRunning():
return
# Ensure any previous worker objects are fully cleaned up before starting
self._cleanup_training_worker()
dataset_yaml = self.dataset_path_edit.text().strip()
if not dataset_yaml:
QMessageBox.warning(
self,
"Dataset Required",
"Please select or generate a data.yaml file first.",
)
return
dataset_path = Path(dataset_yaml).expanduser()
if not dataset_path.exists():
QMessageBox.warning(self, "Invalid Dataset", "Selected data.yaml file does not exist.")
2025-12-10 15:46:26 +02:00
return
dataset_info = (
self.selected_dataset
if self.selected_dataset and self.selected_dataset.get("yaml_path") == str(dataset_path)
2025-12-10 15:46:26 +02:00
else self._parse_dataset_yaml(dataset_path)
)
self.training_log.clear()
self._export_labels_from_database(dataset_info)
dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info)
if dataset_to_use != dataset_path:
self._append_training_log(f"Using RGB-converted dataset at {dataset_to_use.parent}")
2025-12-10 15:46:26 +02:00
params = self._collect_training_params()
2025-12-11 12:50:34 +02:00
stage_plan = self._compose_stage_plan(params)
params["stage_plan"] = stage_plan
total_planned_epochs = self._calculate_total_stage_epochs(stage_plan) or params["epochs"]
2025-12-11 12:50:34 +02:00
params["total_planned_epochs"] = total_planned_epochs
2025-12-10 15:46:26 +02:00
self._active_training_params = params
self._training_cancelled = False
2025-12-11 12:50:34 +02:00
if len(stage_plan) > 1:
self._append_training_log("Two-stage fine-tuning schedule:")
self._log_stage_plan(stage_plan)
self._append_training_log(f"Starting training run '{params['run_name']}' using {params['base_model']}")
2025-12-10 15:46:26 +02:00
self.training_progress_bar.setVisible(True)
2025-12-11 12:50:34 +02:00
self.training_progress_bar.setMaximum(max(1, total_planned_epochs))
2025-12-10 15:46:26 +02:00
self.training_progress_bar.setValue(0)
self._set_training_state(True)
self.training_worker = TrainingWorker(
data_yaml=dataset_to_use.as_posix(),
base_model=params["base_model"],
epochs=params["epochs"],
batch=params["batch"],
imgsz=params["imgsz"],
patience=params["patience"],
lr0=params["lr0"],
save_dir=params["save_dir"],
run_name=params["run_name"],
2025-12-11 12:50:34 +02:00
stage_plan=stage_plan,
total_epochs=total_planned_epochs,
2025-12-10 15:46:26 +02:00
)
self.training_worker.progress.connect(self._on_training_progress)
self.training_worker.finished.connect(self._on_training_finished)
self.training_worker.error.connect(self._on_training_error)
self.training_worker.start()
def _stop_training(self):
if self.training_worker and self.training_worker.isRunning():
self._training_cancelled = True
self._append_training_log("Stop requested. Waiting for the current epoch to finish...")
2025-12-10 15:46:26 +02:00
self.training_worker.stop()
self.stop_training_button.setEnabled(False)
def _disconnect_training_worker_signals(self, worker: TrainingWorker):
for signal, slot in (
(worker.progress, self._on_training_progress),
(worker.finished, self._on_training_finished),
(worker.error, self._on_training_error),
):
try:
signal.disconnect(slot)
except (TypeError, RuntimeError):
pass
def _cleanup_training_worker(
self,
worker: Optional[TrainingWorker] = None,
*,
request_stop: bool = False,
wait_timeout_ms: int = 30000,
):
worker = worker or self.training_worker
if not worker:
return
if worker is self.training_worker:
self.training_worker = None
self._disconnect_training_worker_signals(worker)
if request_stop and worker.isRunning():
worker.stop()
if worker.isRunning():
if not worker.wait(wait_timeout_ms):
logger.warning("Training worker did not finish within %sms", wait_timeout_ms)
2025-12-10 15:46:26 +02:00
worker.deleteLater()
def shutdown(self):
"""Stop any running training worker before the tab is destroyed."""
if self.training_worker and self.training_worker.isRunning():
logger.info("Shutting down training worker before exit")
self._append_training_log("Stopping training before application exit…")
self._cleanup_training_worker(request_stop=True, wait_timeout_ms=120000)
self._training_cancelled = True
else:
self._cleanup_training_worker()
self._set_training_state(False)
self.training_progress_bar.setVisible(False)
def _on_training_progress(self, current_epoch: int, total_epochs: int, metrics: Dict[str, Any]):
2025-12-10 15:46:26 +02:00
self.training_progress_bar.setMaximum(total_epochs)
self.training_progress_bar.setValue(current_epoch)
parts = [f"Epoch {current_epoch}/{total_epochs}"]
if metrics:
metric_text = ", ".join(f"{key}: {value:.4f}" for key, value in metrics.items())
2025-12-10 15:46:26 +02:00
parts.append(metric_text)
self._append_training_log(" | ".join(parts))
def _on_training_finished(self, results: Dict[str, Any]):
self._cleanup_training_worker()
self._set_training_state(False)
self.training_progress_bar.setVisible(False)
if self._training_cancelled:
self._append_training_log("Training cancelled by user.")
QMessageBox.information(self, "Training Cancelled", "Training was stopped.")
self._training_cancelled = False
self._active_training_params = None
return
self._append_training_log("Training completed successfully.")
try:
self._register_trained_model(results)
except Exception as exc:
logger.error(f"Failed to register trained model: {exc}")
QMessageBox.warning(
self,
"Model Registration Failed",
f"Model trained but not registered: {exc}",
)
else:
QMessageBox.information(self, "Training Complete", "Training finished successfully.")
2025-12-10 15:46:26 +02:00
def _on_training_error(self, message: str):
self._cleanup_training_worker()
self._set_training_state(False)
self.training_progress_bar.setVisible(False)
self._training_cancelled = False
self._active_training_params = None
self._append_training_log(f"ERROR: {message}")
QMessageBox.critical(self, "Training Failed", message)
def _register_trained_model(self, results: Dict[str, Any]):
if not self._active_training_params:
return
params = self._active_training_params
model_path = results.get("best_model_path") or results.get("last_model_path")
if not model_path:
raise ValueError("Training results did not include a model path.")
2025-12-11 12:50:34 +02:00
effective_epochs = params.get("total_planned_epochs", params["epochs"])
2025-12-10 15:46:26 +02:00
training_params = {
2025-12-11 12:50:34 +02:00
"epochs": effective_epochs,
2025-12-10 15:46:26 +02:00
"batch": params["batch"],
"imgsz": params["imgsz"],
"patience": params["patience"],
"lr0": params["lr0"],
"run_name": params["run_name"],
2025-12-11 12:50:34 +02:00
"two_stage": params.get("two_stage"),
2025-12-10 15:46:26 +02:00
}
2025-12-11 12:50:34 +02:00
if params.get("stage_plan"):
training_params["stage_plan"] = params["stage_plan"]
if results.get("stage_results"):
training_params["stage_results"] = results["stage_results"]
if results.get("total_epochs_completed") is not None:
training_params["epochs_completed"] = results["total_epochs_completed"]
2025-12-10 15:46:26 +02:00
model_id = self.db_manager.add_model(
model_name=params["model_name"],
model_version=params["model_version"],
model_path=model_path,
base_model=params["base_model"],
training_params=training_params,
metrics=results.get("metrics"),
)
self._append_training_log(f"Registered model '{params['model_name']}' (ID {model_id}) at {model_path}")
2025-12-10 15:46:26 +02:00
self._active_training_params = None
def _set_training_state(self, is_training: bool):
self.start_training_button.setEnabled(not is_training)
self.stop_training_button.setEnabled(is_training)
self.generate_yaml_button.setEnabled(not is_training)
self.dataset_combo.setEnabled(not is_training)
self.browse_button.setEnabled(not is_training)
self.rescan_button.setEnabled(not is_training)
self.model_name_edit.setEnabled(not is_training)
self.model_version_edit.setEnabled(not is_training)
2025-12-11 12:04:08 +02:00
self.base_model_combo.setEnabled(not is_training)
2025-12-10 15:46:26 +02:00
self.base_model_edit.setEnabled(not is_training)
self.base_model_browse_button.setEnabled(not is_training)
self.save_dir_edit.setEnabled(not is_training)
self.save_dir_browse_button.setEnabled(not is_training)
self.epochs_spin.setEnabled(not is_training)
self.batch_spin.setEnabled(not is_training)
self.imgsz_spin.setEnabled(not is_training)
self.patience_spin.setEnabled(not is_training)
self.lr_spin.setEnabled(not is_training)
2025-12-11 12:04:08 +02:00
self.two_stage_checkbox.setEnabled(not is_training)
self._refresh_two_stage_controls_enabled()
2025-12-10 15:46:26 +02:00
def _append_training_log(self, message: str):
timestamp = datetime.now().strftime("%H:%M:%S")
self.training_log.append(f"[{timestamp}] {message}")
def _browse_base_model(self):
start_path = self.base_model_edit.text().strip() or "."
file_path, _ = QFileDialog.getOpenFileName(
self,
"Select Base Model Weights",
start_path,
"PyTorch weights (*.pt *.pth)",
)
if file_path:
self.base_model_edit.setText(file_path)
2025-12-11 12:04:08 +02:00
self._sync_base_model_preset_selection(file_path)
2025-12-10 15:46:26 +02:00
def _browse_save_dir(self):
start_path = self.save_dir_edit.text().strip() or "data/models"
directory = QFileDialog.getExistingDirectory(self, "Select Save Directory", start_path)
2025-12-10 15:46:26 +02:00
if directory:
self.save_dir_edit.setText(directory)
def _display_dataset_success(self, message: str):
self.dataset_status_label.setStyleSheet(self._status_styles["success"])
self.dataset_status_label.setText(message)
def _display_dataset_warning(self, message: str):
self.dataset_status_label.setStyleSheet(self._status_styles["warning"])
self.dataset_status_label.setText(message)
def _display_dataset_error(self, message: str):
self.dataset_status_label.setStyleSheet(self._status_styles["error"])
self.dataset_status_label.setText(message)
2025-12-05 09:50:50 +02:00
def refresh(self):
"""Refresh the tab."""
2025-12-10 15:46:26 +02:00
if self.training_worker and self.training_worker.isRunning():
self._append_training_log("Refresh skipped while training is running.")
return
self._discover_datasets()
current_path = self.dataset_path_edit.text().strip()
if current_path:
self._set_dataset_path(current_path, persist=False)
else:
self._load_saved_dataset()