1375 lines
50 KiB
Python
1375 lines
50 KiB
Python
"""
|
||
Training tab for the microscopy object detection application.
|
||
Handles model training with YOLO.
|
||
"""
|
||
|
||
import hashlib
|
||
import shutil
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List, Optional, Tuple
|
||
|
||
import yaml
|
||
from PIL import Image as PILImage
|
||
from PySide6.QtCore import Qt, QThread, Signal
|
||
from PySide6.QtWidgets import (
|
||
QWidget,
|
||
QVBoxLayout,
|
||
QLabel,
|
||
QGroupBox,
|
||
QHBoxLayout,
|
||
QPushButton,
|
||
QLineEdit,
|
||
QFileDialog,
|
||
QComboBox,
|
||
QFormLayout,
|
||
QMessageBox,
|
||
QTextEdit,
|
||
QProgressBar,
|
||
QSpinBox,
|
||
QDoubleSpinBox,
|
||
)
|
||
|
||
from src.database.db_manager import DatabaseManager
|
||
from src.model.yolo_wrapper import YOLOWrapper
|
||
from src.utils.config_manager import ConfigManager
|
||
from src.utils.logger import get_logger
|
||
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
DEFAULT_IMAGE_EXTENSIONS = {
|
||
".jpg",
|
||
".jpeg",
|
||
".png",
|
||
".tif",
|
||
".tiff",
|
||
".bmp",
|
||
}
|
||
|
||
|
||
class TrainingWorker(QThread):
|
||
"""Background worker that runs YOLO training without blocking the UI."""
|
||
|
||
progress = Signal(int, int, dict) # current_epoch, total_epochs, metrics
|
||
finished = Signal(dict)
|
||
error = Signal(str)
|
||
|
||
def __init__(
|
||
self,
|
||
data_yaml: str,
|
||
base_model: str,
|
||
epochs: int,
|
||
batch: int,
|
||
imgsz: int,
|
||
patience: int,
|
||
lr0: float,
|
||
save_dir: str,
|
||
run_name: str,
|
||
parent: Optional[QThread] = None,
|
||
):
|
||
super().__init__(parent)
|
||
self.data_yaml = data_yaml
|
||
self.base_model = base_model
|
||
self.epochs = epochs
|
||
self.batch = batch
|
||
self.imgsz = imgsz
|
||
self.patience = patience
|
||
self.lr0 = lr0
|
||
self.save_dir = save_dir
|
||
self.run_name = run_name
|
||
self._stop_requested = False
|
||
|
||
def stop(self):
|
||
"""Request the training process to stop at the next epoch boundary."""
|
||
self._stop_requested = True
|
||
self.requestInterruption()
|
||
|
||
def run(self):
|
||
"""Execute YOLO training and emit progress/finished signals."""
|
||
wrapper = YOLOWrapper(self.base_model)
|
||
|
||
def on_epoch_end(trainer):
|
||
current_epoch = getattr(trainer, "epoch", 0) + 1
|
||
metrics: Dict[str, float] = {}
|
||
loss_items = getattr(trainer, "loss_items", None)
|
||
if loss_items:
|
||
metrics["loss"] = float(loss_items[-1])
|
||
self.progress.emit(current_epoch, self.epochs, metrics)
|
||
if self.isInterruptionRequested() or self._stop_requested:
|
||
setattr(trainer, "stop_training", True)
|
||
|
||
callbacks = {"on_fit_epoch_end": on_epoch_end}
|
||
|
||
try:
|
||
results = wrapper.train(
|
||
data_yaml=self.data_yaml,
|
||
epochs=self.epochs,
|
||
imgsz=self.imgsz,
|
||
batch=self.batch,
|
||
patience=self.patience,
|
||
save_dir=self.save_dir,
|
||
name=self.run_name,
|
||
lr0=self.lr0,
|
||
callbacks=callbacks,
|
||
)
|
||
self.finished.emit(results)
|
||
except Exception as exc:
|
||
self.error.emit(str(exc))
|
||
|
||
|
||
class TrainingTab(QWidget):
|
||
"""Training tab for model training."""
|
||
|
||
def __init__(
|
||
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
||
):
|
||
super().__init__(parent)
|
||
self.db_manager = db_manager
|
||
self.config_manager = config_manager
|
||
|
||
self.selected_dataset: Optional[Dict[str, Any]] = None
|
||
self.allowed_extensions = {
|
||
ext.lower() for ext in self.config_manager.get_allowed_extensions()
|
||
} or DEFAULT_IMAGE_EXTENSIONS
|
||
self._status_styles = {
|
||
"default": "",
|
||
"success": "color: #2e7d32; font-weight: 600;",
|
||
"warning": "color: #b26a00; font-weight: 600;",
|
||
"error": "color: #b00020; font-weight: 600;",
|
||
}
|
||
self.training_worker: Optional[TrainingWorker] = None
|
||
self._active_training_params: Optional[Dict[str, Any]] = None
|
||
self._training_cancelled = False
|
||
|
||
self._setup_ui()
|
||
|
||
def _setup_ui(self):
|
||
"""Setup user interface."""
|
||
layout = QVBoxLayout()
|
||
|
||
layout.addWidget(self._create_dataset_group())
|
||
layout.addWidget(self._create_training_controls_group())
|
||
layout.addStretch()
|
||
self.setLayout(layout)
|
||
|
||
self._discover_datasets()
|
||
self._load_saved_dataset()
|
||
|
||
def _create_dataset_group(self) -> QGroupBox:
|
||
group = QGroupBox("Dataset Selection")
|
||
group_layout = QVBoxLayout()
|
||
|
||
description = QLabel(
|
||
"Select a YOLO-format data.yaml file to power model training. "
|
||
"You can pick from discovered datasets or browse manually."
|
||
)
|
||
description.setWordWrap(True)
|
||
group_layout.addWidget(description)
|
||
|
||
combo_layout = QHBoxLayout()
|
||
combo_label = QLabel("Discovered datasets:")
|
||
combo_layout.addWidget(combo_label)
|
||
|
||
self.dataset_combo = QComboBox()
|
||
self.dataset_combo.currentIndexChanged.connect(self._on_dataset_combo_changed)
|
||
combo_layout.addWidget(self.dataset_combo, 1)
|
||
|
||
self.rescan_button = QPushButton("Rescan")
|
||
self.rescan_button.clicked.connect(self._discover_datasets)
|
||
combo_layout.addWidget(self.rescan_button)
|
||
|
||
group_layout.addLayout(combo_layout)
|
||
|
||
path_layout = QHBoxLayout()
|
||
self.dataset_path_edit = QLineEdit()
|
||
self.dataset_path_edit.setPlaceholderText("Select a data.yaml file to continue")
|
||
self.dataset_path_edit.setReadOnly(True)
|
||
path_layout.addWidget(self.dataset_path_edit)
|
||
|
||
self.browse_button = QPushButton("Browse…")
|
||
self.browse_button.clicked.connect(self._browse_dataset)
|
||
path_layout.addWidget(self.browse_button)
|
||
|
||
group_layout.addLayout(path_layout)
|
||
|
||
action_layout = QHBoxLayout()
|
||
action_layout.addStretch()
|
||
self.generate_yaml_button = QPushButton("Generate data.yaml")
|
||
self.generate_yaml_button.clicked.connect(self._generate_data_yaml)
|
||
action_layout.addWidget(self.generate_yaml_button)
|
||
group_layout.addLayout(action_layout)
|
||
|
||
self.dataset_status_label = QLabel("No dataset selected.")
|
||
self.dataset_status_label.setWordWrap(True)
|
||
group_layout.addWidget(self.dataset_status_label)
|
||
|
||
summary_group = QGroupBox("Dataset Summary")
|
||
summary_layout = QFormLayout()
|
||
|
||
self.dataset_root_label = QLabel("–")
|
||
self.dataset_root_label.setWordWrap(True)
|
||
self.dataset_root_label.setTextInteractionFlags(Qt.TextSelectableByMouse)
|
||
summary_layout.addRow("Root:", self.dataset_root_label)
|
||
|
||
self.train_count_label = QLabel("–")
|
||
summary_layout.addRow("Train Images:", self.train_count_label)
|
||
|
||
self.val_count_label = QLabel("–")
|
||
summary_layout.addRow("Val Images:", self.val_count_label)
|
||
|
||
self.test_count_label = QLabel("–")
|
||
summary_layout.addRow("Test Images:", self.test_count_label)
|
||
|
||
self.num_classes_label = QLabel("–")
|
||
summary_layout.addRow("Classes:", self.num_classes_label)
|
||
|
||
self.class_names_label = QLabel("–")
|
||
self.class_names_label.setWordWrap(True)
|
||
summary_layout.addRow("Class Names:", self.class_names_label)
|
||
|
||
summary_group.setLayout(summary_layout)
|
||
group_layout.addWidget(summary_group)
|
||
|
||
group.setLayout(group_layout)
|
||
return group
|
||
|
||
def _create_training_controls_group(self) -> QGroupBox:
|
||
group = QGroupBox("Training Controls")
|
||
group_layout = QVBoxLayout()
|
||
|
||
form_layout = QFormLayout()
|
||
|
||
self.model_name_edit = QLineEdit("custom_model")
|
||
form_layout.addRow("Model Name:", self.model_name_edit)
|
||
|
||
self.model_version_edit = QLineEdit("v1")
|
||
form_layout.addRow("Version:", self.model_version_edit)
|
||
|
||
default_base_model = self.config_manager.get(
|
||
"models.default_base_model", "yolov8s-seg.pt"
|
||
)
|
||
base_model_layout = QHBoxLayout()
|
||
self.base_model_edit = QLineEdit(default_base_model)
|
||
base_model_layout.addWidget(self.base_model_edit)
|
||
self.base_model_browse_button = QPushButton("Browse…")
|
||
self.base_model_browse_button.clicked.connect(self._browse_base_model)
|
||
base_model_layout.addWidget(self.base_model_browse_button)
|
||
form_layout.addRow("Base Model (.pt):", base_model_layout)
|
||
|
||
models_dir = self.config_manager.get("models.models_directory", "data/models")
|
||
save_dir_layout = QHBoxLayout()
|
||
self.save_dir_edit = QLineEdit(models_dir)
|
||
save_dir_layout.addWidget(self.save_dir_edit)
|
||
self.save_dir_browse_button = QPushButton("Browse…")
|
||
self.save_dir_browse_button.clicked.connect(self._browse_save_dir)
|
||
save_dir_layout.addWidget(self.save_dir_browse_button)
|
||
form_layout.addRow("Save Directory:", save_dir_layout)
|
||
|
||
training_defaults = self.config_manager.get_section("training")
|
||
|
||
self.epochs_spin = QSpinBox()
|
||
self.epochs_spin.setRange(1, 1000)
|
||
self.epochs_spin.setValue(int(training_defaults.get("default_epochs", 10)))
|
||
form_layout.addRow("Epochs:", self.epochs_spin)
|
||
|
||
self.batch_spin = QSpinBox()
|
||
self.batch_spin.setRange(1, 256)
|
||
self.batch_spin.setValue(int(training_defaults.get("default_batch_size", 16)))
|
||
form_layout.addRow("Batch Size:", self.batch_spin)
|
||
|
||
self.imgsz_spin = QSpinBox()
|
||
self.imgsz_spin.setRange(320, 2048)
|
||
self.imgsz_spin.setSingleStep(32)
|
||
self.imgsz_spin.setValue(int(training_defaults.get("default_imgsz", 640)))
|
||
form_layout.addRow("Image Size:", self.imgsz_spin)
|
||
|
||
self.patience_spin = QSpinBox()
|
||
self.patience_spin.setRange(1, 200)
|
||
self.patience_spin.setValue(int(training_defaults.get("default_patience", 50)))
|
||
form_layout.addRow("Patience:", self.patience_spin)
|
||
|
||
self.lr_spin = QDoubleSpinBox()
|
||
self.lr_spin.setDecimals(5)
|
||
self.lr_spin.setRange(0.00001, 0.1)
|
||
self.lr_spin.setSingleStep(0.0005)
|
||
self.lr_spin.setValue(float(training_defaults.get("default_lr0", 0.01)))
|
||
form_layout.addRow("Learning Rate (lr0):", self.lr_spin)
|
||
|
||
group_layout.addLayout(form_layout)
|
||
|
||
button_layout = QHBoxLayout()
|
||
self.start_training_button = QPushButton("Start Training")
|
||
self.start_training_button.clicked.connect(self._start_training)
|
||
button_layout.addWidget(self.start_training_button)
|
||
|
||
self.stop_training_button = QPushButton("Stop")
|
||
self.stop_training_button.setEnabled(False)
|
||
self.stop_training_button.clicked.connect(self._stop_training)
|
||
button_layout.addWidget(self.stop_training_button)
|
||
button_layout.addStretch()
|
||
group_layout.addLayout(button_layout)
|
||
|
||
self.training_progress_bar = QProgressBar()
|
||
self.training_progress_bar.setVisible(False)
|
||
group_layout.addWidget(self.training_progress_bar)
|
||
|
||
self.training_log = QTextEdit()
|
||
self.training_log.setReadOnly(True)
|
||
self.training_log.setMaximumHeight(200)
|
||
group_layout.addWidget(self.training_log)
|
||
|
||
group.setLayout(group_layout)
|
||
return group
|
||
|
||
def _get_dataset_search_roots(self) -> List[Path]:
|
||
roots: List[Path] = []
|
||
default_root = Path("data/datasets").expanduser()
|
||
if default_root.exists():
|
||
roots.append(default_root.resolve())
|
||
|
||
repo_root = self.config_manager.get_image_repository_path()
|
||
if repo_root:
|
||
repo_path = Path(repo_root).expanduser()
|
||
if repo_path.exists():
|
||
resolved = repo_path.resolve()
|
||
if resolved not in roots:
|
||
roots.append(resolved)
|
||
|
||
return roots
|
||
|
||
def _discover_datasets(self):
|
||
"""Populate combo box with discovered data.yaml files."""
|
||
discovered: List[Path] = []
|
||
for root in self._get_dataset_search_roots():
|
||
try:
|
||
for yaml_path in root.rglob("*.yaml"):
|
||
if yaml_path.name.lower() not in {"data.yaml", "dataset.yaml"}:
|
||
continue
|
||
discovered.append(yaml_path.resolve())
|
||
except Exception as exc:
|
||
logger.warning(f"Unable to scan {root}: {exc}")
|
||
|
||
unique_paths: List[Path] = []
|
||
seen: set[str] = set()
|
||
for path in sorted(discovered):
|
||
normalized = str(path)
|
||
if normalized not in seen:
|
||
seen.add(normalized)
|
||
unique_paths.append(path)
|
||
|
||
current_selection = self.dataset_path_edit.text().strip()
|
||
|
||
self.dataset_combo.blockSignals(True)
|
||
self.dataset_combo.clear()
|
||
self.dataset_combo.addItem("Select discovered dataset…", None)
|
||
|
||
for path in unique_paths:
|
||
display_name = f"{path.parent.name} ({path})"
|
||
self.dataset_combo.addItem(display_name, str(path))
|
||
|
||
self.dataset_combo.setEnabled(bool(unique_paths))
|
||
self.dataset_combo.blockSignals(False)
|
||
|
||
if current_selection:
|
||
self._select_combo_entry_for_path(current_selection)
|
||
|
||
def _select_combo_entry_for_path(self, target_path: str):
|
||
try:
|
||
normalized_target = str(Path(target_path).expanduser().resolve())
|
||
except Exception:
|
||
normalized_target = str(target_path)
|
||
|
||
for idx in range(1, self.dataset_combo.count()):
|
||
data = self.dataset_combo.itemData(idx)
|
||
if not data:
|
||
continue
|
||
try:
|
||
normalized_item = str(Path(data).expanduser().resolve())
|
||
except Exception:
|
||
normalized_item = str(data)
|
||
|
||
if normalized_item == normalized_target:
|
||
self.dataset_combo.setCurrentIndex(idx)
|
||
return
|
||
|
||
self.dataset_combo.setCurrentIndex(0)
|
||
|
||
def _on_dataset_combo_changed(self, index: int):
|
||
dataset_path = self.dataset_combo.itemData(index)
|
||
if dataset_path:
|
||
self._set_dataset_path(dataset_path)
|
||
|
||
def _browse_dataset(self):
|
||
"""Open a file dialog to manually select data.yaml."""
|
||
start_dir = self.config_manager.get(
|
||
"training.last_dataset_dir", "data/datasets"
|
||
)
|
||
start_path = Path(start_dir).expanduser()
|
||
if not start_path.exists():
|
||
start_path = Path.cwd()
|
||
|
||
file_path, _ = QFileDialog.getOpenFileName(
|
||
self,
|
||
"Select YOLO data.yaml",
|
||
str(start_path),
|
||
"YAML Files (*.yaml *.yml)",
|
||
)
|
||
|
||
if file_path:
|
||
self._set_dataset_path(file_path)
|
||
|
||
def _generate_data_yaml(self):
|
||
"""Compose a data.yaml file using database metadata and dataset folders."""
|
||
dataset_root = self._determine_dataset_root()
|
||
if dataset_root is None:
|
||
QMessageBox.warning(
|
||
self,
|
||
"Dataset Root Missing",
|
||
"Unable to locate a dataset root. Please create data/datasets or browse to an existing dataset first.",
|
||
)
|
||
return
|
||
|
||
self.generate_yaml_button.setEnabled(False)
|
||
try:
|
||
output_path = self.db_manager.compose_data_yaml(str(dataset_root))
|
||
except ValueError as exc:
|
||
logger.error(f"data.yaml generation failed: {exc}")
|
||
self._display_dataset_error(str(exc))
|
||
QMessageBox.critical(self, "data.yaml Generation Failed", str(exc))
|
||
return
|
||
except Exception as exc:
|
||
logger.exception("Unexpected error while generating data.yaml")
|
||
self._display_dataset_error(
|
||
"Unexpected error while generating data.yaml. Check logs for details."
|
||
)
|
||
QMessageBox.critical(
|
||
self,
|
||
"data.yaml Generation Failed",
|
||
"An unexpected error occurred. Please inspect the logs for details.",
|
||
)
|
||
return
|
||
finally:
|
||
self.generate_yaml_button.setEnabled(True)
|
||
|
||
QMessageBox.information(
|
||
self,
|
||
"data.yaml Generated",
|
||
f"data.yaml saved to:\n{output_path}",
|
||
)
|
||
self._set_dataset_path(output_path)
|
||
self._display_dataset_success(f"Generated data.yaml at {output_path}")
|
||
|
||
def _determine_dataset_root(self) -> Optional[Path]:
|
||
"""Infer the dataset root directory for generating data.yaml."""
|
||
path_text = self.dataset_path_edit.text().strip()
|
||
if path_text:
|
||
candidate = Path(path_text).expanduser().parent
|
||
if candidate.exists():
|
||
return candidate
|
||
|
||
last_dir = self.config_manager.get("training.last_dataset_dir", "")
|
||
if last_dir:
|
||
candidate = Path(last_dir).expanduser()
|
||
if candidate.exists():
|
||
return candidate
|
||
|
||
default_root = Path("data/datasets").expanduser()
|
||
if default_root.exists():
|
||
return default_root
|
||
|
||
return None
|
||
|
||
def _set_dataset_path(self, path: str, persist: bool = True):
|
||
path_obj = Path(path).expanduser()
|
||
if not path_obj.exists():
|
||
self._display_dataset_error(f"File not found: {path_obj}")
|
||
return
|
||
|
||
if path_obj.suffix.lower() not in {".yaml", ".yml"}:
|
||
self._display_dataset_error("Please select a YAML file.")
|
||
return
|
||
|
||
self.dataset_path_edit.setText(str(path_obj))
|
||
|
||
try:
|
||
self._update_dataset_info(path_obj)
|
||
except ValueError as exc:
|
||
logger.error(f"Dataset parsing failed: {exc}")
|
||
self._display_dataset_error(str(exc))
|
||
QMessageBox.warning(self, "Invalid Dataset", f"{exc}")
|
||
return
|
||
|
||
if persist:
|
||
self.config_manager.set("training.last_dataset_yaml", str(path_obj))
|
||
self.config_manager.set("training.last_dataset_dir", str(path_obj.parent))
|
||
self.config_manager.save_config()
|
||
|
||
self._select_combo_entry_for_path(str(path_obj))
|
||
|
||
def _load_saved_dataset(self):
|
||
saved_path = self.config_manager.get("training.last_dataset_yaml", "")
|
||
if saved_path and Path(saved_path).expanduser().exists():
|
||
self._set_dataset_path(saved_path, persist=False)
|
||
elif self.dataset_combo.isEnabled() and self.dataset_combo.count() > 1:
|
||
self.dataset_combo.setCurrentIndex(1)
|
||
|
||
def _update_dataset_info(self, yaml_path: Path):
|
||
info = self._parse_dataset_yaml(yaml_path)
|
||
self.selected_dataset = info
|
||
|
||
self.dataset_root_label.setText(info["root"]) # type: ignore[arg-type]
|
||
self.train_count_label.setText(
|
||
self._format_split_info(info["splits"].get("train"))
|
||
)
|
||
self.val_count_label.setText(self._format_split_info(info["splits"].get("val")))
|
||
self.test_count_label.setText(
|
||
self._format_split_info(info["splits"].get("test"))
|
||
)
|
||
self.num_classes_label.setText(str(info["num_classes"]))
|
||
class_names = ", ".join(info["class_names"]) or "–"
|
||
self.class_names_label.setText(class_names)
|
||
|
||
warnings = info.get("warnings", [])
|
||
if warnings:
|
||
warning_text = "Dataset loaded with warnings:\n- " + "\n- ".join(warnings)
|
||
self._display_dataset_warning(warning_text)
|
||
else:
|
||
self._display_dataset_success("Dataset ready for training.")
|
||
|
||
def _format_split_info(self, split: Optional[Dict[str, Any]]) -> str:
|
||
if not split:
|
||
return "n/a"
|
||
path = split.get("path") or "n/a"
|
||
count = split.get("count", 0)
|
||
return f"{count} @ {path}"
|
||
|
||
def _parse_dataset_yaml(self, yaml_path: Path) -> Dict[str, Any]:
|
||
try:
|
||
with open(yaml_path, "r", encoding="utf-8") as handle:
|
||
data = yaml.safe_load(handle) or {}
|
||
except yaml.YAMLError as exc:
|
||
raise ValueError(f"Invalid YAML syntax in {yaml_path}: {exc}") from exc
|
||
except OSError as exc:
|
||
raise ValueError(f"Unable to read {yaml_path}: {exc}") from exc
|
||
|
||
base_entry = data.get("path")
|
||
if base_entry:
|
||
base_path = Path(base_entry)
|
||
if not base_path.is_absolute():
|
||
base_path = (yaml_path.parent / base_path).resolve()
|
||
else:
|
||
base_path = base_path.resolve()
|
||
else:
|
||
base_path = yaml_path.parent.resolve()
|
||
|
||
warnings: List[str] = []
|
||
splits: Dict[str, Dict[str, Any]] = {}
|
||
for split_name in ("train", "val", "test"):
|
||
split_value = data.get(split_name)
|
||
split_info = {"path": "", "count": 0}
|
||
if split_value:
|
||
split_path = Path(split_value)
|
||
if not split_path.is_absolute():
|
||
split_path = (base_path / split_value).resolve()
|
||
else:
|
||
split_path = split_path.resolve()
|
||
split_info["path"] = str(split_path)
|
||
|
||
if split_path.exists():
|
||
split_info["count"] = self._count_images(split_path)
|
||
if split_info["count"] == 0:
|
||
warnings.append(
|
||
f"No images found for {split_name} split at {split_path}"
|
||
)
|
||
else:
|
||
warnings.append(
|
||
f"{split_name.capitalize()} path does not exist: {split_path}"
|
||
)
|
||
else:
|
||
if split_name in ("train", "val"):
|
||
warnings.append(
|
||
f"{split_name.capitalize()} split missing in data.yaml"
|
||
)
|
||
splits[split_name] = split_info
|
||
|
||
names_list = self._normalize_class_names(data.get("names"))
|
||
|
||
nc_value = data.get("nc")
|
||
if nc_value is not None:
|
||
try:
|
||
nc_value = int(nc_value)
|
||
except (ValueError, TypeError):
|
||
warnings.append("Invalid 'nc' value detected; using class name count.")
|
||
nc_value = len(names_list)
|
||
else:
|
||
nc_value = len(names_list)
|
||
|
||
if not names_list and nc_value:
|
||
names_list = [f"class_{idx}" for idx in range(int(nc_value))]
|
||
elif nc_value and len(names_list) not in (0, int(nc_value)):
|
||
warnings.append(
|
||
f"Number of class names ({len(names_list)}) does not match nc={nc_value}"
|
||
)
|
||
|
||
dataset_name = data.get("name") or base_path.name
|
||
|
||
return {
|
||
"yaml_path": str(yaml_path),
|
||
"root": str(base_path),
|
||
"name": dataset_name,
|
||
"splits": splits,
|
||
"class_names": names_list,
|
||
"num_classes": int(nc_value) if nc_value else len(names_list),
|
||
"warnings": warnings,
|
||
}
|
||
|
||
def _normalize_class_names(self, names_field: Any) -> List[str]:
|
||
if isinstance(names_field, dict):
|
||
|
||
def sort_key(item):
|
||
key, _ = item
|
||
try:
|
||
return int(key)
|
||
except (ValueError, TypeError):
|
||
return key
|
||
|
||
return [str(name) for _, name in sorted(names_field.items(), key=sort_key)]
|
||
|
||
if isinstance(names_field, (list, tuple)):
|
||
return [str(name) for name in names_field]
|
||
|
||
return []
|
||
|
||
def _count_images(self, directory: Path) -> int:
|
||
count = 0
|
||
try:
|
||
for file_path in directory.rglob("*"):
|
||
if file_path.is_file():
|
||
suffix = file_path.suffix.lower()
|
||
if not self.allowed_extensions or suffix in self.allowed_extensions:
|
||
count += 1
|
||
except Exception as exc:
|
||
logger.warning(f"Unable to count images in {directory}: {exc}")
|
||
return count
|
||
|
||
def _export_labels_from_database(self, dataset_info: Dict[str, Any]) -> None:
|
||
"""Write YOLO txt labels for dataset splits using database annotations."""
|
||
|
||
splits = dataset_info.get("splits", {})
|
||
if not splits:
|
||
return
|
||
|
||
class_index_map = self._build_class_index_map(dataset_info)
|
||
if not class_index_map:
|
||
self._append_training_log(
|
||
"Skipping label export: dataset classes do not match database entries."
|
||
)
|
||
return
|
||
|
||
dataset_root_str = dataset_info.get("root")
|
||
dataset_yaml_path = dataset_info.get("yaml_path")
|
||
dataset_yaml = (
|
||
Path(dataset_yaml_path).expanduser() if dataset_yaml_path else None
|
||
)
|
||
dataset_root: Optional[Path]
|
||
if dataset_root_str:
|
||
dataset_root = Path(dataset_root_str).resolve()
|
||
else:
|
||
dataset_root = dataset_yaml.parent.resolve() if dataset_yaml else None
|
||
|
||
split_messages: List[str] = []
|
||
for split_name in ("train", "val", "test"):
|
||
split_entry = splits.get(split_name) or {}
|
||
images_dir_str = split_entry.get("path")
|
||
if not images_dir_str:
|
||
continue
|
||
|
||
images_dir = Path(images_dir_str)
|
||
if not images_dir.exists():
|
||
continue
|
||
|
||
stats = self._export_labels_for_split(
|
||
split_name=split_name,
|
||
images_dir=images_dir,
|
||
dataset_root=dataset_root or images_dir.resolve(),
|
||
class_index_map=class_index_map,
|
||
)
|
||
if not stats:
|
||
continue
|
||
|
||
message = (
|
||
f"[{split_name}] Exported {stats['total_annotations']} annotations "
|
||
f"across {stats['processed_images']} image(s)."
|
||
)
|
||
if stats["registered_images"]:
|
||
message += f" {stats['registered_images']} image(s) had database-backed annotations."
|
||
if stats["missing_records"]:
|
||
message += f" {stats['missing_records']} image(s) had no database entry; empty label files were written."
|
||
split_messages.append(message)
|
||
|
||
for msg in split_messages:
|
||
self._append_training_log(msg)
|
||
|
||
if dataset_yaml:
|
||
self._clear_rgb_cache_for_dataset(dataset_yaml)
|
||
|
||
def _export_labels_for_split(
|
||
self,
|
||
split_name: str,
|
||
images_dir: Path,
|
||
dataset_root: Path,
|
||
class_index_map: Dict[int, int],
|
||
) -> Optional[Dict[str, int]]:
|
||
labels_dir = self._infer_labels_dir(images_dir)
|
||
labels_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
processed_images = 0
|
||
registered_images = 0
|
||
missing_records = 0
|
||
total_annotations = 0
|
||
|
||
for image_file in images_dir.rglob("*"):
|
||
if not image_file.is_file():
|
||
continue
|
||
suffix = image_file.suffix.lower()
|
||
if self.allowed_extensions and suffix not in self.allowed_extensions:
|
||
continue
|
||
|
||
processed_images += 1
|
||
label_path = (labels_dir / image_file.relative_to(images_dir)).with_suffix(
|
||
".txt"
|
||
)
|
||
label_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
found, annotation_entries = self._fetch_annotations_for_image(
|
||
image_file, dataset_root, images_dir, class_index_map
|
||
)
|
||
if found:
|
||
registered_images += 1
|
||
else:
|
||
missing_records += 1
|
||
|
||
annotations_written = 0
|
||
with open(label_path, "w", encoding="utf-8") as handle:
|
||
for entry in annotation_entries:
|
||
polygon = entry.get("polygon") or []
|
||
if polygon:
|
||
coords = " ".join(f"{value:.6f}" for value in polygon)
|
||
handle.write(f"{entry['class_idx']} {coords}\n")
|
||
annotations_written += 1
|
||
elif entry.get("bbox"):
|
||
x_center, y_center, width, height = entry["bbox"]
|
||
handle.write(
|
||
f"{entry['class_idx']} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n"
|
||
)
|
||
annotations_written += 1
|
||
|
||
total_annotations += annotations_written
|
||
|
||
cache_reset_root = labels_dir.parent
|
||
self._invalidate_split_cache(cache_reset_root)
|
||
|
||
if processed_images == 0:
|
||
self._append_training_log(
|
||
f"[{split_name}] No images found to export labels for."
|
||
)
|
||
return None
|
||
|
||
return {
|
||
"split": split_name,
|
||
"processed_images": processed_images,
|
||
"registered_images": registered_images,
|
||
"missing_records": missing_records,
|
||
"total_annotations": total_annotations,
|
||
}
|
||
|
||
def _build_class_index_map(self, dataset_info: Dict[str, Any]) -> Dict[int, int]:
|
||
class_names = dataset_info.get("class_names") or []
|
||
name_to_index = {name: idx for idx, name in enumerate(class_names)}
|
||
mapping: Dict[int, int] = {}
|
||
try:
|
||
db_classes = self.db_manager.get_object_classes()
|
||
except Exception as exc:
|
||
logger.error(f"Failed to read object classes from database: {exc}")
|
||
return mapping
|
||
|
||
for cls in db_classes:
|
||
idx = name_to_index.get(cls.get("class_name"))
|
||
if idx is not None:
|
||
mapping[cls["id"]] = idx
|
||
return mapping
|
||
|
||
def _fetch_annotations_for_image(
|
||
self,
|
||
image_path: Path,
|
||
dataset_root: Path,
|
||
images_dir: Path,
|
||
class_index_map: Dict[int, int],
|
||
) -> Tuple[bool, List[Dict[str, Any]]]:
|
||
resolved_image = image_path.resolve()
|
||
candidates: List[str] = []
|
||
|
||
for base in (dataset_root, images_dir):
|
||
try:
|
||
relative = resolved_image.relative_to(base.resolve()).as_posix()
|
||
candidates.append(relative)
|
||
except ValueError:
|
||
continue
|
||
|
||
candidates.append(resolved_image.name)
|
||
|
||
image_row: Optional[Dict[str, Any]] = None
|
||
seen: set[str] = set()
|
||
for candidate in candidates:
|
||
normalized = candidate.replace("\\", "/")
|
||
if normalized in seen:
|
||
continue
|
||
seen.add(normalized)
|
||
image_row = self.db_manager.get_image_by_path(normalized)
|
||
if image_row:
|
||
break
|
||
|
||
if not image_row:
|
||
return False, []
|
||
|
||
annotations = self.db_manager.get_annotations_for_image(image_row["id"]) or []
|
||
yolo_entries: List[Dict[str, Any]] = []
|
||
|
||
for ann in annotations:
|
||
class_idx = class_index_map.get(ann.get("class_id"))
|
||
if class_idx is None:
|
||
continue
|
||
|
||
x_min = self._clamp01(float(ann.get("x_min", 0.0)))
|
||
y_min = self._clamp01(float(ann.get("y_min", 0.0)))
|
||
x_max = self._clamp01(float(ann.get("x_max", 0.0)))
|
||
y_max = self._clamp01(float(ann.get("y_max", 0.0)))
|
||
|
||
width = max(0.0, x_max - x_min)
|
||
height = max(0.0, y_max - y_min)
|
||
bbox_tuple: Optional[Tuple[float, float, float, float]] = None
|
||
if width > 0 and height > 0:
|
||
x_center = x_min + width / 2.0
|
||
y_center = y_min + height / 2.0
|
||
bbox_tuple = (x_center, y_center, width, height)
|
||
|
||
polygon = self._convert_segmentation_mask_to_polygon(
|
||
ann.get("segmentation_mask"), (x_min, y_min, x_max, y_max)
|
||
)
|
||
|
||
if not bbox_tuple and not polygon:
|
||
continue
|
||
|
||
yolo_entries.append(
|
||
{
|
||
"class_idx": class_idx,
|
||
"bbox": bbox_tuple,
|
||
"polygon": polygon,
|
||
}
|
||
)
|
||
|
||
return True, yolo_entries
|
||
|
||
def _convert_segmentation_mask_to_polygon(
|
||
self,
|
||
mask_data: Any,
|
||
bbox: Optional[Tuple[float, float, float, float]] = None,
|
||
) -> List[float]:
|
||
if not isinstance(mask_data, list):
|
||
return []
|
||
|
||
candidates: List[Tuple[float, List[float]]] = []
|
||
for order in ("yx", "xy"):
|
||
coords: List[float] = []
|
||
xs: List[float] = []
|
||
ys: List[float] = []
|
||
for point in mask_data:
|
||
if not isinstance(point, (list, tuple)) or len(point) != 2:
|
||
continue
|
||
first = float(point[0])
|
||
second = float(point[1])
|
||
if order == "yx":
|
||
x_val = self._clamp01(second)
|
||
y_val = self._clamp01(first)
|
||
else:
|
||
x_val = self._clamp01(first)
|
||
y_val = self._clamp01(second)
|
||
coords.extend([x_val, y_val])
|
||
xs.append(x_val)
|
||
ys.append(y_val)
|
||
|
||
if len(coords) < 6:
|
||
continue
|
||
|
||
score = 0.0
|
||
if bbox:
|
||
x_min, y_min, x_max, y_max = bbox
|
||
score = (
|
||
abs((min(xs) if xs else 0.0) - x_min)
|
||
+ abs((max(xs) if xs else 0.0) - x_max)
|
||
+ abs((min(ys) if ys else 0.0) - y_min)
|
||
+ abs((max(ys) if ys else 0.0) - y_max)
|
||
)
|
||
|
||
candidates.append((score, coords))
|
||
|
||
if not candidates:
|
||
return []
|
||
|
||
candidates.sort(key=lambda item: item[0])
|
||
return candidates[0][1]
|
||
|
||
@staticmethod
|
||
def _clamp01(value: float) -> float:
|
||
if value < 0.0:
|
||
return 0.0
|
||
if value > 1.0:
|
||
return 1.0
|
||
return value
|
||
|
||
def _prepare_dataset_for_training(
|
||
self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None
|
||
) -> Path:
|
||
dataset_info = dataset_info or (
|
||
self.selected_dataset
|
||
if self.selected_dataset
|
||
and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
|
||
else self._parse_dataset_yaml(dataset_yaml)
|
||
)
|
||
|
||
train_split = dataset_info.get("splits", {}).get("train") or {}
|
||
images_path_str = train_split.get("path")
|
||
if not images_path_str:
|
||
return dataset_yaml
|
||
|
||
images_path = Path(images_path_str)
|
||
if not images_path.exists():
|
||
return dataset_yaml
|
||
|
||
if not self._dataset_requires_rgb_conversion(images_path):
|
||
return dataset_yaml
|
||
|
||
cache_root = self._get_rgb_cache_root(dataset_yaml)
|
||
rgb_yaml = cache_root / "data.yaml"
|
||
if rgb_yaml.exists():
|
||
self._append_training_log(
|
||
f"Detected grayscale dataset; reusing RGB cache at {cache_root}"
|
||
)
|
||
return rgb_yaml
|
||
|
||
self._append_training_log(
|
||
f"Detected grayscale dataset; creating RGB cache at {cache_root}"
|
||
)
|
||
self._build_rgb_dataset(cache_root, dataset_info)
|
||
return rgb_yaml
|
||
|
||
def _get_rgb_cache_root(self, dataset_yaml: Path) -> Path:
|
||
cache_base = Path("data/datasets/_rgb_cache")
|
||
cache_base.mkdir(parents=True, exist_ok=True)
|
||
key = hashlib.md5(str(dataset_yaml.parent.resolve()).encode()).hexdigest()[:8]
|
||
return cache_base / f"{dataset_yaml.parent.name}_{key}"
|
||
|
||
def _clear_rgb_cache_for_dataset(self, dataset_yaml: Path):
|
||
cache_root = self._get_rgb_cache_root(dataset_yaml)
|
||
if cache_root.exists():
|
||
try:
|
||
shutil.rmtree(cache_root)
|
||
logger.debug(f"Removed RGB cache at {cache_root}")
|
||
except OSError as exc:
|
||
logger.warning(f"Failed to remove RGB cache {cache_root}: {exc}")
|
||
|
||
def _dataset_requires_rgb_conversion(self, images_dir: Path) -> bool:
|
||
sample_image = self._find_first_image(images_dir)
|
||
if not sample_image:
|
||
return False
|
||
try:
|
||
with PILImage.open(sample_image) as img:
|
||
return img.mode.upper() != "RGB"
|
||
except Exception as exc:
|
||
logger.warning(f"Failed to inspect image {sample_image}: {exc}")
|
||
return False
|
||
|
||
def _find_first_image(self, directory: Path) -> Optional[Path]:
|
||
if not directory.exists():
|
||
return None
|
||
for path in directory.rglob("*"):
|
||
if path.is_file() and path.suffix.lower() in self.allowed_extensions:
|
||
return path
|
||
return None
|
||
|
||
def _build_rgb_dataset(self, cache_root: Path, dataset_info: Dict[str, Any]):
|
||
if cache_root.exists():
|
||
shutil.rmtree(cache_root)
|
||
cache_root.mkdir(parents=True, exist_ok=True)
|
||
|
||
splits = dataset_info.get("splits", {})
|
||
for split_name in ("train", "val", "test"):
|
||
split_entry = splits.get(split_name)
|
||
if not split_entry:
|
||
continue
|
||
images_src = Path(split_entry.get("path", ""))
|
||
if not images_src.exists():
|
||
continue
|
||
images_dst = cache_root / split_name / "images"
|
||
self._convert_images_to_rgb(images_src, images_dst)
|
||
|
||
labels_src = self._infer_labels_dir(images_src)
|
||
if labels_src.exists():
|
||
labels_dst = cache_root / split_name / "labels"
|
||
self._copy_labels(labels_src, labels_dst)
|
||
|
||
class_names = dataset_info.get("class_names") or []
|
||
names_map = {idx: name for idx, name in enumerate(class_names)}
|
||
num_classes = dataset_info.get("num_classes") or len(class_names)
|
||
|
||
yaml_payload: Dict[str, Any] = {
|
||
"path": cache_root.as_posix(),
|
||
"names": names_map,
|
||
"nc": num_classes,
|
||
}
|
||
|
||
for split_name in ("train", "val", "test"):
|
||
images_dir = cache_root / split_name / "images"
|
||
if images_dir.exists():
|
||
yaml_payload[split_name] = f"{split_name}/images"
|
||
|
||
with open(cache_root / "data.yaml", "w", encoding="utf-8") as handle:
|
||
yaml.safe_dump(yaml_payload, handle, sort_keys=False)
|
||
|
||
def _convert_images_to_rgb(self, src_dir: Path, dst_dir: Path):
|
||
for src in src_dir.rglob("*"):
|
||
if not src.is_file() or src.suffix.lower() not in self.allowed_extensions:
|
||
continue
|
||
relative = src.relative_to(src_dir)
|
||
dst = dst_dir / relative
|
||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||
try:
|
||
with PILImage.open(src) as img:
|
||
rgb_img = img.convert("RGB")
|
||
rgb_img.save(dst)
|
||
except Exception as exc:
|
||
logger.warning(f"Failed to convert {src} to RGB: {exc}")
|
||
|
||
def _copy_labels(self, labels_src: Path, labels_dst: Path):
|
||
label_files = list(labels_src.rglob("*.txt"))
|
||
for label_file in label_files:
|
||
relative = label_file.relative_to(labels_src)
|
||
dst = labels_dst / relative
|
||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||
shutil.copy2(label_file, dst)
|
||
|
||
def _infer_labels_dir(self, images_dir: Path) -> Path:
|
||
return images_dir.parent / "labels"
|
||
|
||
def _invalidate_split_cache(self, split_root: Path):
|
||
for cache_name in ("labels.cache", "images.cache"):
|
||
cache_path = split_root / cache_name
|
||
if cache_path.exists():
|
||
try:
|
||
cache_path.unlink()
|
||
logger.debug(f"Removed stale cache file: {cache_path}")
|
||
except OSError as exc:
|
||
logger.warning(f"Failed to remove cache {cache_path}: {exc}")
|
||
|
||
def _collect_training_params(self) -> Dict[str, Any]:
|
||
model_name = self.model_name_edit.text().strip() or "custom_model"
|
||
model_version = self.model_version_edit.text().strip() or "v1"
|
||
base_model = self.base_model_edit.text().strip() or self.config_manager.get(
|
||
"models.default_base_model", "yolov8s-seg.pt"
|
||
)
|
||
save_dir = self.save_dir_edit.text().strip() or self.config_manager.get(
|
||
"models.models_directory", "data/models"
|
||
)
|
||
save_dir_path = Path(save_dir).expanduser()
|
||
save_dir_path.mkdir(parents=True, exist_ok=True)
|
||
run_name = f"{model_name}_{model_version}".replace(" ", "_")
|
||
|
||
return {
|
||
"model_name": model_name,
|
||
"model_version": model_version,
|
||
"base_model": base_model,
|
||
"save_dir": save_dir_path.as_posix(),
|
||
"run_name": run_name,
|
||
"epochs": self.epochs_spin.value(),
|
||
"batch": self.batch_spin.value(),
|
||
"imgsz": self.imgsz_spin.value(),
|
||
"patience": self.patience_spin.value(),
|
||
"lr0": self.lr_spin.value(),
|
||
}
|
||
|
||
def _start_training(self):
|
||
if self.training_worker and self.training_worker.isRunning():
|
||
return
|
||
# Ensure any previous worker objects are fully cleaned up before starting
|
||
self._cleanup_training_worker()
|
||
|
||
dataset_yaml = self.dataset_path_edit.text().strip()
|
||
if not dataset_yaml:
|
||
QMessageBox.warning(
|
||
self,
|
||
"Dataset Required",
|
||
"Please select or generate a data.yaml file first.",
|
||
)
|
||
return
|
||
|
||
dataset_path = Path(dataset_yaml).expanduser()
|
||
if not dataset_path.exists():
|
||
QMessageBox.warning(
|
||
self, "Invalid Dataset", "Selected data.yaml file does not exist."
|
||
)
|
||
return
|
||
|
||
dataset_info = (
|
||
self.selected_dataset
|
||
if self.selected_dataset
|
||
and self.selected_dataset.get("yaml_path") == str(dataset_path)
|
||
else self._parse_dataset_yaml(dataset_path)
|
||
)
|
||
|
||
self.training_log.clear()
|
||
self._export_labels_from_database(dataset_info)
|
||
|
||
dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info)
|
||
if dataset_to_use != dataset_path:
|
||
self._append_training_log(
|
||
f"Using RGB-converted dataset at {dataset_to_use.parent}"
|
||
)
|
||
|
||
params = self._collect_training_params()
|
||
self._active_training_params = params
|
||
self._training_cancelled = False
|
||
|
||
self._append_training_log(
|
||
f"Starting training run '{params['run_name']}' using {params['base_model']}"
|
||
)
|
||
|
||
self.training_progress_bar.setVisible(True)
|
||
self.training_progress_bar.setMaximum(params["epochs"])
|
||
self.training_progress_bar.setValue(0)
|
||
self._set_training_state(True)
|
||
|
||
self.training_worker = TrainingWorker(
|
||
data_yaml=dataset_to_use.as_posix(),
|
||
base_model=params["base_model"],
|
||
epochs=params["epochs"],
|
||
batch=params["batch"],
|
||
imgsz=params["imgsz"],
|
||
patience=params["patience"],
|
||
lr0=params["lr0"],
|
||
save_dir=params["save_dir"],
|
||
run_name=params["run_name"],
|
||
)
|
||
self.training_worker.progress.connect(self._on_training_progress)
|
||
self.training_worker.finished.connect(self._on_training_finished)
|
||
self.training_worker.error.connect(self._on_training_error)
|
||
self.training_worker.start()
|
||
|
||
def _stop_training(self):
|
||
if self.training_worker and self.training_worker.isRunning():
|
||
self._training_cancelled = True
|
||
self._append_training_log(
|
||
"Stop requested. Waiting for the current epoch to finish..."
|
||
)
|
||
self.training_worker.stop()
|
||
self.stop_training_button.setEnabled(False)
|
||
|
||
def _disconnect_training_worker_signals(self, worker: TrainingWorker):
|
||
for signal, slot in (
|
||
(worker.progress, self._on_training_progress),
|
||
(worker.finished, self._on_training_finished),
|
||
(worker.error, self._on_training_error),
|
||
):
|
||
try:
|
||
signal.disconnect(slot)
|
||
except (TypeError, RuntimeError):
|
||
pass
|
||
|
||
def _cleanup_training_worker(
|
||
self,
|
||
worker: Optional[TrainingWorker] = None,
|
||
*,
|
||
request_stop: bool = False,
|
||
wait_timeout_ms: int = 30000,
|
||
):
|
||
worker = worker or self.training_worker
|
||
if not worker:
|
||
return
|
||
|
||
if worker is self.training_worker:
|
||
self.training_worker = None
|
||
|
||
self._disconnect_training_worker_signals(worker)
|
||
|
||
if request_stop and worker.isRunning():
|
||
worker.stop()
|
||
|
||
if worker.isRunning():
|
||
if not worker.wait(wait_timeout_ms):
|
||
logger.warning(
|
||
"Training worker did not finish within %sms", wait_timeout_ms
|
||
)
|
||
|
||
worker.deleteLater()
|
||
|
||
def shutdown(self):
|
||
"""Stop any running training worker before the tab is destroyed."""
|
||
if self.training_worker and self.training_worker.isRunning():
|
||
logger.info("Shutting down training worker before exit")
|
||
self._append_training_log("Stopping training before application exit…")
|
||
self._cleanup_training_worker(request_stop=True, wait_timeout_ms=120000)
|
||
self._training_cancelled = True
|
||
else:
|
||
self._cleanup_training_worker()
|
||
|
||
self._set_training_state(False)
|
||
self.training_progress_bar.setVisible(False)
|
||
|
||
def _on_training_progress(
|
||
self, current_epoch: int, total_epochs: int, metrics: Dict[str, Any]
|
||
):
|
||
self.training_progress_bar.setMaximum(total_epochs)
|
||
self.training_progress_bar.setValue(current_epoch)
|
||
parts = [f"Epoch {current_epoch}/{total_epochs}"]
|
||
if metrics:
|
||
metric_text = ", ".join(
|
||
f"{key}: {value:.4f}" for key, value in metrics.items()
|
||
)
|
||
parts.append(metric_text)
|
||
self._append_training_log(" | ".join(parts))
|
||
|
||
def _on_training_finished(self, results: Dict[str, Any]):
|
||
self._cleanup_training_worker()
|
||
self._set_training_state(False)
|
||
self.training_progress_bar.setVisible(False)
|
||
|
||
if self._training_cancelled:
|
||
self._append_training_log("Training cancelled by user.")
|
||
QMessageBox.information(self, "Training Cancelled", "Training was stopped.")
|
||
self._training_cancelled = False
|
||
self._active_training_params = None
|
||
return
|
||
|
||
self._append_training_log("Training completed successfully.")
|
||
try:
|
||
self._register_trained_model(results)
|
||
except Exception as exc:
|
||
logger.error(f"Failed to register trained model: {exc}")
|
||
QMessageBox.warning(
|
||
self,
|
||
"Model Registration Failed",
|
||
f"Model trained but not registered: {exc}",
|
||
)
|
||
else:
|
||
QMessageBox.information(
|
||
self, "Training Complete", "Training finished successfully."
|
||
)
|
||
|
||
def _on_training_error(self, message: str):
|
||
self._cleanup_training_worker()
|
||
self._set_training_state(False)
|
||
self.training_progress_bar.setVisible(False)
|
||
self._training_cancelled = False
|
||
self._active_training_params = None
|
||
self._append_training_log(f"ERROR: {message}")
|
||
QMessageBox.critical(self, "Training Failed", message)
|
||
|
||
def _register_trained_model(self, results: Dict[str, Any]):
|
||
if not self._active_training_params:
|
||
return
|
||
|
||
params = self._active_training_params
|
||
model_path = results.get("best_model_path") or results.get("last_model_path")
|
||
if not model_path:
|
||
raise ValueError("Training results did not include a model path.")
|
||
|
||
training_params = {
|
||
"epochs": params["epochs"],
|
||
"batch": params["batch"],
|
||
"imgsz": params["imgsz"],
|
||
"patience": params["patience"],
|
||
"lr0": params["lr0"],
|
||
"run_name": params["run_name"],
|
||
}
|
||
|
||
model_id = self.db_manager.add_model(
|
||
model_name=params["model_name"],
|
||
model_version=params["model_version"],
|
||
model_path=model_path,
|
||
base_model=params["base_model"],
|
||
training_params=training_params,
|
||
metrics=results.get("metrics"),
|
||
)
|
||
|
||
self._append_training_log(
|
||
f"Registered model '{params['model_name']}' (ID {model_id}) at {model_path}"
|
||
)
|
||
self._active_training_params = None
|
||
|
||
def _set_training_state(self, is_training: bool):
|
||
self.start_training_button.setEnabled(not is_training)
|
||
self.stop_training_button.setEnabled(is_training)
|
||
self.generate_yaml_button.setEnabled(not is_training)
|
||
self.dataset_combo.setEnabled(not is_training)
|
||
self.browse_button.setEnabled(not is_training)
|
||
self.rescan_button.setEnabled(not is_training)
|
||
self.model_name_edit.setEnabled(not is_training)
|
||
self.model_version_edit.setEnabled(not is_training)
|
||
self.base_model_edit.setEnabled(not is_training)
|
||
self.base_model_browse_button.setEnabled(not is_training)
|
||
self.save_dir_edit.setEnabled(not is_training)
|
||
self.save_dir_browse_button.setEnabled(not is_training)
|
||
self.epochs_spin.setEnabled(not is_training)
|
||
self.batch_spin.setEnabled(not is_training)
|
||
self.imgsz_spin.setEnabled(not is_training)
|
||
self.patience_spin.setEnabled(not is_training)
|
||
self.lr_spin.setEnabled(not is_training)
|
||
|
||
def _append_training_log(self, message: str):
|
||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||
self.training_log.append(f"[{timestamp}] {message}")
|
||
|
||
def _browse_base_model(self):
|
||
start_path = self.base_model_edit.text().strip() or "."
|
||
file_path, _ = QFileDialog.getOpenFileName(
|
||
self,
|
||
"Select Base Model Weights",
|
||
start_path,
|
||
"PyTorch weights (*.pt *.pth)",
|
||
)
|
||
if file_path:
|
||
self.base_model_edit.setText(file_path)
|
||
|
||
def _browse_save_dir(self):
|
||
start_path = self.save_dir_edit.text().strip() or "data/models"
|
||
directory = QFileDialog.getExistingDirectory(
|
||
self, "Select Save Directory", start_path
|
||
)
|
||
if directory:
|
||
self.save_dir_edit.setText(directory)
|
||
|
||
def _display_dataset_success(self, message: str):
|
||
self.dataset_status_label.setStyleSheet(self._status_styles["success"])
|
||
self.dataset_status_label.setText(message)
|
||
|
||
def _display_dataset_warning(self, message: str):
|
||
self.dataset_status_label.setStyleSheet(self._status_styles["warning"])
|
||
self.dataset_status_label.setText(message)
|
||
|
||
def _display_dataset_error(self, message: str):
|
||
self.dataset_status_label.setStyleSheet(self._status_styles["error"])
|
||
self.dataset_status_label.setText(message)
|
||
|
||
def refresh(self):
|
||
"""Refresh the tab."""
|
||
if self.training_worker and self.training_worker.isRunning():
|
||
self._append_training_log("Refresh skipped while training is running.")
|
||
return
|
||
|
||
self._discover_datasets()
|
||
current_path = self.dataset_path_edit.text().strip()
|
||
if current_path:
|
||
self._set_dataset_path(current_path, persist=False)
|
||
else:
|
||
self._load_saved_dataset()
|