""" Training tab for the microscopy object detection application. Handles model training with YOLO. """ import hashlib import shutil from datetime import datetime from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import yaml from PIL import Image as PILImage from PySide6.QtCore import Qt, QThread, Signal from PySide6.QtWidgets import ( QWidget, QVBoxLayout, QLabel, QGroupBox, QHBoxLayout, QPushButton, QLineEdit, QFileDialog, QComboBox, QFormLayout, QMessageBox, QTextEdit, QProgressBar, QSpinBox, QDoubleSpinBox, ) from src.database.db_manager import DatabaseManager from src.model.yolo_wrapper import YOLOWrapper from src.utils.config_manager import ConfigManager from src.utils.logger import get_logger logger = get_logger(__name__) DEFAULT_IMAGE_EXTENSIONS = { ".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp", } class TrainingWorker(QThread): """Background worker that runs YOLO training without blocking the UI.""" progress = Signal(int, int, dict) # current_epoch, total_epochs, metrics finished = Signal(dict) error = Signal(str) def __init__( self, data_yaml: str, base_model: str, epochs: int, batch: int, imgsz: int, patience: int, lr0: float, save_dir: str, run_name: str, parent: Optional[QThread] = None, ): super().__init__(parent) self.data_yaml = data_yaml self.base_model = base_model self.epochs = epochs self.batch = batch self.imgsz = imgsz self.patience = patience self.lr0 = lr0 self.save_dir = save_dir self.run_name = run_name self._stop_requested = False def stop(self): """Request the training process to stop at the next epoch boundary.""" self._stop_requested = True self.requestInterruption() def run(self): """Execute YOLO training and emit progress/finished signals.""" wrapper = YOLOWrapper(self.base_model) def on_epoch_end(trainer): current_epoch = getattr(trainer, "epoch", 0) + 1 metrics: Dict[str, float] = {} loss_items = getattr(trainer, "loss_items", None) if loss_items: metrics["loss"] = float(loss_items[-1]) self.progress.emit(current_epoch, self.epochs, metrics) if self.isInterruptionRequested() or self._stop_requested: setattr(trainer, "stop_training", True) callbacks = {"on_fit_epoch_end": on_epoch_end} try: results = wrapper.train( data_yaml=self.data_yaml, epochs=self.epochs, imgsz=self.imgsz, batch=self.batch, patience=self.patience, save_dir=self.save_dir, name=self.run_name, lr0=self.lr0, callbacks=callbacks, ) self.finished.emit(results) except Exception as exc: self.error.emit(str(exc)) class TrainingTab(QWidget): """Training tab for model training.""" def __init__( self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None ): super().__init__(parent) self.db_manager = db_manager self.config_manager = config_manager self.selected_dataset: Optional[Dict[str, Any]] = None self.allowed_extensions = { ext.lower() for ext in self.config_manager.get_allowed_extensions() } or DEFAULT_IMAGE_EXTENSIONS self._status_styles = { "default": "", "success": "color: #2e7d32; font-weight: 600;", "warning": "color: #b26a00; font-weight: 600;", "error": "color: #b00020; font-weight: 600;", } self.training_worker: Optional[TrainingWorker] = None self._active_training_params: Optional[Dict[str, Any]] = None self._training_cancelled = False self._setup_ui() def _setup_ui(self): """Setup user interface.""" layout = QVBoxLayout() layout.addWidget(self._create_dataset_group()) layout.addWidget(self._create_training_controls_group()) layout.addStretch() self.setLayout(layout) self._discover_datasets() self._load_saved_dataset() def _create_dataset_group(self) -> QGroupBox: group = QGroupBox("Dataset Selection") group_layout = QVBoxLayout() description = QLabel( "Select a YOLO-format data.yaml file to power model training. " "You can pick from discovered datasets or browse manually." ) description.setWordWrap(True) group_layout.addWidget(description) combo_layout = QHBoxLayout() combo_label = QLabel("Discovered datasets:") combo_layout.addWidget(combo_label) self.dataset_combo = QComboBox() self.dataset_combo.currentIndexChanged.connect(self._on_dataset_combo_changed) combo_layout.addWidget(self.dataset_combo, 1) self.rescan_button = QPushButton("Rescan") self.rescan_button.clicked.connect(self._discover_datasets) combo_layout.addWidget(self.rescan_button) group_layout.addLayout(combo_layout) path_layout = QHBoxLayout() self.dataset_path_edit = QLineEdit() self.dataset_path_edit.setPlaceholderText("Select a data.yaml file to continue") self.dataset_path_edit.setReadOnly(True) path_layout.addWidget(self.dataset_path_edit) self.browse_button = QPushButton("Browse…") self.browse_button.clicked.connect(self._browse_dataset) path_layout.addWidget(self.browse_button) group_layout.addLayout(path_layout) action_layout = QHBoxLayout() action_layout.addStretch() self.generate_yaml_button = QPushButton("Generate data.yaml") self.generate_yaml_button.clicked.connect(self._generate_data_yaml) action_layout.addWidget(self.generate_yaml_button) group_layout.addLayout(action_layout) self.dataset_status_label = QLabel("No dataset selected.") self.dataset_status_label.setWordWrap(True) group_layout.addWidget(self.dataset_status_label) summary_group = QGroupBox("Dataset Summary") summary_layout = QFormLayout() self.dataset_root_label = QLabel("–") self.dataset_root_label.setWordWrap(True) self.dataset_root_label.setTextInteractionFlags(Qt.TextSelectableByMouse) summary_layout.addRow("Root:", self.dataset_root_label) self.train_count_label = QLabel("–") summary_layout.addRow("Train Images:", self.train_count_label) self.val_count_label = QLabel("–") summary_layout.addRow("Val Images:", self.val_count_label) self.test_count_label = QLabel("–") summary_layout.addRow("Test Images:", self.test_count_label) self.num_classes_label = QLabel("–") summary_layout.addRow("Classes:", self.num_classes_label) self.class_names_label = QLabel("–") self.class_names_label.setWordWrap(True) summary_layout.addRow("Class Names:", self.class_names_label) summary_group.setLayout(summary_layout) group_layout.addWidget(summary_group) group.setLayout(group_layout) return group def _create_training_controls_group(self) -> QGroupBox: group = QGroupBox("Training Controls") group_layout = QVBoxLayout() form_layout = QFormLayout() self.model_name_edit = QLineEdit("custom_model") form_layout.addRow("Model Name:", self.model_name_edit) self.model_version_edit = QLineEdit("v1") form_layout.addRow("Version:", self.model_version_edit) default_base_model = self.config_manager.get( "models.default_base_model", "yolov8s-seg.pt" ) base_model_layout = QHBoxLayout() self.base_model_edit = QLineEdit(default_base_model) base_model_layout.addWidget(self.base_model_edit) self.base_model_browse_button = QPushButton("Browse…") self.base_model_browse_button.clicked.connect(self._browse_base_model) base_model_layout.addWidget(self.base_model_browse_button) form_layout.addRow("Base Model (.pt):", base_model_layout) models_dir = self.config_manager.get("models.models_directory", "data/models") save_dir_layout = QHBoxLayout() self.save_dir_edit = QLineEdit(models_dir) save_dir_layout.addWidget(self.save_dir_edit) self.save_dir_browse_button = QPushButton("Browse…") self.save_dir_browse_button.clicked.connect(self._browse_save_dir) save_dir_layout.addWidget(self.save_dir_browse_button) form_layout.addRow("Save Directory:", save_dir_layout) training_defaults = self.config_manager.get_section("training") self.epochs_spin = QSpinBox() self.epochs_spin.setRange(1, 1000) self.epochs_spin.setValue(int(training_defaults.get("default_epochs", 10))) form_layout.addRow("Epochs:", self.epochs_spin) self.batch_spin = QSpinBox() self.batch_spin.setRange(1, 256) self.batch_spin.setValue(int(training_defaults.get("default_batch_size", 16))) form_layout.addRow("Batch Size:", self.batch_spin) self.imgsz_spin = QSpinBox() self.imgsz_spin.setRange(320, 2048) self.imgsz_spin.setSingleStep(32) self.imgsz_spin.setValue(int(training_defaults.get("default_imgsz", 640))) form_layout.addRow("Image Size:", self.imgsz_spin) self.patience_spin = QSpinBox() self.patience_spin.setRange(1, 200) self.patience_spin.setValue(int(training_defaults.get("default_patience", 50))) form_layout.addRow("Patience:", self.patience_spin) self.lr_spin = QDoubleSpinBox() self.lr_spin.setDecimals(5) self.lr_spin.setRange(0.00001, 0.1) self.lr_spin.setSingleStep(0.0005) self.lr_spin.setValue(float(training_defaults.get("default_lr0", 0.01))) form_layout.addRow("Learning Rate (lr0):", self.lr_spin) group_layout.addLayout(form_layout) button_layout = QHBoxLayout() self.start_training_button = QPushButton("Start Training") self.start_training_button.clicked.connect(self._start_training) button_layout.addWidget(self.start_training_button) self.stop_training_button = QPushButton("Stop") self.stop_training_button.setEnabled(False) self.stop_training_button.clicked.connect(self._stop_training) button_layout.addWidget(self.stop_training_button) button_layout.addStretch() group_layout.addLayout(button_layout) self.training_progress_bar = QProgressBar() self.training_progress_bar.setVisible(False) group_layout.addWidget(self.training_progress_bar) self.training_log = QTextEdit() self.training_log.setReadOnly(True) self.training_log.setMaximumHeight(200) group_layout.addWidget(self.training_log) group.setLayout(group_layout) return group def _get_dataset_search_roots(self) -> List[Path]: roots: List[Path] = [] default_root = Path("data/datasets").expanduser() if default_root.exists(): roots.append(default_root.resolve()) repo_root = self.config_manager.get_image_repository_path() if repo_root: repo_path = Path(repo_root).expanduser() if repo_path.exists(): resolved = repo_path.resolve() if resolved not in roots: roots.append(resolved) return roots def _discover_datasets(self): """Populate combo box with discovered data.yaml files.""" discovered: List[Path] = [] for root in self._get_dataset_search_roots(): try: for yaml_path in root.rglob("*.yaml"): if yaml_path.name.lower() not in {"data.yaml", "dataset.yaml"}: continue discovered.append(yaml_path.resolve()) except Exception as exc: logger.warning(f"Unable to scan {root}: {exc}") unique_paths: List[Path] = [] seen: set[str] = set() for path in sorted(discovered): normalized = str(path) if normalized not in seen: seen.add(normalized) unique_paths.append(path) current_selection = self.dataset_path_edit.text().strip() self.dataset_combo.blockSignals(True) self.dataset_combo.clear() self.dataset_combo.addItem("Select discovered dataset…", None) for path in unique_paths: display_name = f"{path.parent.name} ({path})" self.dataset_combo.addItem(display_name, str(path)) self.dataset_combo.setEnabled(bool(unique_paths)) self.dataset_combo.blockSignals(False) if current_selection: self._select_combo_entry_for_path(current_selection) def _select_combo_entry_for_path(self, target_path: str): try: normalized_target = str(Path(target_path).expanduser().resolve()) except Exception: normalized_target = str(target_path) for idx in range(1, self.dataset_combo.count()): data = self.dataset_combo.itemData(idx) if not data: continue try: normalized_item = str(Path(data).expanduser().resolve()) except Exception: normalized_item = str(data) if normalized_item == normalized_target: self.dataset_combo.setCurrentIndex(idx) return self.dataset_combo.setCurrentIndex(0) def _on_dataset_combo_changed(self, index: int): dataset_path = self.dataset_combo.itemData(index) if dataset_path: self._set_dataset_path(dataset_path) def _browse_dataset(self): """Open a file dialog to manually select data.yaml.""" start_dir = self.config_manager.get( "training.last_dataset_dir", "data/datasets" ) start_path = Path(start_dir).expanduser() if not start_path.exists(): start_path = Path.cwd() file_path, _ = QFileDialog.getOpenFileName( self, "Select YOLO data.yaml", str(start_path), "YAML Files (*.yaml *.yml)", ) if file_path: self._set_dataset_path(file_path) def _generate_data_yaml(self): """Compose a data.yaml file using database metadata and dataset folders.""" dataset_root = self._determine_dataset_root() if dataset_root is None: QMessageBox.warning( self, "Dataset Root Missing", "Unable to locate a dataset root. Please create data/datasets or browse to an existing dataset first.", ) return self.generate_yaml_button.setEnabled(False) try: output_path = self.db_manager.compose_data_yaml(str(dataset_root)) except ValueError as exc: logger.error(f"data.yaml generation failed: {exc}") self._display_dataset_error(str(exc)) QMessageBox.critical(self, "data.yaml Generation Failed", str(exc)) return except Exception as exc: logger.exception("Unexpected error while generating data.yaml") self._display_dataset_error( "Unexpected error while generating data.yaml. Check logs for details." ) QMessageBox.critical( self, "data.yaml Generation Failed", "An unexpected error occurred. Please inspect the logs for details.", ) return finally: self.generate_yaml_button.setEnabled(True) QMessageBox.information( self, "data.yaml Generated", f"data.yaml saved to:\n{output_path}", ) self._set_dataset_path(output_path) self._display_dataset_success(f"Generated data.yaml at {output_path}") def _determine_dataset_root(self) -> Optional[Path]: """Infer the dataset root directory for generating data.yaml.""" path_text = self.dataset_path_edit.text().strip() if path_text: candidate = Path(path_text).expanduser().parent if candidate.exists(): return candidate last_dir = self.config_manager.get("training.last_dataset_dir", "") if last_dir: candidate = Path(last_dir).expanduser() if candidate.exists(): return candidate default_root = Path("data/datasets").expanduser() if default_root.exists(): return default_root return None def _set_dataset_path(self, path: str, persist: bool = True): path_obj = Path(path).expanduser() if not path_obj.exists(): self._display_dataset_error(f"File not found: {path_obj}") return if path_obj.suffix.lower() not in {".yaml", ".yml"}: self._display_dataset_error("Please select a YAML file.") return self.dataset_path_edit.setText(str(path_obj)) try: self._update_dataset_info(path_obj) except ValueError as exc: logger.error(f"Dataset parsing failed: {exc}") self._display_dataset_error(str(exc)) QMessageBox.warning(self, "Invalid Dataset", f"{exc}") return if persist: self.config_manager.set("training.last_dataset_yaml", str(path_obj)) self.config_manager.set("training.last_dataset_dir", str(path_obj.parent)) self.config_manager.save_config() self._select_combo_entry_for_path(str(path_obj)) def _load_saved_dataset(self): saved_path = self.config_manager.get("training.last_dataset_yaml", "") if saved_path and Path(saved_path).expanduser().exists(): self._set_dataset_path(saved_path, persist=False) elif self.dataset_combo.isEnabled() and self.dataset_combo.count() > 1: self.dataset_combo.setCurrentIndex(1) def _update_dataset_info(self, yaml_path: Path): info = self._parse_dataset_yaml(yaml_path) self.selected_dataset = info self.dataset_root_label.setText(info["root"]) # type: ignore[arg-type] self.train_count_label.setText( self._format_split_info(info["splits"].get("train")) ) self.val_count_label.setText(self._format_split_info(info["splits"].get("val"))) self.test_count_label.setText( self._format_split_info(info["splits"].get("test")) ) self.num_classes_label.setText(str(info["num_classes"])) class_names = ", ".join(info["class_names"]) or "–" self.class_names_label.setText(class_names) warnings = info.get("warnings", []) if warnings: warning_text = "Dataset loaded with warnings:\n- " + "\n- ".join(warnings) self._display_dataset_warning(warning_text) else: self._display_dataset_success("Dataset ready for training.") def _format_split_info(self, split: Optional[Dict[str, Any]]) -> str: if not split: return "n/a" path = split.get("path") or "n/a" count = split.get("count", 0) return f"{count} @ {path}" def _parse_dataset_yaml(self, yaml_path: Path) -> Dict[str, Any]: try: with open(yaml_path, "r", encoding="utf-8") as handle: data = yaml.safe_load(handle) or {} except yaml.YAMLError as exc: raise ValueError(f"Invalid YAML syntax in {yaml_path}: {exc}") from exc except OSError as exc: raise ValueError(f"Unable to read {yaml_path}: {exc}") from exc base_entry = data.get("path") if base_entry: base_path = Path(base_entry) if not base_path.is_absolute(): base_path = (yaml_path.parent / base_path).resolve() else: base_path = base_path.resolve() else: base_path = yaml_path.parent.resolve() warnings: List[str] = [] splits: Dict[str, Dict[str, Any]] = {} for split_name in ("train", "val", "test"): split_value = data.get(split_name) split_info = {"path": "", "count": 0} if split_value: split_path = Path(split_value) if not split_path.is_absolute(): split_path = (base_path / split_value).resolve() else: split_path = split_path.resolve() split_info["path"] = str(split_path) if split_path.exists(): split_info["count"] = self._count_images(split_path) if split_info["count"] == 0: warnings.append( f"No images found for {split_name} split at {split_path}" ) else: warnings.append( f"{split_name.capitalize()} path does not exist: {split_path}" ) else: if split_name in ("train", "val"): warnings.append( f"{split_name.capitalize()} split missing in data.yaml" ) splits[split_name] = split_info names_list = self._normalize_class_names(data.get("names")) nc_value = data.get("nc") if nc_value is not None: try: nc_value = int(nc_value) except (ValueError, TypeError): warnings.append("Invalid 'nc' value detected; using class name count.") nc_value = len(names_list) else: nc_value = len(names_list) if not names_list and nc_value: names_list = [f"class_{idx}" for idx in range(int(nc_value))] elif nc_value and len(names_list) not in (0, int(nc_value)): warnings.append( f"Number of class names ({len(names_list)}) does not match nc={nc_value}" ) dataset_name = data.get("name") or base_path.name return { "yaml_path": str(yaml_path), "root": str(base_path), "name": dataset_name, "splits": splits, "class_names": names_list, "num_classes": int(nc_value) if nc_value else len(names_list), "warnings": warnings, } def _normalize_class_names(self, names_field: Any) -> List[str]: if isinstance(names_field, dict): def sort_key(item): key, _ = item try: return int(key) except (ValueError, TypeError): return key return [str(name) for _, name in sorted(names_field.items(), key=sort_key)] if isinstance(names_field, (list, tuple)): return [str(name) for name in names_field] return [] def _count_images(self, directory: Path) -> int: count = 0 try: for file_path in directory.rglob("*"): if file_path.is_file(): suffix = file_path.suffix.lower() if not self.allowed_extensions or suffix in self.allowed_extensions: count += 1 except Exception as exc: logger.warning(f"Unable to count images in {directory}: {exc}") return count def _export_labels_from_database(self, dataset_info: Dict[str, Any]) -> None: """Write YOLO txt labels for dataset splits using database annotations.""" splits = dataset_info.get("splits", {}) if not splits: return class_index_map = self._build_class_index_map(dataset_info) if not class_index_map: self._append_training_log( "Skipping label export: dataset classes do not match database entries." ) return dataset_root_str = dataset_info.get("root") dataset_yaml_path = dataset_info.get("yaml_path") dataset_yaml = ( Path(dataset_yaml_path).expanduser() if dataset_yaml_path else None ) dataset_root: Optional[Path] if dataset_root_str: dataset_root = Path(dataset_root_str).resolve() else: dataset_root = dataset_yaml.parent.resolve() if dataset_yaml else None split_messages: List[str] = [] for split_name in ("train", "val", "test"): split_entry = splits.get(split_name) or {} images_dir_str = split_entry.get("path") if not images_dir_str: continue images_dir = Path(images_dir_str) if not images_dir.exists(): continue stats = self._export_labels_for_split( split_name=split_name, images_dir=images_dir, dataset_root=dataset_root or images_dir.resolve(), class_index_map=class_index_map, ) if not stats: continue message = ( f"[{split_name}] Exported {stats['total_annotations']} annotations " f"across {stats['processed_images']} image(s)." ) if stats["registered_images"]: message += f" {stats['registered_images']} image(s) had database-backed annotations." if stats["missing_records"]: message += f" {stats['missing_records']} image(s) had no database entry; empty label files were written." split_messages.append(message) for msg in split_messages: self._append_training_log(msg) if dataset_yaml: self._clear_rgb_cache_for_dataset(dataset_yaml) def _export_labels_for_split( self, split_name: str, images_dir: Path, dataset_root: Path, class_index_map: Dict[int, int], ) -> Optional[Dict[str, int]]: labels_dir = self._infer_labels_dir(images_dir) labels_dir.mkdir(parents=True, exist_ok=True) processed_images = 0 registered_images = 0 missing_records = 0 total_annotations = 0 for image_file in images_dir.rglob("*"): if not image_file.is_file(): continue suffix = image_file.suffix.lower() if self.allowed_extensions and suffix not in self.allowed_extensions: continue processed_images += 1 label_path = (labels_dir / image_file.relative_to(images_dir)).with_suffix( ".txt" ) label_path.parent.mkdir(parents=True, exist_ok=True) found, annotation_entries = self._fetch_annotations_for_image( image_file, dataset_root, images_dir, class_index_map ) if found: registered_images += 1 else: missing_records += 1 annotations_written = 0 with open(label_path, "w", encoding="utf-8") as handle: for entry in annotation_entries: polygon = entry.get("polygon") or [] if polygon: coords = " ".join(f"{value:.6f}" for value in polygon) handle.write(f"{entry['class_idx']} {coords}\n") annotations_written += 1 elif entry.get("bbox"): x_center, y_center, width, height = entry["bbox"] handle.write( f"{entry['class_idx']} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n" ) annotations_written += 1 total_annotations += annotations_written cache_reset_root = labels_dir.parent self._invalidate_split_cache(cache_reset_root) if processed_images == 0: self._append_training_log( f"[{split_name}] No images found to export labels for." ) return None return { "split": split_name, "processed_images": processed_images, "registered_images": registered_images, "missing_records": missing_records, "total_annotations": total_annotations, } def _build_class_index_map(self, dataset_info: Dict[str, Any]) -> Dict[int, int]: class_names = dataset_info.get("class_names") or [] name_to_index = {name: idx for idx, name in enumerate(class_names)} mapping: Dict[int, int] = {} try: db_classes = self.db_manager.get_object_classes() except Exception as exc: logger.error(f"Failed to read object classes from database: {exc}") return mapping for cls in db_classes: idx = name_to_index.get(cls.get("class_name")) if idx is not None: mapping[cls["id"]] = idx return mapping def _fetch_annotations_for_image( self, image_path: Path, dataset_root: Path, images_dir: Path, class_index_map: Dict[int, int], ) -> Tuple[bool, List[Dict[str, Any]]]: resolved_image = image_path.resolve() candidates: List[str] = [] for base in (dataset_root, images_dir): try: relative = resolved_image.relative_to(base.resolve()).as_posix() candidates.append(relative) except ValueError: continue candidates.append(resolved_image.name) image_row: Optional[Dict[str, Any]] = None seen: set[str] = set() for candidate in candidates: normalized = candidate.replace("\\", "/") if normalized in seen: continue seen.add(normalized) image_row = self.db_manager.get_image_by_path(normalized) if image_row: break if not image_row: return False, [] annotations = self.db_manager.get_annotations_for_image(image_row["id"]) or [] yolo_entries: List[Dict[str, Any]] = [] for ann in annotations: class_idx = class_index_map.get(ann.get("class_id")) if class_idx is None: continue x_min = self._clamp01(float(ann.get("x_min", 0.0))) y_min = self._clamp01(float(ann.get("y_min", 0.0))) x_max = self._clamp01(float(ann.get("x_max", 0.0))) y_max = self._clamp01(float(ann.get("y_max", 0.0))) width = max(0.0, x_max - x_min) height = max(0.0, y_max - y_min) bbox_tuple: Optional[Tuple[float, float, float, float]] = None if width > 0 and height > 0: x_center = x_min + width / 2.0 y_center = y_min + height / 2.0 bbox_tuple = (x_center, y_center, width, height) polygon = self._convert_segmentation_mask_to_polygon( ann.get("segmentation_mask"), (x_min, y_min, x_max, y_max) ) if not bbox_tuple and not polygon: continue yolo_entries.append( { "class_idx": class_idx, "bbox": bbox_tuple, "polygon": polygon, } ) return True, yolo_entries def _convert_segmentation_mask_to_polygon( self, mask_data: Any, bbox: Optional[Tuple[float, float, float, float]] = None, ) -> List[float]: if not isinstance(mask_data, list): return [] candidates: List[Tuple[float, List[float]]] = [] for order in ("yx", "xy"): coords: List[float] = [] xs: List[float] = [] ys: List[float] = [] for point in mask_data: if not isinstance(point, (list, tuple)) or len(point) != 2: continue first = float(point[0]) second = float(point[1]) if order == "yx": x_val = self._clamp01(second) y_val = self._clamp01(first) else: x_val = self._clamp01(first) y_val = self._clamp01(second) coords.extend([x_val, y_val]) xs.append(x_val) ys.append(y_val) if len(coords) < 6: continue score = 0.0 if bbox: x_min, y_min, x_max, y_max = bbox score = ( abs((min(xs) if xs else 0.0) - x_min) + abs((max(xs) if xs else 0.0) - x_max) + abs((min(ys) if ys else 0.0) - y_min) + abs((max(ys) if ys else 0.0) - y_max) ) candidates.append((score, coords)) if not candidates: return [] candidates.sort(key=lambda item: item[0]) return candidates[0][1] @staticmethod def _clamp01(value: float) -> float: if value < 0.0: return 0.0 if value > 1.0: return 1.0 return value def _prepare_dataset_for_training( self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None ) -> Path: dataset_info = dataset_info or ( self.selected_dataset if self.selected_dataset and self.selected_dataset.get("yaml_path") == str(dataset_yaml) else self._parse_dataset_yaml(dataset_yaml) ) train_split = dataset_info.get("splits", {}).get("train") or {} images_path_str = train_split.get("path") if not images_path_str: return dataset_yaml images_path = Path(images_path_str) if not images_path.exists(): return dataset_yaml if not self._dataset_requires_rgb_conversion(images_path): return dataset_yaml cache_root = self._get_rgb_cache_root(dataset_yaml) rgb_yaml = cache_root / "data.yaml" if rgb_yaml.exists(): self._append_training_log( f"Detected grayscale dataset; reusing RGB cache at {cache_root}" ) return rgb_yaml self._append_training_log( f"Detected grayscale dataset; creating RGB cache at {cache_root}" ) self._build_rgb_dataset(cache_root, dataset_info) return rgb_yaml def _get_rgb_cache_root(self, dataset_yaml: Path) -> Path: cache_base = Path("data/datasets/_rgb_cache") cache_base.mkdir(parents=True, exist_ok=True) key = hashlib.md5(str(dataset_yaml.parent.resolve()).encode()).hexdigest()[:8] return cache_base / f"{dataset_yaml.parent.name}_{key}" def _clear_rgb_cache_for_dataset(self, dataset_yaml: Path): cache_root = self._get_rgb_cache_root(dataset_yaml) if cache_root.exists(): try: shutil.rmtree(cache_root) logger.debug(f"Removed RGB cache at {cache_root}") except OSError as exc: logger.warning(f"Failed to remove RGB cache {cache_root}: {exc}") def _dataset_requires_rgb_conversion(self, images_dir: Path) -> bool: sample_image = self._find_first_image(images_dir) if not sample_image: return False try: with PILImage.open(sample_image) as img: return img.mode.upper() != "RGB" except Exception as exc: logger.warning(f"Failed to inspect image {sample_image}: {exc}") return False def _find_first_image(self, directory: Path) -> Optional[Path]: if not directory.exists(): return None for path in directory.rglob("*"): if path.is_file() and path.suffix.lower() in self.allowed_extensions: return path return None def _build_rgb_dataset(self, cache_root: Path, dataset_info: Dict[str, Any]): if cache_root.exists(): shutil.rmtree(cache_root) cache_root.mkdir(parents=True, exist_ok=True) splits = dataset_info.get("splits", {}) for split_name in ("train", "val", "test"): split_entry = splits.get(split_name) if not split_entry: continue images_src = Path(split_entry.get("path", "")) if not images_src.exists(): continue images_dst = cache_root / split_name / "images" self._convert_images_to_rgb(images_src, images_dst) labels_src = self._infer_labels_dir(images_src) if labels_src.exists(): labels_dst = cache_root / split_name / "labels" self._copy_labels(labels_src, labels_dst) class_names = dataset_info.get("class_names") or [] names_map = {idx: name for idx, name in enumerate(class_names)} num_classes = dataset_info.get("num_classes") or len(class_names) yaml_payload: Dict[str, Any] = { "path": cache_root.as_posix(), "names": names_map, "nc": num_classes, } for split_name in ("train", "val", "test"): images_dir = cache_root / split_name / "images" if images_dir.exists(): yaml_payload[split_name] = f"{split_name}/images" with open(cache_root / "data.yaml", "w", encoding="utf-8") as handle: yaml.safe_dump(yaml_payload, handle, sort_keys=False) def _convert_images_to_rgb(self, src_dir: Path, dst_dir: Path): for src in src_dir.rglob("*"): if not src.is_file() or src.suffix.lower() not in self.allowed_extensions: continue relative = src.relative_to(src_dir) dst = dst_dir / relative dst.parent.mkdir(parents=True, exist_ok=True) try: with PILImage.open(src) as img: rgb_img = img.convert("RGB") rgb_img.save(dst) except Exception as exc: logger.warning(f"Failed to convert {src} to RGB: {exc}") def _copy_labels(self, labels_src: Path, labels_dst: Path): label_files = list(labels_src.rglob("*.txt")) for label_file in label_files: relative = label_file.relative_to(labels_src) dst = labels_dst / relative dst.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(label_file, dst) def _infer_labels_dir(self, images_dir: Path) -> Path: return images_dir.parent / "labels" def _invalidate_split_cache(self, split_root: Path): for cache_name in ("labels.cache", "images.cache"): cache_path = split_root / cache_name if cache_path.exists(): try: cache_path.unlink() logger.debug(f"Removed stale cache file: {cache_path}") except OSError as exc: logger.warning(f"Failed to remove cache {cache_path}: {exc}") def _collect_training_params(self) -> Dict[str, Any]: model_name = self.model_name_edit.text().strip() or "custom_model" model_version = self.model_version_edit.text().strip() or "v1" base_model = self.base_model_edit.text().strip() or self.config_manager.get( "models.default_base_model", "yolov8s-seg.pt" ) save_dir = self.save_dir_edit.text().strip() or self.config_manager.get( "models.models_directory", "data/models" ) save_dir_path = Path(save_dir).expanduser() save_dir_path.mkdir(parents=True, exist_ok=True) run_name = f"{model_name}_{model_version}".replace(" ", "_") return { "model_name": model_name, "model_version": model_version, "base_model": base_model, "save_dir": save_dir_path.as_posix(), "run_name": run_name, "epochs": self.epochs_spin.value(), "batch": self.batch_spin.value(), "imgsz": self.imgsz_spin.value(), "patience": self.patience_spin.value(), "lr0": self.lr_spin.value(), } def _start_training(self): if self.training_worker and self.training_worker.isRunning(): return # Ensure any previous worker objects are fully cleaned up before starting self._cleanup_training_worker() dataset_yaml = self.dataset_path_edit.text().strip() if not dataset_yaml: QMessageBox.warning( self, "Dataset Required", "Please select or generate a data.yaml file first.", ) return dataset_path = Path(dataset_yaml).expanduser() if not dataset_path.exists(): QMessageBox.warning( self, "Invalid Dataset", "Selected data.yaml file does not exist." ) return dataset_info = ( self.selected_dataset if self.selected_dataset and self.selected_dataset.get("yaml_path") == str(dataset_path) else self._parse_dataset_yaml(dataset_path) ) self.training_log.clear() self._export_labels_from_database(dataset_info) dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info) if dataset_to_use != dataset_path: self._append_training_log( f"Using RGB-converted dataset at {dataset_to_use.parent}" ) params = self._collect_training_params() self._active_training_params = params self._training_cancelled = False self._append_training_log( f"Starting training run '{params['run_name']}' using {params['base_model']}" ) self.training_progress_bar.setVisible(True) self.training_progress_bar.setMaximum(params["epochs"]) self.training_progress_bar.setValue(0) self._set_training_state(True) self.training_worker = TrainingWorker( data_yaml=dataset_to_use.as_posix(), base_model=params["base_model"], epochs=params["epochs"], batch=params["batch"], imgsz=params["imgsz"], patience=params["patience"], lr0=params["lr0"], save_dir=params["save_dir"], run_name=params["run_name"], ) self.training_worker.progress.connect(self._on_training_progress) self.training_worker.finished.connect(self._on_training_finished) self.training_worker.error.connect(self._on_training_error) self.training_worker.start() def _stop_training(self): if self.training_worker and self.training_worker.isRunning(): self._training_cancelled = True self._append_training_log( "Stop requested. Waiting for the current epoch to finish..." ) self.training_worker.stop() self.stop_training_button.setEnabled(False) def _disconnect_training_worker_signals(self, worker: TrainingWorker): for signal, slot in ( (worker.progress, self._on_training_progress), (worker.finished, self._on_training_finished), (worker.error, self._on_training_error), ): try: signal.disconnect(slot) except (TypeError, RuntimeError): pass def _cleanup_training_worker( self, worker: Optional[TrainingWorker] = None, *, request_stop: bool = False, wait_timeout_ms: int = 30000, ): worker = worker or self.training_worker if not worker: return if worker is self.training_worker: self.training_worker = None self._disconnect_training_worker_signals(worker) if request_stop and worker.isRunning(): worker.stop() if worker.isRunning(): if not worker.wait(wait_timeout_ms): logger.warning( "Training worker did not finish within %sms", wait_timeout_ms ) worker.deleteLater() def shutdown(self): """Stop any running training worker before the tab is destroyed.""" if self.training_worker and self.training_worker.isRunning(): logger.info("Shutting down training worker before exit") self._append_training_log("Stopping training before application exit…") self._cleanup_training_worker(request_stop=True, wait_timeout_ms=120000) self._training_cancelled = True else: self._cleanup_training_worker() self._set_training_state(False) self.training_progress_bar.setVisible(False) def _on_training_progress( self, current_epoch: int, total_epochs: int, metrics: Dict[str, Any] ): self.training_progress_bar.setMaximum(total_epochs) self.training_progress_bar.setValue(current_epoch) parts = [f"Epoch {current_epoch}/{total_epochs}"] if metrics: metric_text = ", ".join( f"{key}: {value:.4f}" for key, value in metrics.items() ) parts.append(metric_text) self._append_training_log(" | ".join(parts)) def _on_training_finished(self, results: Dict[str, Any]): self._cleanup_training_worker() self._set_training_state(False) self.training_progress_bar.setVisible(False) if self._training_cancelled: self._append_training_log("Training cancelled by user.") QMessageBox.information(self, "Training Cancelled", "Training was stopped.") self._training_cancelled = False self._active_training_params = None return self._append_training_log("Training completed successfully.") try: self._register_trained_model(results) except Exception as exc: logger.error(f"Failed to register trained model: {exc}") QMessageBox.warning( self, "Model Registration Failed", f"Model trained but not registered: {exc}", ) else: QMessageBox.information( self, "Training Complete", "Training finished successfully." ) def _on_training_error(self, message: str): self._cleanup_training_worker() self._set_training_state(False) self.training_progress_bar.setVisible(False) self._training_cancelled = False self._active_training_params = None self._append_training_log(f"ERROR: {message}") QMessageBox.critical(self, "Training Failed", message) def _register_trained_model(self, results: Dict[str, Any]): if not self._active_training_params: return params = self._active_training_params model_path = results.get("best_model_path") or results.get("last_model_path") if not model_path: raise ValueError("Training results did not include a model path.") training_params = { "epochs": params["epochs"], "batch": params["batch"], "imgsz": params["imgsz"], "patience": params["patience"], "lr0": params["lr0"], "run_name": params["run_name"], } model_id = self.db_manager.add_model( model_name=params["model_name"], model_version=params["model_version"], model_path=model_path, base_model=params["base_model"], training_params=training_params, metrics=results.get("metrics"), ) self._append_training_log( f"Registered model '{params['model_name']}' (ID {model_id}) at {model_path}" ) self._active_training_params = None def _set_training_state(self, is_training: bool): self.start_training_button.setEnabled(not is_training) self.stop_training_button.setEnabled(is_training) self.generate_yaml_button.setEnabled(not is_training) self.dataset_combo.setEnabled(not is_training) self.browse_button.setEnabled(not is_training) self.rescan_button.setEnabled(not is_training) self.model_name_edit.setEnabled(not is_training) self.model_version_edit.setEnabled(not is_training) self.base_model_edit.setEnabled(not is_training) self.base_model_browse_button.setEnabled(not is_training) self.save_dir_edit.setEnabled(not is_training) self.save_dir_browse_button.setEnabled(not is_training) self.epochs_spin.setEnabled(not is_training) self.batch_spin.setEnabled(not is_training) self.imgsz_spin.setEnabled(not is_training) self.patience_spin.setEnabled(not is_training) self.lr_spin.setEnabled(not is_training) def _append_training_log(self, message: str): timestamp = datetime.now().strftime("%H:%M:%S") self.training_log.append(f"[{timestamp}] {message}") def _browse_base_model(self): start_path = self.base_model_edit.text().strip() or "." file_path, _ = QFileDialog.getOpenFileName( self, "Select Base Model Weights", start_path, "PyTorch weights (*.pt *.pth)", ) if file_path: self.base_model_edit.setText(file_path) def _browse_save_dir(self): start_path = self.save_dir_edit.text().strip() or "data/models" directory = QFileDialog.getExistingDirectory( self, "Select Save Directory", start_path ) if directory: self.save_dir_edit.setText(directory) def _display_dataset_success(self, message: str): self.dataset_status_label.setStyleSheet(self._status_styles["success"]) self.dataset_status_label.setText(message) def _display_dataset_warning(self, message: str): self.dataset_status_label.setStyleSheet(self._status_styles["warning"]) self.dataset_status_label.setText(message) def _display_dataset_error(self, message: str): self.dataset_status_label.setStyleSheet(self._status_styles["error"]) self.dataset_status_label.setText(message) def refresh(self): """Refresh the tab.""" if self.training_worker and self.training_worker.isRunning(): self._append_training_log("Refresh skipped while training is running.") return self._discover_datasets() current_path = self.dataset_path_edit.text().strip() if current_path: self._set_dataset_path(current_path, persist=False) else: self._load_saved_dataset()