From 5d196c3a4a66d8ab3e012ae9518534ac4648a840 Mon Sep 17 00:00:00 2001 From: Martin Laasmaa Date: Wed, 10 Dec 2025 15:46:26 +0200 Subject: [PATCH] Update training --- config/app_config.yaml | 2 + src/database/db_manager.py | 188 +++++ src/gui/main_window.py | 4 +- src/gui/tabs/training_tab.py | 1356 +++++++++++++++++++++++++++++++++- src/model/yolo_wrapper.py | 2 + 5 files changed, 1534 insertions(+), 18 deletions(-) diff --git a/config/app_config.yaml b/config/app_config.yaml index be8d2d1..7aa8e16 100644 --- a/config/app_config.yaml +++ b/config/app_config.yaml @@ -18,6 +18,8 @@ training: default_imgsz: 640 default_patience: 50 default_lr0: 0.01 + last_dataset_yaml: /home/martin/code/object_detection/data/datasets/data.yaml + last_dataset_dir: /home/martin/code/object_detection/data/datasets detection: default_confidence: 0.25 default_iou: 0.45 diff --git a/src/database/db_manager.py b/src/database/db_manager.py index 53d5695..0100798 100644 --- a/src/database/db_manager.py +++ b/src/database/db_manager.py @@ -10,6 +10,13 @@ from typing import List, Dict, Optional, Tuple, Any, Union from pathlib import Path import csv import hashlib +import yaml + +from src.utils.logger import get_logger + +IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp") + +logger = get_logger(__name__) class DatabaseManager: @@ -861,6 +868,187 @@ class DatabaseManager: finally: conn.close() + # ==================== Dataset Utilities ==================== + + def compose_data_yaml( + self, + dataset_root: str, + output_path: Optional[str] = None, + splits: Optional[Dict[str, str]] = None, + ) -> str: + """ + Compose a YOLO data.yaml file based on dataset folders and database metadata. + + Args: + dataset_root: Base directory containing the dataset structure. + output_path: Optional output path; defaults to /data.yaml. + splits: Optional mapping overriding train/val/test image directories (relative + to dataset_root or absolute paths). + + Returns: + Path to the generated YAML file. + """ + dataset_root_path = Path(dataset_root).expanduser() + if not dataset_root_path.exists(): + raise ValueError(f"Dataset root does not exist: {dataset_root_path}") + dataset_root_path = dataset_root_path.resolve() + + split_map: Dict[str, str] = {key: "" for key in ("train", "val", "test")} + if splits: + for key, value in splits.items(): + if key in split_map and value: + split_map[key] = value + + inferred = self._infer_split_dirs(dataset_root_path) + for key in split_map: + if not split_map[key]: + split_map[key] = inferred.get(key, "") + + for required in ("train", "val"): + if not split_map[required]: + raise ValueError( + "Unable to determine %s image directory under %s. Provide it " + "explicitly via the 'splits' argument." + % (required, dataset_root_path) + ) + + yaml_splits: Dict[str, str] = {} + for key, value in split_map.items(): + if not value: + continue + yaml_splits[key] = self._normalize_split_value(value, dataset_root_path) + + class_names = self._fetch_annotation_class_names() + if not class_names: + class_names = [cls["class_name"] for cls in self.get_object_classes()] + if not class_names: + raise ValueError("No object classes available to populate data.yaml") + + names_map = {idx: name for idx, name in enumerate(class_names)} + payload: Dict[str, Any] = { + "path": dataset_root_path.as_posix(), + "train": yaml_splits["train"], + "val": yaml_splits["val"], + "names": names_map, + "nc": len(class_names), + } + if yaml_splits.get("test"): + payload["test"] = yaml_splits["test"] + + output_path_obj = ( + Path(output_path).expanduser() + if output_path + else dataset_root_path / "data.yaml" + ) + output_path_obj.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path_obj, "w", encoding="utf-8") as handle: + yaml.safe_dump(payload, handle, sort_keys=False) + + logger.info(f"Generated data.yaml at {output_path_obj}") + return output_path_obj.as_posix() + + def _fetch_annotation_class_names(self) -> List[str]: + """Return class names referenced by annotations (ordered by class ID).""" + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute( + """ + SELECT DISTINCT c.id, c.class_name + FROM annotations a + JOIN object_classes c ON a.class_id = c.id + ORDER BY c.id + """ + ) + rows = cursor.fetchall() + return [row["class_name"] for row in rows] + finally: + conn.close() + + def _infer_split_dirs(self, dataset_root: Path) -> Dict[str, str]: + """Infer train/val/test image directories relative to dataset_root.""" + patterns = { + "train": [ + "train/images", + "training/images", + "images/train", + "images/training", + "train", + "training", + ], + "val": [ + "val/images", + "validation/images", + "images/val", + "images/validation", + "val", + "validation", + ], + "test": [ + "test/images", + "testing/images", + "images/test", + "images/testing", + "test", + "testing", + ], + } + + inferred: Dict[str, str] = {key: "" for key in patterns} + for split_name, options in patterns.items(): + for relative in options: + candidate = (dataset_root / relative).resolve() + if ( + candidate.exists() + and candidate.is_dir() + and self._directory_has_images(candidate) + ): + try: + inferred[split_name] = candidate.relative_to( + dataset_root + ).as_posix() + except ValueError: + inferred[split_name] = candidate.as_posix() + break + return inferred + + def _normalize_split_value(self, split_value: str, dataset_root: Path) -> str: + """Validate and normalize a split directory to a YAML-friendly string.""" + split_path = Path(split_value).expanduser() + if not split_path.is_absolute(): + split_path = (dataset_root / split_path).resolve() + else: + split_path = split_path.resolve() + + if not split_path.exists() or not split_path.is_dir(): + raise ValueError(f"Split directory not found: {split_path}") + + if not self._directory_has_images(split_path): + raise ValueError(f"No images found under {split_path}") + + try: + return split_path.relative_to(dataset_root).as_posix() + except ValueError: + return split_path.as_posix() + + @staticmethod + def _directory_has_images(directory: Path, max_checks: int = 2000) -> bool: + """Return True if directory tree contains at least one image file.""" + checked = 0 + try: + for file_path in directory.rglob("*"): + if not file_path.is_file(): + continue + if file_path.suffix.lower() in IMAGE_EXTENSIONS: + return True + checked += 1 + if checked >= max_checks: + break + except Exception: + return False + return False + @staticmethod def calculate_checksum(file_path: str) -> str: """Calculate MD5 checksum of a file.""" diff --git a/src/gui/main_window.py b/src/gui/main_window.py index 99eefa9..3a4df57 100644 --- a/src/gui/main_window.py +++ b/src/gui/main_window.py @@ -297,7 +297,9 @@ class MainWindow(QMainWindow): # Save window state before closing self._save_window_state() - # Save annotation tab state if it exists + # Persist tab state and stop background work before exit + if hasattr(self, "training_tab"): + self.training_tab.shutdown() if hasattr(self, "annotation_tab"): self.annotation_tab.save_state() diff --git a/src/gui/tabs/training_tab.py b/src/gui/tabs/training_tab.py index 8c3e778..b6ba528 100644 --- a/src/gui/tabs/training_tab.py +++ b/src/gui/tabs/training_tab.py @@ -3,15 +3,120 @@ Training tab for the microscopy object detection application. Handles model training with YOLO. """ -from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox +import hashlib +import shutil +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import yaml +from PIL import Image as PILImage +from PySide6.QtCore import Qt, QThread, Signal +from PySide6.QtWidgets import ( + QWidget, + QVBoxLayout, + QLabel, + QGroupBox, + QHBoxLayout, + QPushButton, + QLineEdit, + QFileDialog, + QComboBox, + QFormLayout, + QMessageBox, + QTextEdit, + QProgressBar, + QSpinBox, + QDoubleSpinBox, +) from src.database.db_manager import DatabaseManager +from src.model.yolo_wrapper import YOLOWrapper from src.utils.config_manager import ConfigManager from src.utils.logger import get_logger logger = get_logger(__name__) +DEFAULT_IMAGE_EXTENSIONS = { + ".jpg", + ".jpeg", + ".png", + ".tif", + ".tiff", + ".bmp", +} + + +class TrainingWorker(QThread): + """Background worker that runs YOLO training without blocking the UI.""" + + progress = Signal(int, int, dict) # current_epoch, total_epochs, metrics + finished = Signal(dict) + error = Signal(str) + + def __init__( + self, + data_yaml: str, + base_model: str, + epochs: int, + batch: int, + imgsz: int, + patience: int, + lr0: float, + save_dir: str, + run_name: str, + parent: Optional[QThread] = None, + ): + super().__init__(parent) + self.data_yaml = data_yaml + self.base_model = base_model + self.epochs = epochs + self.batch = batch + self.imgsz = imgsz + self.patience = patience + self.lr0 = lr0 + self.save_dir = save_dir + self.run_name = run_name + self._stop_requested = False + + def stop(self): + """Request the training process to stop at the next epoch boundary.""" + self._stop_requested = True + self.requestInterruption() + + def run(self): + """Execute YOLO training and emit progress/finished signals.""" + wrapper = YOLOWrapper(self.base_model) + + def on_epoch_end(trainer): + current_epoch = getattr(trainer, "epoch", 0) + 1 + metrics: Dict[str, float] = {} + loss_items = getattr(trainer, "loss_items", None) + if loss_items: + metrics["loss"] = float(loss_items[-1]) + self.progress.emit(current_epoch, self.epochs, metrics) + if self.isInterruptionRequested() or self._stop_requested: + setattr(trainer, "stop_training", True) + + callbacks = {"on_fit_epoch_end": on_epoch_end} + + try: + results = wrapper.train( + data_yaml=self.data_yaml, + epochs=self.epochs, + imgsz=self.imgsz, + batch=self.batch, + patience=self.patience, + save_dir=self.save_dir, + name=self.run_name, + lr0=self.lr0, + callbacks=callbacks, + ) + self.finished.emit(results) + except Exception as exc: + self.error.emit(str(exc)) + class TrainingTab(QWidget): """Training tab for model training.""" @@ -23,30 +128,1247 @@ class TrainingTab(QWidget): self.db_manager = db_manager self.config_manager = config_manager + self.selected_dataset: Optional[Dict[str, Any]] = None + self.allowed_extensions = { + ext.lower() for ext in self.config_manager.get_allowed_extensions() + } or DEFAULT_IMAGE_EXTENSIONS + self._status_styles = { + "default": "", + "success": "color: #2e7d32; font-weight: 600;", + "warning": "color: #b26a00; font-weight: 600;", + "error": "color: #b00020; font-weight: 600;", + } + self.training_worker: Optional[TrainingWorker] = None + self._active_training_params: Optional[Dict[str, Any]] = None + self._training_cancelled = False + self._setup_ui() def _setup_ui(self): """Setup user interface.""" layout = QVBoxLayout() - # Placeholder - group = QGroupBox("Training") - group_layout = QVBoxLayout() - label = QLabel( - "Training functionality will be implemented here.\n\n" - "Features:\n" - "- Dataset selection\n" - "- Training parameter configuration\n" - "- Real-time training progress\n" - "- Loss and metric visualization" - ) - group_layout.addWidget(label) - group.setLayout(group_layout) - - layout.addWidget(group) + layout.addWidget(self._create_dataset_group()) + layout.addWidget(self._create_training_controls_group()) layout.addStretch() self.setLayout(layout) + self._discover_datasets() + self._load_saved_dataset() + + def _create_dataset_group(self) -> QGroupBox: + group = QGroupBox("Dataset Selection") + group_layout = QVBoxLayout() + + description = QLabel( + "Select a YOLO-format data.yaml file to power model training. " + "You can pick from discovered datasets or browse manually." + ) + description.setWordWrap(True) + group_layout.addWidget(description) + + combo_layout = QHBoxLayout() + combo_label = QLabel("Discovered datasets:") + combo_layout.addWidget(combo_label) + + self.dataset_combo = QComboBox() + self.dataset_combo.currentIndexChanged.connect(self._on_dataset_combo_changed) + combo_layout.addWidget(self.dataset_combo, 1) + + self.rescan_button = QPushButton("Rescan") + self.rescan_button.clicked.connect(self._discover_datasets) + combo_layout.addWidget(self.rescan_button) + + group_layout.addLayout(combo_layout) + + path_layout = QHBoxLayout() + self.dataset_path_edit = QLineEdit() + self.dataset_path_edit.setPlaceholderText("Select a data.yaml file to continue") + self.dataset_path_edit.setReadOnly(True) + path_layout.addWidget(self.dataset_path_edit) + + self.browse_button = QPushButton("Browse…") + self.browse_button.clicked.connect(self._browse_dataset) + path_layout.addWidget(self.browse_button) + + group_layout.addLayout(path_layout) + + action_layout = QHBoxLayout() + action_layout.addStretch() + self.generate_yaml_button = QPushButton("Generate data.yaml") + self.generate_yaml_button.clicked.connect(self._generate_data_yaml) + action_layout.addWidget(self.generate_yaml_button) + group_layout.addLayout(action_layout) + + self.dataset_status_label = QLabel("No dataset selected.") + self.dataset_status_label.setWordWrap(True) + group_layout.addWidget(self.dataset_status_label) + + summary_group = QGroupBox("Dataset Summary") + summary_layout = QFormLayout() + + self.dataset_root_label = QLabel("–") + self.dataset_root_label.setWordWrap(True) + self.dataset_root_label.setTextInteractionFlags(Qt.TextSelectableByMouse) + summary_layout.addRow("Root:", self.dataset_root_label) + + self.train_count_label = QLabel("–") + summary_layout.addRow("Train Images:", self.train_count_label) + + self.val_count_label = QLabel("–") + summary_layout.addRow("Val Images:", self.val_count_label) + + self.test_count_label = QLabel("–") + summary_layout.addRow("Test Images:", self.test_count_label) + + self.num_classes_label = QLabel("–") + summary_layout.addRow("Classes:", self.num_classes_label) + + self.class_names_label = QLabel("–") + self.class_names_label.setWordWrap(True) + summary_layout.addRow("Class Names:", self.class_names_label) + + summary_group.setLayout(summary_layout) + group_layout.addWidget(summary_group) + + group.setLayout(group_layout) + return group + + def _create_training_controls_group(self) -> QGroupBox: + group = QGroupBox("Training Controls") + group_layout = QVBoxLayout() + + form_layout = QFormLayout() + + self.model_name_edit = QLineEdit("custom_model") + form_layout.addRow("Model Name:", self.model_name_edit) + + self.model_version_edit = QLineEdit("v1") + form_layout.addRow("Version:", self.model_version_edit) + + default_base_model = self.config_manager.get( + "models.default_base_model", "yolov8s-seg.pt" + ) + base_model_layout = QHBoxLayout() + self.base_model_edit = QLineEdit(default_base_model) + base_model_layout.addWidget(self.base_model_edit) + self.base_model_browse_button = QPushButton("Browse…") + self.base_model_browse_button.clicked.connect(self._browse_base_model) + base_model_layout.addWidget(self.base_model_browse_button) + form_layout.addRow("Base Model (.pt):", base_model_layout) + + models_dir = self.config_manager.get("models.models_directory", "data/models") + save_dir_layout = QHBoxLayout() + self.save_dir_edit = QLineEdit(models_dir) + save_dir_layout.addWidget(self.save_dir_edit) + self.save_dir_browse_button = QPushButton("Browse…") + self.save_dir_browse_button.clicked.connect(self._browse_save_dir) + save_dir_layout.addWidget(self.save_dir_browse_button) + form_layout.addRow("Save Directory:", save_dir_layout) + + training_defaults = self.config_manager.get_section("training") + + self.epochs_spin = QSpinBox() + self.epochs_spin.setRange(1, 1000) + self.epochs_spin.setValue(int(training_defaults.get("default_epochs", 10))) + form_layout.addRow("Epochs:", self.epochs_spin) + + self.batch_spin = QSpinBox() + self.batch_spin.setRange(1, 256) + self.batch_spin.setValue(int(training_defaults.get("default_batch_size", 16))) + form_layout.addRow("Batch Size:", self.batch_spin) + + self.imgsz_spin = QSpinBox() + self.imgsz_spin.setRange(320, 2048) + self.imgsz_spin.setSingleStep(32) + self.imgsz_spin.setValue(int(training_defaults.get("default_imgsz", 640))) + form_layout.addRow("Image Size:", self.imgsz_spin) + + self.patience_spin = QSpinBox() + self.patience_spin.setRange(1, 200) + self.patience_spin.setValue(int(training_defaults.get("default_patience", 50))) + form_layout.addRow("Patience:", self.patience_spin) + + self.lr_spin = QDoubleSpinBox() + self.lr_spin.setDecimals(5) + self.lr_spin.setRange(0.00001, 0.1) + self.lr_spin.setSingleStep(0.0005) + self.lr_spin.setValue(float(training_defaults.get("default_lr0", 0.01))) + form_layout.addRow("Learning Rate (lr0):", self.lr_spin) + + group_layout.addLayout(form_layout) + + button_layout = QHBoxLayout() + self.start_training_button = QPushButton("Start Training") + self.start_training_button.clicked.connect(self._start_training) + button_layout.addWidget(self.start_training_button) + + self.stop_training_button = QPushButton("Stop") + self.stop_training_button.setEnabled(False) + self.stop_training_button.clicked.connect(self._stop_training) + button_layout.addWidget(self.stop_training_button) + button_layout.addStretch() + group_layout.addLayout(button_layout) + + self.training_progress_bar = QProgressBar() + self.training_progress_bar.setVisible(False) + group_layout.addWidget(self.training_progress_bar) + + self.training_log = QTextEdit() + self.training_log.setReadOnly(True) + self.training_log.setMaximumHeight(200) + group_layout.addWidget(self.training_log) + + group.setLayout(group_layout) + return group + + def _get_dataset_search_roots(self) -> List[Path]: + roots: List[Path] = [] + default_root = Path("data/datasets").expanduser() + if default_root.exists(): + roots.append(default_root.resolve()) + + repo_root = self.config_manager.get_image_repository_path() + if repo_root: + repo_path = Path(repo_root).expanduser() + if repo_path.exists(): + resolved = repo_path.resolve() + if resolved not in roots: + roots.append(resolved) + + return roots + + def _discover_datasets(self): + """Populate combo box with discovered data.yaml files.""" + discovered: List[Path] = [] + for root in self._get_dataset_search_roots(): + try: + for yaml_path in root.rglob("*.yaml"): + if yaml_path.name.lower() not in {"data.yaml", "dataset.yaml"}: + continue + discovered.append(yaml_path.resolve()) + except Exception as exc: + logger.warning(f"Unable to scan {root}: {exc}") + + unique_paths: List[Path] = [] + seen: set[str] = set() + for path in sorted(discovered): + normalized = str(path) + if normalized not in seen: + seen.add(normalized) + unique_paths.append(path) + + current_selection = self.dataset_path_edit.text().strip() + + self.dataset_combo.blockSignals(True) + self.dataset_combo.clear() + self.dataset_combo.addItem("Select discovered dataset…", None) + + for path in unique_paths: + display_name = f"{path.parent.name} ({path})" + self.dataset_combo.addItem(display_name, str(path)) + + self.dataset_combo.setEnabled(bool(unique_paths)) + self.dataset_combo.blockSignals(False) + + if current_selection: + self._select_combo_entry_for_path(current_selection) + + def _select_combo_entry_for_path(self, target_path: str): + try: + normalized_target = str(Path(target_path).expanduser().resolve()) + except Exception: + normalized_target = str(target_path) + + for idx in range(1, self.dataset_combo.count()): + data = self.dataset_combo.itemData(idx) + if not data: + continue + try: + normalized_item = str(Path(data).expanduser().resolve()) + except Exception: + normalized_item = str(data) + + if normalized_item == normalized_target: + self.dataset_combo.setCurrentIndex(idx) + return + + self.dataset_combo.setCurrentIndex(0) + + def _on_dataset_combo_changed(self, index: int): + dataset_path = self.dataset_combo.itemData(index) + if dataset_path: + self._set_dataset_path(dataset_path) + + def _browse_dataset(self): + """Open a file dialog to manually select data.yaml.""" + start_dir = self.config_manager.get( + "training.last_dataset_dir", "data/datasets" + ) + start_path = Path(start_dir).expanduser() + if not start_path.exists(): + start_path = Path.cwd() + + file_path, _ = QFileDialog.getOpenFileName( + self, + "Select YOLO data.yaml", + str(start_path), + "YAML Files (*.yaml *.yml)", + ) + + if file_path: + self._set_dataset_path(file_path) + + def _generate_data_yaml(self): + """Compose a data.yaml file using database metadata and dataset folders.""" + dataset_root = self._determine_dataset_root() + if dataset_root is None: + QMessageBox.warning( + self, + "Dataset Root Missing", + "Unable to locate a dataset root. Please create data/datasets or browse to an existing dataset first.", + ) + return + + self.generate_yaml_button.setEnabled(False) + try: + output_path = self.db_manager.compose_data_yaml(str(dataset_root)) + except ValueError as exc: + logger.error(f"data.yaml generation failed: {exc}") + self._display_dataset_error(str(exc)) + QMessageBox.critical(self, "data.yaml Generation Failed", str(exc)) + return + except Exception as exc: + logger.exception("Unexpected error while generating data.yaml") + self._display_dataset_error( + "Unexpected error while generating data.yaml. Check logs for details." + ) + QMessageBox.critical( + self, + "data.yaml Generation Failed", + "An unexpected error occurred. Please inspect the logs for details.", + ) + return + finally: + self.generate_yaml_button.setEnabled(True) + + QMessageBox.information( + self, + "data.yaml Generated", + f"data.yaml saved to:\n{output_path}", + ) + self._set_dataset_path(output_path) + self._display_dataset_success(f"Generated data.yaml at {output_path}") + + def _determine_dataset_root(self) -> Optional[Path]: + """Infer the dataset root directory for generating data.yaml.""" + path_text = self.dataset_path_edit.text().strip() + if path_text: + candidate = Path(path_text).expanduser().parent + if candidate.exists(): + return candidate + + last_dir = self.config_manager.get("training.last_dataset_dir", "") + if last_dir: + candidate = Path(last_dir).expanduser() + if candidate.exists(): + return candidate + + default_root = Path("data/datasets").expanduser() + if default_root.exists(): + return default_root + + return None + + def _set_dataset_path(self, path: str, persist: bool = True): + path_obj = Path(path).expanduser() + if not path_obj.exists(): + self._display_dataset_error(f"File not found: {path_obj}") + return + + if path_obj.suffix.lower() not in {".yaml", ".yml"}: + self._display_dataset_error("Please select a YAML file.") + return + + self.dataset_path_edit.setText(str(path_obj)) + + try: + self._update_dataset_info(path_obj) + except ValueError as exc: + logger.error(f"Dataset parsing failed: {exc}") + self._display_dataset_error(str(exc)) + QMessageBox.warning(self, "Invalid Dataset", f"{exc}") + return + + if persist: + self.config_manager.set("training.last_dataset_yaml", str(path_obj)) + self.config_manager.set("training.last_dataset_dir", str(path_obj.parent)) + self.config_manager.save_config() + + self._select_combo_entry_for_path(str(path_obj)) + + def _load_saved_dataset(self): + saved_path = self.config_manager.get("training.last_dataset_yaml", "") + if saved_path and Path(saved_path).expanduser().exists(): + self._set_dataset_path(saved_path, persist=False) + elif self.dataset_combo.isEnabled() and self.dataset_combo.count() > 1: + self.dataset_combo.setCurrentIndex(1) + + def _update_dataset_info(self, yaml_path: Path): + info = self._parse_dataset_yaml(yaml_path) + self.selected_dataset = info + + self.dataset_root_label.setText(info["root"]) # type: ignore[arg-type] + self.train_count_label.setText( + self._format_split_info(info["splits"].get("train")) + ) + self.val_count_label.setText(self._format_split_info(info["splits"].get("val"))) + self.test_count_label.setText( + self._format_split_info(info["splits"].get("test")) + ) + self.num_classes_label.setText(str(info["num_classes"])) + class_names = ", ".join(info["class_names"]) or "–" + self.class_names_label.setText(class_names) + + warnings = info.get("warnings", []) + if warnings: + warning_text = "Dataset loaded with warnings:\n- " + "\n- ".join(warnings) + self._display_dataset_warning(warning_text) + else: + self._display_dataset_success("Dataset ready for training.") + + def _format_split_info(self, split: Optional[Dict[str, Any]]) -> str: + if not split: + return "n/a" + path = split.get("path") or "n/a" + count = split.get("count", 0) + return f"{count} @ {path}" + + def _parse_dataset_yaml(self, yaml_path: Path) -> Dict[str, Any]: + try: + with open(yaml_path, "r", encoding="utf-8") as handle: + data = yaml.safe_load(handle) or {} + except yaml.YAMLError as exc: + raise ValueError(f"Invalid YAML syntax in {yaml_path}: {exc}") from exc + except OSError as exc: + raise ValueError(f"Unable to read {yaml_path}: {exc}") from exc + + base_entry = data.get("path") + if base_entry: + base_path = Path(base_entry) + if not base_path.is_absolute(): + base_path = (yaml_path.parent / base_path).resolve() + else: + base_path = base_path.resolve() + else: + base_path = yaml_path.parent.resolve() + + warnings: List[str] = [] + splits: Dict[str, Dict[str, Any]] = {} + for split_name in ("train", "val", "test"): + split_value = data.get(split_name) + split_info = {"path": "", "count": 0} + if split_value: + split_path = Path(split_value) + if not split_path.is_absolute(): + split_path = (base_path / split_value).resolve() + else: + split_path = split_path.resolve() + split_info["path"] = str(split_path) + + if split_path.exists(): + split_info["count"] = self._count_images(split_path) + if split_info["count"] == 0: + warnings.append( + f"No images found for {split_name} split at {split_path}" + ) + else: + warnings.append( + f"{split_name.capitalize()} path does not exist: {split_path}" + ) + else: + if split_name in ("train", "val"): + warnings.append( + f"{split_name.capitalize()} split missing in data.yaml" + ) + splits[split_name] = split_info + + names_list = self._normalize_class_names(data.get("names")) + + nc_value = data.get("nc") + if nc_value is not None: + try: + nc_value = int(nc_value) + except (ValueError, TypeError): + warnings.append("Invalid 'nc' value detected; using class name count.") + nc_value = len(names_list) + else: + nc_value = len(names_list) + + if not names_list and nc_value: + names_list = [f"class_{idx}" for idx in range(int(nc_value))] + elif nc_value and len(names_list) not in (0, int(nc_value)): + warnings.append( + f"Number of class names ({len(names_list)}) does not match nc={nc_value}" + ) + + dataset_name = data.get("name") or base_path.name + + return { + "yaml_path": str(yaml_path), + "root": str(base_path), + "name": dataset_name, + "splits": splits, + "class_names": names_list, + "num_classes": int(nc_value) if nc_value else len(names_list), + "warnings": warnings, + } + + def _normalize_class_names(self, names_field: Any) -> List[str]: + if isinstance(names_field, dict): + + def sort_key(item): + key, _ = item + try: + return int(key) + except (ValueError, TypeError): + return key + + return [str(name) for _, name in sorted(names_field.items(), key=sort_key)] + + if isinstance(names_field, (list, tuple)): + return [str(name) for name in names_field] + + return [] + + def _count_images(self, directory: Path) -> int: + count = 0 + try: + for file_path in directory.rglob("*"): + if file_path.is_file(): + suffix = file_path.suffix.lower() + if not self.allowed_extensions or suffix in self.allowed_extensions: + count += 1 + except Exception as exc: + logger.warning(f"Unable to count images in {directory}: {exc}") + return count + + def _export_labels_from_database(self, dataset_info: Dict[str, Any]) -> None: + """Write YOLO txt labels for dataset splits using database annotations.""" + + splits = dataset_info.get("splits", {}) + if not splits: + return + + class_index_map = self._build_class_index_map(dataset_info) + if not class_index_map: + self._append_training_log( + "Skipping label export: dataset classes do not match database entries." + ) + return + + dataset_root_str = dataset_info.get("root") + dataset_yaml_path = dataset_info.get("yaml_path") + dataset_yaml = ( + Path(dataset_yaml_path).expanduser() if dataset_yaml_path else None + ) + dataset_root: Optional[Path] + if dataset_root_str: + dataset_root = Path(dataset_root_str).resolve() + else: + dataset_root = dataset_yaml.parent.resolve() if dataset_yaml else None + + split_messages: List[str] = [] + for split_name in ("train", "val", "test"): + split_entry = splits.get(split_name) or {} + images_dir_str = split_entry.get("path") + if not images_dir_str: + continue + + images_dir = Path(images_dir_str) + if not images_dir.exists(): + continue + + stats = self._export_labels_for_split( + split_name=split_name, + images_dir=images_dir, + dataset_root=dataset_root or images_dir.resolve(), + class_index_map=class_index_map, + ) + if not stats: + continue + + message = ( + f"[{split_name}] Exported {stats['total_annotations']} annotations " + f"across {stats['processed_images']} image(s)." + ) + if stats["registered_images"]: + message += f" {stats['registered_images']} image(s) had database-backed annotations." + if stats["missing_records"]: + message += f" {stats['missing_records']} image(s) had no database entry; empty label files were written." + split_messages.append(message) + + for msg in split_messages: + self._append_training_log(msg) + + if dataset_yaml: + self._clear_rgb_cache_for_dataset(dataset_yaml) + + def _export_labels_for_split( + self, + split_name: str, + images_dir: Path, + dataset_root: Path, + class_index_map: Dict[int, int], + ) -> Optional[Dict[str, int]]: + labels_dir = self._infer_labels_dir(images_dir) + labels_dir.mkdir(parents=True, exist_ok=True) + + processed_images = 0 + registered_images = 0 + missing_records = 0 + total_annotations = 0 + + for image_file in images_dir.rglob("*"): + if not image_file.is_file(): + continue + suffix = image_file.suffix.lower() + if self.allowed_extensions and suffix not in self.allowed_extensions: + continue + + processed_images += 1 + label_path = (labels_dir / image_file.relative_to(images_dir)).with_suffix( + ".txt" + ) + label_path.parent.mkdir(parents=True, exist_ok=True) + + found, annotation_entries = self._fetch_annotations_for_image( + image_file, dataset_root, images_dir, class_index_map + ) + if found: + registered_images += 1 + else: + missing_records += 1 + + annotations_written = 0 + with open(label_path, "w", encoding="utf-8") as handle: + for entry in annotation_entries: + polygon = entry.get("polygon") or [] + if polygon: + coords = " ".join(f"{value:.6f}" for value in polygon) + handle.write(f"{entry['class_idx']} {coords}\n") + annotations_written += 1 + elif entry.get("bbox"): + x_center, y_center, width, height = entry["bbox"] + handle.write( + f"{entry['class_idx']} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n" + ) + annotations_written += 1 + + total_annotations += annotations_written + + cache_reset_root = labels_dir.parent + self._invalidate_split_cache(cache_reset_root) + + if processed_images == 0: + self._append_training_log( + f"[{split_name}] No images found to export labels for." + ) + return None + + return { + "split": split_name, + "processed_images": processed_images, + "registered_images": registered_images, + "missing_records": missing_records, + "total_annotations": total_annotations, + } + + def _build_class_index_map(self, dataset_info: Dict[str, Any]) -> Dict[int, int]: + class_names = dataset_info.get("class_names") or [] + name_to_index = {name: idx for idx, name in enumerate(class_names)} + mapping: Dict[int, int] = {} + try: + db_classes = self.db_manager.get_object_classes() + except Exception as exc: + logger.error(f"Failed to read object classes from database: {exc}") + return mapping + + for cls in db_classes: + idx = name_to_index.get(cls.get("class_name")) + if idx is not None: + mapping[cls["id"]] = idx + return mapping + + def _fetch_annotations_for_image( + self, + image_path: Path, + dataset_root: Path, + images_dir: Path, + class_index_map: Dict[int, int], + ) -> Tuple[bool, List[Dict[str, Any]]]: + resolved_image = image_path.resolve() + candidates: List[str] = [] + + for base in (dataset_root, images_dir): + try: + relative = resolved_image.relative_to(base.resolve()).as_posix() + candidates.append(relative) + except ValueError: + continue + + candidates.append(resolved_image.name) + + image_row: Optional[Dict[str, Any]] = None + seen: set[str] = set() + for candidate in candidates: + normalized = candidate.replace("\\", "/") + if normalized in seen: + continue + seen.add(normalized) + image_row = self.db_manager.get_image_by_path(normalized) + if image_row: + break + + if not image_row: + return False, [] + + annotations = self.db_manager.get_annotations_for_image(image_row["id"]) or [] + yolo_entries: List[Dict[str, Any]] = [] + + for ann in annotations: + class_idx = class_index_map.get(ann.get("class_id")) + if class_idx is None: + continue + + x_min = self._clamp01(float(ann.get("x_min", 0.0))) + y_min = self._clamp01(float(ann.get("y_min", 0.0))) + x_max = self._clamp01(float(ann.get("x_max", 0.0))) + y_max = self._clamp01(float(ann.get("y_max", 0.0))) + + width = max(0.0, x_max - x_min) + height = max(0.0, y_max - y_min) + bbox_tuple: Optional[Tuple[float, float, float, float]] = None + if width > 0 and height > 0: + x_center = x_min + width / 2.0 + y_center = y_min + height / 2.0 + bbox_tuple = (x_center, y_center, width, height) + + polygon = self._convert_segmentation_mask_to_polygon( + ann.get("segmentation_mask"), (x_min, y_min, x_max, y_max) + ) + + if not bbox_tuple and not polygon: + continue + + yolo_entries.append( + { + "class_idx": class_idx, + "bbox": bbox_tuple, + "polygon": polygon, + } + ) + + return True, yolo_entries + + def _convert_segmentation_mask_to_polygon( + self, + mask_data: Any, + bbox: Optional[Tuple[float, float, float, float]] = None, + ) -> List[float]: + if not isinstance(mask_data, list): + return [] + + candidates: List[Tuple[float, List[float]]] = [] + for order in ("yx", "xy"): + coords: List[float] = [] + xs: List[float] = [] + ys: List[float] = [] + for point in mask_data: + if not isinstance(point, (list, tuple)) or len(point) != 2: + continue + first = float(point[0]) + second = float(point[1]) + if order == "yx": + x_val = self._clamp01(second) + y_val = self._clamp01(first) + else: + x_val = self._clamp01(first) + y_val = self._clamp01(second) + coords.extend([x_val, y_val]) + xs.append(x_val) + ys.append(y_val) + + if len(coords) < 6: + continue + + score = 0.0 + if bbox: + x_min, y_min, x_max, y_max = bbox + score = ( + abs((min(xs) if xs else 0.0) - x_min) + + abs((max(xs) if xs else 0.0) - x_max) + + abs((min(ys) if ys else 0.0) - y_min) + + abs((max(ys) if ys else 0.0) - y_max) + ) + + candidates.append((score, coords)) + + if not candidates: + return [] + + candidates.sort(key=lambda item: item[0]) + return candidates[0][1] + + @staticmethod + def _clamp01(value: float) -> float: + if value < 0.0: + return 0.0 + if value > 1.0: + return 1.0 + return value + + def _prepare_dataset_for_training( + self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None + ) -> Path: + dataset_info = dataset_info or ( + self.selected_dataset + if self.selected_dataset + and self.selected_dataset.get("yaml_path") == str(dataset_yaml) + else self._parse_dataset_yaml(dataset_yaml) + ) + + train_split = dataset_info.get("splits", {}).get("train") or {} + images_path_str = train_split.get("path") + if not images_path_str: + return dataset_yaml + + images_path = Path(images_path_str) + if not images_path.exists(): + return dataset_yaml + + if not self._dataset_requires_rgb_conversion(images_path): + return dataset_yaml + + cache_root = self._get_rgb_cache_root(dataset_yaml) + rgb_yaml = cache_root / "data.yaml" + if rgb_yaml.exists(): + self._append_training_log( + f"Detected grayscale dataset; reusing RGB cache at {cache_root}" + ) + return rgb_yaml + + self._append_training_log( + f"Detected grayscale dataset; creating RGB cache at {cache_root}" + ) + self._build_rgb_dataset(cache_root, dataset_info) + return rgb_yaml + + def _get_rgb_cache_root(self, dataset_yaml: Path) -> Path: + cache_base = Path("data/datasets/_rgb_cache") + cache_base.mkdir(parents=True, exist_ok=True) + key = hashlib.md5(str(dataset_yaml.parent.resolve()).encode()).hexdigest()[:8] + return cache_base / f"{dataset_yaml.parent.name}_{key}" + + def _clear_rgb_cache_for_dataset(self, dataset_yaml: Path): + cache_root = self._get_rgb_cache_root(dataset_yaml) + if cache_root.exists(): + try: + shutil.rmtree(cache_root) + logger.debug(f"Removed RGB cache at {cache_root}") + except OSError as exc: + logger.warning(f"Failed to remove RGB cache {cache_root}: {exc}") + + def _dataset_requires_rgb_conversion(self, images_dir: Path) -> bool: + sample_image = self._find_first_image(images_dir) + if not sample_image: + return False + try: + with PILImage.open(sample_image) as img: + return img.mode.upper() != "RGB" + except Exception as exc: + logger.warning(f"Failed to inspect image {sample_image}: {exc}") + return False + + def _find_first_image(self, directory: Path) -> Optional[Path]: + if not directory.exists(): + return None + for path in directory.rglob("*"): + if path.is_file() and path.suffix.lower() in self.allowed_extensions: + return path + return None + + def _build_rgb_dataset(self, cache_root: Path, dataset_info: Dict[str, Any]): + if cache_root.exists(): + shutil.rmtree(cache_root) + cache_root.mkdir(parents=True, exist_ok=True) + + splits = dataset_info.get("splits", {}) + for split_name in ("train", "val", "test"): + split_entry = splits.get(split_name) + if not split_entry: + continue + images_src = Path(split_entry.get("path", "")) + if not images_src.exists(): + continue + images_dst = cache_root / split_name / "images" + self._convert_images_to_rgb(images_src, images_dst) + + labels_src = self._infer_labels_dir(images_src) + if labels_src.exists(): + labels_dst = cache_root / split_name / "labels" + self._copy_labels(labels_src, labels_dst) + + class_names = dataset_info.get("class_names") or [] + names_map = {idx: name for idx, name in enumerate(class_names)} + num_classes = dataset_info.get("num_classes") or len(class_names) + + yaml_payload: Dict[str, Any] = { + "path": cache_root.as_posix(), + "names": names_map, + "nc": num_classes, + } + + for split_name in ("train", "val", "test"): + images_dir = cache_root / split_name / "images" + if images_dir.exists(): + yaml_payload[split_name] = f"{split_name}/images" + + with open(cache_root / "data.yaml", "w", encoding="utf-8") as handle: + yaml.safe_dump(yaml_payload, handle, sort_keys=False) + + def _convert_images_to_rgb(self, src_dir: Path, dst_dir: Path): + for src in src_dir.rglob("*"): + if not src.is_file() or src.suffix.lower() not in self.allowed_extensions: + continue + relative = src.relative_to(src_dir) + dst = dst_dir / relative + dst.parent.mkdir(parents=True, exist_ok=True) + try: + with PILImage.open(src) as img: + rgb_img = img.convert("RGB") + rgb_img.save(dst) + except Exception as exc: + logger.warning(f"Failed to convert {src} to RGB: {exc}") + + def _copy_labels(self, labels_src: Path, labels_dst: Path): + label_files = list(labels_src.rglob("*.txt")) + for label_file in label_files: + relative = label_file.relative_to(labels_src) + dst = labels_dst / relative + dst.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(label_file, dst) + + def _infer_labels_dir(self, images_dir: Path) -> Path: + return images_dir.parent / "labels" + + def _invalidate_split_cache(self, split_root: Path): + for cache_name in ("labels.cache", "images.cache"): + cache_path = split_root / cache_name + if cache_path.exists(): + try: + cache_path.unlink() + logger.debug(f"Removed stale cache file: {cache_path}") + except OSError as exc: + logger.warning(f"Failed to remove cache {cache_path}: {exc}") + + def _collect_training_params(self) -> Dict[str, Any]: + model_name = self.model_name_edit.text().strip() or "custom_model" + model_version = self.model_version_edit.text().strip() or "v1" + base_model = self.base_model_edit.text().strip() or self.config_manager.get( + "models.default_base_model", "yolov8s-seg.pt" + ) + save_dir = self.save_dir_edit.text().strip() or self.config_manager.get( + "models.models_directory", "data/models" + ) + save_dir_path = Path(save_dir).expanduser() + save_dir_path.mkdir(parents=True, exist_ok=True) + run_name = f"{model_name}_{model_version}".replace(" ", "_") + + return { + "model_name": model_name, + "model_version": model_version, + "base_model": base_model, + "save_dir": save_dir_path.as_posix(), + "run_name": run_name, + "epochs": self.epochs_spin.value(), + "batch": self.batch_spin.value(), + "imgsz": self.imgsz_spin.value(), + "patience": self.patience_spin.value(), + "lr0": self.lr_spin.value(), + } + + def _start_training(self): + if self.training_worker and self.training_worker.isRunning(): + return + # Ensure any previous worker objects are fully cleaned up before starting + self._cleanup_training_worker() + + dataset_yaml = self.dataset_path_edit.text().strip() + if not dataset_yaml: + QMessageBox.warning( + self, + "Dataset Required", + "Please select or generate a data.yaml file first.", + ) + return + + dataset_path = Path(dataset_yaml).expanduser() + if not dataset_path.exists(): + QMessageBox.warning( + self, "Invalid Dataset", "Selected data.yaml file does not exist." + ) + return + + dataset_info = ( + self.selected_dataset + if self.selected_dataset + and self.selected_dataset.get("yaml_path") == str(dataset_path) + else self._parse_dataset_yaml(dataset_path) + ) + + self.training_log.clear() + self._export_labels_from_database(dataset_info) + + dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info) + if dataset_to_use != dataset_path: + self._append_training_log( + f"Using RGB-converted dataset at {dataset_to_use.parent}" + ) + + params = self._collect_training_params() + self._active_training_params = params + self._training_cancelled = False + + self._append_training_log( + f"Starting training run '{params['run_name']}' using {params['base_model']}" + ) + + self.training_progress_bar.setVisible(True) + self.training_progress_bar.setMaximum(params["epochs"]) + self.training_progress_bar.setValue(0) + self._set_training_state(True) + + self.training_worker = TrainingWorker( + data_yaml=dataset_to_use.as_posix(), + base_model=params["base_model"], + epochs=params["epochs"], + batch=params["batch"], + imgsz=params["imgsz"], + patience=params["patience"], + lr0=params["lr0"], + save_dir=params["save_dir"], + run_name=params["run_name"], + ) + self.training_worker.progress.connect(self._on_training_progress) + self.training_worker.finished.connect(self._on_training_finished) + self.training_worker.error.connect(self._on_training_error) + self.training_worker.start() + + def _stop_training(self): + if self.training_worker and self.training_worker.isRunning(): + self._training_cancelled = True + self._append_training_log( + "Stop requested. Waiting for the current epoch to finish..." + ) + self.training_worker.stop() + self.stop_training_button.setEnabled(False) + + def _disconnect_training_worker_signals(self, worker: TrainingWorker): + for signal, slot in ( + (worker.progress, self._on_training_progress), + (worker.finished, self._on_training_finished), + (worker.error, self._on_training_error), + ): + try: + signal.disconnect(slot) + except (TypeError, RuntimeError): + pass + + def _cleanup_training_worker( + self, + worker: Optional[TrainingWorker] = None, + *, + request_stop: bool = False, + wait_timeout_ms: int = 30000, + ): + worker = worker or self.training_worker + if not worker: + return + + if worker is self.training_worker: + self.training_worker = None + + self._disconnect_training_worker_signals(worker) + + if request_stop and worker.isRunning(): + worker.stop() + + if worker.isRunning(): + if not worker.wait(wait_timeout_ms): + logger.warning( + "Training worker did not finish within %sms", wait_timeout_ms + ) + + worker.deleteLater() + + def shutdown(self): + """Stop any running training worker before the tab is destroyed.""" + if self.training_worker and self.training_worker.isRunning(): + logger.info("Shutting down training worker before exit") + self._append_training_log("Stopping training before application exit…") + self._cleanup_training_worker(request_stop=True, wait_timeout_ms=120000) + self._training_cancelled = True + else: + self._cleanup_training_worker() + + self._set_training_state(False) + self.training_progress_bar.setVisible(False) + + def _on_training_progress( + self, current_epoch: int, total_epochs: int, metrics: Dict[str, Any] + ): + self.training_progress_bar.setMaximum(total_epochs) + self.training_progress_bar.setValue(current_epoch) + parts = [f"Epoch {current_epoch}/{total_epochs}"] + if metrics: + metric_text = ", ".join( + f"{key}: {value:.4f}" for key, value in metrics.items() + ) + parts.append(metric_text) + self._append_training_log(" | ".join(parts)) + + def _on_training_finished(self, results: Dict[str, Any]): + self._cleanup_training_worker() + self._set_training_state(False) + self.training_progress_bar.setVisible(False) + + if self._training_cancelled: + self._append_training_log("Training cancelled by user.") + QMessageBox.information(self, "Training Cancelled", "Training was stopped.") + self._training_cancelled = False + self._active_training_params = None + return + + self._append_training_log("Training completed successfully.") + try: + self._register_trained_model(results) + except Exception as exc: + logger.error(f"Failed to register trained model: {exc}") + QMessageBox.warning( + self, + "Model Registration Failed", + f"Model trained but not registered: {exc}", + ) + else: + QMessageBox.information( + self, "Training Complete", "Training finished successfully." + ) + + def _on_training_error(self, message: str): + self._cleanup_training_worker() + self._set_training_state(False) + self.training_progress_bar.setVisible(False) + self._training_cancelled = False + self._active_training_params = None + self._append_training_log(f"ERROR: {message}") + QMessageBox.critical(self, "Training Failed", message) + + def _register_trained_model(self, results: Dict[str, Any]): + if not self._active_training_params: + return + + params = self._active_training_params + model_path = results.get("best_model_path") or results.get("last_model_path") + if not model_path: + raise ValueError("Training results did not include a model path.") + + training_params = { + "epochs": params["epochs"], + "batch": params["batch"], + "imgsz": params["imgsz"], + "patience": params["patience"], + "lr0": params["lr0"], + "run_name": params["run_name"], + } + + model_id = self.db_manager.add_model( + model_name=params["model_name"], + model_version=params["model_version"], + model_path=model_path, + base_model=params["base_model"], + training_params=training_params, + metrics=results.get("metrics"), + ) + + self._append_training_log( + f"Registered model '{params['model_name']}' (ID {model_id}) at {model_path}" + ) + self._active_training_params = None + + def _set_training_state(self, is_training: bool): + self.start_training_button.setEnabled(not is_training) + self.stop_training_button.setEnabled(is_training) + self.generate_yaml_button.setEnabled(not is_training) + self.dataset_combo.setEnabled(not is_training) + self.browse_button.setEnabled(not is_training) + self.rescan_button.setEnabled(not is_training) + self.model_name_edit.setEnabled(not is_training) + self.model_version_edit.setEnabled(not is_training) + self.base_model_edit.setEnabled(not is_training) + self.base_model_browse_button.setEnabled(not is_training) + self.save_dir_edit.setEnabled(not is_training) + self.save_dir_browse_button.setEnabled(not is_training) + self.epochs_spin.setEnabled(not is_training) + self.batch_spin.setEnabled(not is_training) + self.imgsz_spin.setEnabled(not is_training) + self.patience_spin.setEnabled(not is_training) + self.lr_spin.setEnabled(not is_training) + + def _append_training_log(self, message: str): + timestamp = datetime.now().strftime("%H:%M:%S") + self.training_log.append(f"[{timestamp}] {message}") + + def _browse_base_model(self): + start_path = self.base_model_edit.text().strip() or "." + file_path, _ = QFileDialog.getOpenFileName( + self, + "Select Base Model Weights", + start_path, + "PyTorch weights (*.pt *.pth)", + ) + if file_path: + self.base_model_edit.setText(file_path) + + def _browse_save_dir(self): + start_path = self.save_dir_edit.text().strip() or "data/models" + directory = QFileDialog.getExistingDirectory( + self, "Select Save Directory", start_path + ) + if directory: + self.save_dir_edit.setText(directory) + + def _display_dataset_success(self, message: str): + self.dataset_status_label.setStyleSheet(self._status_styles["success"]) + self.dataset_status_label.setText(message) + + def _display_dataset_warning(self, message: str): + self.dataset_status_label.setStyleSheet(self._status_styles["warning"]) + self.dataset_status_label.setText(message) + + def _display_dataset_error(self, message: str): + self.dataset_status_label.setStyleSheet(self._status_styles["error"]) + self.dataset_status_label.setText(message) + def refresh(self): """Refresh the tab.""" - pass + if self.training_worker and self.training_worker.isRunning(): + self._append_training_log("Refresh skipped while training is running.") + return + + self._discover_datasets() + current_path = self.dataset_path_edit.text().strip() + if current_path: + self._set_dataset_path(current_path, persist=False) + else: + self._load_saved_dataset() diff --git a/src/model/yolo_wrapper.py b/src/model/yolo_wrapper.py index fa1fd8a..cbcf01e 100644 --- a/src/model/yolo_wrapper.py +++ b/src/model/yolo_wrapper.py @@ -55,6 +55,7 @@ class YOLOWrapper: save_dir: str = "data/models", name: str = "custom_model", resume: bool = False, + callbacks: Optional[Dict[str, Callable]] = None, **kwargs, ) -> Dict[str, Any]: """ @@ -69,6 +70,7 @@ class YOLOWrapper: save_dir: Directory to save trained model name: Name for the training run resume: Resume training from last checkpoint + callbacks: Optional Ultralytics callback dictionary **kwargs: Additional training arguments Returns: -- 2.49.1