Files
object-segmentation/src/gui/tabs/training_tab.py
2025-12-10 15:46:26 +02:00

1375 lines
50 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Training tab for the microscopy object detection application.
Handles model training with YOLO.
"""
import hashlib
import shutil
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import yaml
from PIL import Image as PILImage
from PySide6.QtCore import Qt, QThread, Signal
from PySide6.QtWidgets import (
QWidget,
QVBoxLayout,
QLabel,
QGroupBox,
QHBoxLayout,
QPushButton,
QLineEdit,
QFileDialog,
QComboBox,
QFormLayout,
QMessageBox,
QTextEdit,
QProgressBar,
QSpinBox,
QDoubleSpinBox,
)
from src.database.db_manager import DatabaseManager
from src.model.yolo_wrapper import YOLOWrapper
from src.utils.config_manager import ConfigManager
from src.utils.logger import get_logger
logger = get_logger(__name__)
DEFAULT_IMAGE_EXTENSIONS = {
".jpg",
".jpeg",
".png",
".tif",
".tiff",
".bmp",
}
class TrainingWorker(QThread):
"""Background worker that runs YOLO training without blocking the UI."""
progress = Signal(int, int, dict) # current_epoch, total_epochs, metrics
finished = Signal(dict)
error = Signal(str)
def __init__(
self,
data_yaml: str,
base_model: str,
epochs: int,
batch: int,
imgsz: int,
patience: int,
lr0: float,
save_dir: str,
run_name: str,
parent: Optional[QThread] = None,
):
super().__init__(parent)
self.data_yaml = data_yaml
self.base_model = base_model
self.epochs = epochs
self.batch = batch
self.imgsz = imgsz
self.patience = patience
self.lr0 = lr0
self.save_dir = save_dir
self.run_name = run_name
self._stop_requested = False
def stop(self):
"""Request the training process to stop at the next epoch boundary."""
self._stop_requested = True
self.requestInterruption()
def run(self):
"""Execute YOLO training and emit progress/finished signals."""
wrapper = YOLOWrapper(self.base_model)
def on_epoch_end(trainer):
current_epoch = getattr(trainer, "epoch", 0) + 1
metrics: Dict[str, float] = {}
loss_items = getattr(trainer, "loss_items", None)
if loss_items:
metrics["loss"] = float(loss_items[-1])
self.progress.emit(current_epoch, self.epochs, metrics)
if self.isInterruptionRequested() or self._stop_requested:
setattr(trainer, "stop_training", True)
callbacks = {"on_fit_epoch_end": on_epoch_end}
try:
results = wrapper.train(
data_yaml=self.data_yaml,
epochs=self.epochs,
imgsz=self.imgsz,
batch=self.batch,
patience=self.patience,
save_dir=self.save_dir,
name=self.run_name,
lr0=self.lr0,
callbacks=callbacks,
)
self.finished.emit(results)
except Exception as exc:
self.error.emit(str(exc))
class TrainingTab(QWidget):
"""Training tab for model training."""
def __init__(
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
):
super().__init__(parent)
self.db_manager = db_manager
self.config_manager = config_manager
self.selected_dataset: Optional[Dict[str, Any]] = None
self.allowed_extensions = {
ext.lower() for ext in self.config_manager.get_allowed_extensions()
} or DEFAULT_IMAGE_EXTENSIONS
self._status_styles = {
"default": "",
"success": "color: #2e7d32; font-weight: 600;",
"warning": "color: #b26a00; font-weight: 600;",
"error": "color: #b00020; font-weight: 600;",
}
self.training_worker: Optional[TrainingWorker] = None
self._active_training_params: Optional[Dict[str, Any]] = None
self._training_cancelled = False
self._setup_ui()
def _setup_ui(self):
"""Setup user interface."""
layout = QVBoxLayout()
layout.addWidget(self._create_dataset_group())
layout.addWidget(self._create_training_controls_group())
layout.addStretch()
self.setLayout(layout)
self._discover_datasets()
self._load_saved_dataset()
def _create_dataset_group(self) -> QGroupBox:
group = QGroupBox("Dataset Selection")
group_layout = QVBoxLayout()
description = QLabel(
"Select a YOLO-format data.yaml file to power model training. "
"You can pick from discovered datasets or browse manually."
)
description.setWordWrap(True)
group_layout.addWidget(description)
combo_layout = QHBoxLayout()
combo_label = QLabel("Discovered datasets:")
combo_layout.addWidget(combo_label)
self.dataset_combo = QComboBox()
self.dataset_combo.currentIndexChanged.connect(self._on_dataset_combo_changed)
combo_layout.addWidget(self.dataset_combo, 1)
self.rescan_button = QPushButton("Rescan")
self.rescan_button.clicked.connect(self._discover_datasets)
combo_layout.addWidget(self.rescan_button)
group_layout.addLayout(combo_layout)
path_layout = QHBoxLayout()
self.dataset_path_edit = QLineEdit()
self.dataset_path_edit.setPlaceholderText("Select a data.yaml file to continue")
self.dataset_path_edit.setReadOnly(True)
path_layout.addWidget(self.dataset_path_edit)
self.browse_button = QPushButton("Browse…")
self.browse_button.clicked.connect(self._browse_dataset)
path_layout.addWidget(self.browse_button)
group_layout.addLayout(path_layout)
action_layout = QHBoxLayout()
action_layout.addStretch()
self.generate_yaml_button = QPushButton("Generate data.yaml")
self.generate_yaml_button.clicked.connect(self._generate_data_yaml)
action_layout.addWidget(self.generate_yaml_button)
group_layout.addLayout(action_layout)
self.dataset_status_label = QLabel("No dataset selected.")
self.dataset_status_label.setWordWrap(True)
group_layout.addWidget(self.dataset_status_label)
summary_group = QGroupBox("Dataset Summary")
summary_layout = QFormLayout()
self.dataset_root_label = QLabel("")
self.dataset_root_label.setWordWrap(True)
self.dataset_root_label.setTextInteractionFlags(Qt.TextSelectableByMouse)
summary_layout.addRow("Root:", self.dataset_root_label)
self.train_count_label = QLabel("")
summary_layout.addRow("Train Images:", self.train_count_label)
self.val_count_label = QLabel("")
summary_layout.addRow("Val Images:", self.val_count_label)
self.test_count_label = QLabel("")
summary_layout.addRow("Test Images:", self.test_count_label)
self.num_classes_label = QLabel("")
summary_layout.addRow("Classes:", self.num_classes_label)
self.class_names_label = QLabel("")
self.class_names_label.setWordWrap(True)
summary_layout.addRow("Class Names:", self.class_names_label)
summary_group.setLayout(summary_layout)
group_layout.addWidget(summary_group)
group.setLayout(group_layout)
return group
def _create_training_controls_group(self) -> QGroupBox:
group = QGroupBox("Training Controls")
group_layout = QVBoxLayout()
form_layout = QFormLayout()
self.model_name_edit = QLineEdit("custom_model")
form_layout.addRow("Model Name:", self.model_name_edit)
self.model_version_edit = QLineEdit("v1")
form_layout.addRow("Version:", self.model_version_edit)
default_base_model = self.config_manager.get(
"models.default_base_model", "yolov8s-seg.pt"
)
base_model_layout = QHBoxLayout()
self.base_model_edit = QLineEdit(default_base_model)
base_model_layout.addWidget(self.base_model_edit)
self.base_model_browse_button = QPushButton("Browse…")
self.base_model_browse_button.clicked.connect(self._browse_base_model)
base_model_layout.addWidget(self.base_model_browse_button)
form_layout.addRow("Base Model (.pt):", base_model_layout)
models_dir = self.config_manager.get("models.models_directory", "data/models")
save_dir_layout = QHBoxLayout()
self.save_dir_edit = QLineEdit(models_dir)
save_dir_layout.addWidget(self.save_dir_edit)
self.save_dir_browse_button = QPushButton("Browse…")
self.save_dir_browse_button.clicked.connect(self._browse_save_dir)
save_dir_layout.addWidget(self.save_dir_browse_button)
form_layout.addRow("Save Directory:", save_dir_layout)
training_defaults = self.config_manager.get_section("training")
self.epochs_spin = QSpinBox()
self.epochs_spin.setRange(1, 1000)
self.epochs_spin.setValue(int(training_defaults.get("default_epochs", 10)))
form_layout.addRow("Epochs:", self.epochs_spin)
self.batch_spin = QSpinBox()
self.batch_spin.setRange(1, 256)
self.batch_spin.setValue(int(training_defaults.get("default_batch_size", 16)))
form_layout.addRow("Batch Size:", self.batch_spin)
self.imgsz_spin = QSpinBox()
self.imgsz_spin.setRange(320, 2048)
self.imgsz_spin.setSingleStep(32)
self.imgsz_spin.setValue(int(training_defaults.get("default_imgsz", 640)))
form_layout.addRow("Image Size:", self.imgsz_spin)
self.patience_spin = QSpinBox()
self.patience_spin.setRange(1, 200)
self.patience_spin.setValue(int(training_defaults.get("default_patience", 50)))
form_layout.addRow("Patience:", self.patience_spin)
self.lr_spin = QDoubleSpinBox()
self.lr_spin.setDecimals(5)
self.lr_spin.setRange(0.00001, 0.1)
self.lr_spin.setSingleStep(0.0005)
self.lr_spin.setValue(float(training_defaults.get("default_lr0", 0.01)))
form_layout.addRow("Learning Rate (lr0):", self.lr_spin)
group_layout.addLayout(form_layout)
button_layout = QHBoxLayout()
self.start_training_button = QPushButton("Start Training")
self.start_training_button.clicked.connect(self._start_training)
button_layout.addWidget(self.start_training_button)
self.stop_training_button = QPushButton("Stop")
self.stop_training_button.setEnabled(False)
self.stop_training_button.clicked.connect(self._stop_training)
button_layout.addWidget(self.stop_training_button)
button_layout.addStretch()
group_layout.addLayout(button_layout)
self.training_progress_bar = QProgressBar()
self.training_progress_bar.setVisible(False)
group_layout.addWidget(self.training_progress_bar)
self.training_log = QTextEdit()
self.training_log.setReadOnly(True)
self.training_log.setMaximumHeight(200)
group_layout.addWidget(self.training_log)
group.setLayout(group_layout)
return group
def _get_dataset_search_roots(self) -> List[Path]:
roots: List[Path] = []
default_root = Path("data/datasets").expanduser()
if default_root.exists():
roots.append(default_root.resolve())
repo_root = self.config_manager.get_image_repository_path()
if repo_root:
repo_path = Path(repo_root).expanduser()
if repo_path.exists():
resolved = repo_path.resolve()
if resolved not in roots:
roots.append(resolved)
return roots
def _discover_datasets(self):
"""Populate combo box with discovered data.yaml files."""
discovered: List[Path] = []
for root in self._get_dataset_search_roots():
try:
for yaml_path in root.rglob("*.yaml"):
if yaml_path.name.lower() not in {"data.yaml", "dataset.yaml"}:
continue
discovered.append(yaml_path.resolve())
except Exception as exc:
logger.warning(f"Unable to scan {root}: {exc}")
unique_paths: List[Path] = []
seen: set[str] = set()
for path in sorted(discovered):
normalized = str(path)
if normalized not in seen:
seen.add(normalized)
unique_paths.append(path)
current_selection = self.dataset_path_edit.text().strip()
self.dataset_combo.blockSignals(True)
self.dataset_combo.clear()
self.dataset_combo.addItem("Select discovered dataset…", None)
for path in unique_paths:
display_name = f"{path.parent.name} ({path})"
self.dataset_combo.addItem(display_name, str(path))
self.dataset_combo.setEnabled(bool(unique_paths))
self.dataset_combo.blockSignals(False)
if current_selection:
self._select_combo_entry_for_path(current_selection)
def _select_combo_entry_for_path(self, target_path: str):
try:
normalized_target = str(Path(target_path).expanduser().resolve())
except Exception:
normalized_target = str(target_path)
for idx in range(1, self.dataset_combo.count()):
data = self.dataset_combo.itemData(idx)
if not data:
continue
try:
normalized_item = str(Path(data).expanduser().resolve())
except Exception:
normalized_item = str(data)
if normalized_item == normalized_target:
self.dataset_combo.setCurrentIndex(idx)
return
self.dataset_combo.setCurrentIndex(0)
def _on_dataset_combo_changed(self, index: int):
dataset_path = self.dataset_combo.itemData(index)
if dataset_path:
self._set_dataset_path(dataset_path)
def _browse_dataset(self):
"""Open a file dialog to manually select data.yaml."""
start_dir = self.config_manager.get(
"training.last_dataset_dir", "data/datasets"
)
start_path = Path(start_dir).expanduser()
if not start_path.exists():
start_path = Path.cwd()
file_path, _ = QFileDialog.getOpenFileName(
self,
"Select YOLO data.yaml",
str(start_path),
"YAML Files (*.yaml *.yml)",
)
if file_path:
self._set_dataset_path(file_path)
def _generate_data_yaml(self):
"""Compose a data.yaml file using database metadata and dataset folders."""
dataset_root = self._determine_dataset_root()
if dataset_root is None:
QMessageBox.warning(
self,
"Dataset Root Missing",
"Unable to locate a dataset root. Please create data/datasets or browse to an existing dataset first.",
)
return
self.generate_yaml_button.setEnabled(False)
try:
output_path = self.db_manager.compose_data_yaml(str(dataset_root))
except ValueError as exc:
logger.error(f"data.yaml generation failed: {exc}")
self._display_dataset_error(str(exc))
QMessageBox.critical(self, "data.yaml Generation Failed", str(exc))
return
except Exception as exc:
logger.exception("Unexpected error while generating data.yaml")
self._display_dataset_error(
"Unexpected error while generating data.yaml. Check logs for details."
)
QMessageBox.critical(
self,
"data.yaml Generation Failed",
"An unexpected error occurred. Please inspect the logs for details.",
)
return
finally:
self.generate_yaml_button.setEnabled(True)
QMessageBox.information(
self,
"data.yaml Generated",
f"data.yaml saved to:\n{output_path}",
)
self._set_dataset_path(output_path)
self._display_dataset_success(f"Generated data.yaml at {output_path}")
def _determine_dataset_root(self) -> Optional[Path]:
"""Infer the dataset root directory for generating data.yaml."""
path_text = self.dataset_path_edit.text().strip()
if path_text:
candidate = Path(path_text).expanduser().parent
if candidate.exists():
return candidate
last_dir = self.config_manager.get("training.last_dataset_dir", "")
if last_dir:
candidate = Path(last_dir).expanduser()
if candidate.exists():
return candidate
default_root = Path("data/datasets").expanduser()
if default_root.exists():
return default_root
return None
def _set_dataset_path(self, path: str, persist: bool = True):
path_obj = Path(path).expanduser()
if not path_obj.exists():
self._display_dataset_error(f"File not found: {path_obj}")
return
if path_obj.suffix.lower() not in {".yaml", ".yml"}:
self._display_dataset_error("Please select a YAML file.")
return
self.dataset_path_edit.setText(str(path_obj))
try:
self._update_dataset_info(path_obj)
except ValueError as exc:
logger.error(f"Dataset parsing failed: {exc}")
self._display_dataset_error(str(exc))
QMessageBox.warning(self, "Invalid Dataset", f"{exc}")
return
if persist:
self.config_manager.set("training.last_dataset_yaml", str(path_obj))
self.config_manager.set("training.last_dataset_dir", str(path_obj.parent))
self.config_manager.save_config()
self._select_combo_entry_for_path(str(path_obj))
def _load_saved_dataset(self):
saved_path = self.config_manager.get("training.last_dataset_yaml", "")
if saved_path and Path(saved_path).expanduser().exists():
self._set_dataset_path(saved_path, persist=False)
elif self.dataset_combo.isEnabled() and self.dataset_combo.count() > 1:
self.dataset_combo.setCurrentIndex(1)
def _update_dataset_info(self, yaml_path: Path):
info = self._parse_dataset_yaml(yaml_path)
self.selected_dataset = info
self.dataset_root_label.setText(info["root"]) # type: ignore[arg-type]
self.train_count_label.setText(
self._format_split_info(info["splits"].get("train"))
)
self.val_count_label.setText(self._format_split_info(info["splits"].get("val")))
self.test_count_label.setText(
self._format_split_info(info["splits"].get("test"))
)
self.num_classes_label.setText(str(info["num_classes"]))
class_names = ", ".join(info["class_names"]) or ""
self.class_names_label.setText(class_names)
warnings = info.get("warnings", [])
if warnings:
warning_text = "Dataset loaded with warnings:\n- " + "\n- ".join(warnings)
self._display_dataset_warning(warning_text)
else:
self._display_dataset_success("Dataset ready for training.")
def _format_split_info(self, split: Optional[Dict[str, Any]]) -> str:
if not split:
return "n/a"
path = split.get("path") or "n/a"
count = split.get("count", 0)
return f"{count} @ {path}"
def _parse_dataset_yaml(self, yaml_path: Path) -> Dict[str, Any]:
try:
with open(yaml_path, "r", encoding="utf-8") as handle:
data = yaml.safe_load(handle) or {}
except yaml.YAMLError as exc:
raise ValueError(f"Invalid YAML syntax in {yaml_path}: {exc}") from exc
except OSError as exc:
raise ValueError(f"Unable to read {yaml_path}: {exc}") from exc
base_entry = data.get("path")
if base_entry:
base_path = Path(base_entry)
if not base_path.is_absolute():
base_path = (yaml_path.parent / base_path).resolve()
else:
base_path = base_path.resolve()
else:
base_path = yaml_path.parent.resolve()
warnings: List[str] = []
splits: Dict[str, Dict[str, Any]] = {}
for split_name in ("train", "val", "test"):
split_value = data.get(split_name)
split_info = {"path": "", "count": 0}
if split_value:
split_path = Path(split_value)
if not split_path.is_absolute():
split_path = (base_path / split_value).resolve()
else:
split_path = split_path.resolve()
split_info["path"] = str(split_path)
if split_path.exists():
split_info["count"] = self._count_images(split_path)
if split_info["count"] == 0:
warnings.append(
f"No images found for {split_name} split at {split_path}"
)
else:
warnings.append(
f"{split_name.capitalize()} path does not exist: {split_path}"
)
else:
if split_name in ("train", "val"):
warnings.append(
f"{split_name.capitalize()} split missing in data.yaml"
)
splits[split_name] = split_info
names_list = self._normalize_class_names(data.get("names"))
nc_value = data.get("nc")
if nc_value is not None:
try:
nc_value = int(nc_value)
except (ValueError, TypeError):
warnings.append("Invalid 'nc' value detected; using class name count.")
nc_value = len(names_list)
else:
nc_value = len(names_list)
if not names_list and nc_value:
names_list = [f"class_{idx}" for idx in range(int(nc_value))]
elif nc_value and len(names_list) not in (0, int(nc_value)):
warnings.append(
f"Number of class names ({len(names_list)}) does not match nc={nc_value}"
)
dataset_name = data.get("name") or base_path.name
return {
"yaml_path": str(yaml_path),
"root": str(base_path),
"name": dataset_name,
"splits": splits,
"class_names": names_list,
"num_classes": int(nc_value) if nc_value else len(names_list),
"warnings": warnings,
}
def _normalize_class_names(self, names_field: Any) -> List[str]:
if isinstance(names_field, dict):
def sort_key(item):
key, _ = item
try:
return int(key)
except (ValueError, TypeError):
return key
return [str(name) for _, name in sorted(names_field.items(), key=sort_key)]
if isinstance(names_field, (list, tuple)):
return [str(name) for name in names_field]
return []
def _count_images(self, directory: Path) -> int:
count = 0
try:
for file_path in directory.rglob("*"):
if file_path.is_file():
suffix = file_path.suffix.lower()
if not self.allowed_extensions or suffix in self.allowed_extensions:
count += 1
except Exception as exc:
logger.warning(f"Unable to count images in {directory}: {exc}")
return count
def _export_labels_from_database(self, dataset_info: Dict[str, Any]) -> None:
"""Write YOLO txt labels for dataset splits using database annotations."""
splits = dataset_info.get("splits", {})
if not splits:
return
class_index_map = self._build_class_index_map(dataset_info)
if not class_index_map:
self._append_training_log(
"Skipping label export: dataset classes do not match database entries."
)
return
dataset_root_str = dataset_info.get("root")
dataset_yaml_path = dataset_info.get("yaml_path")
dataset_yaml = (
Path(dataset_yaml_path).expanduser() if dataset_yaml_path else None
)
dataset_root: Optional[Path]
if dataset_root_str:
dataset_root = Path(dataset_root_str).resolve()
else:
dataset_root = dataset_yaml.parent.resolve() if dataset_yaml else None
split_messages: List[str] = []
for split_name in ("train", "val", "test"):
split_entry = splits.get(split_name) or {}
images_dir_str = split_entry.get("path")
if not images_dir_str:
continue
images_dir = Path(images_dir_str)
if not images_dir.exists():
continue
stats = self._export_labels_for_split(
split_name=split_name,
images_dir=images_dir,
dataset_root=dataset_root or images_dir.resolve(),
class_index_map=class_index_map,
)
if not stats:
continue
message = (
f"[{split_name}] Exported {stats['total_annotations']} annotations "
f"across {stats['processed_images']} image(s)."
)
if stats["registered_images"]:
message += f" {stats['registered_images']} image(s) had database-backed annotations."
if stats["missing_records"]:
message += f" {stats['missing_records']} image(s) had no database entry; empty label files were written."
split_messages.append(message)
for msg in split_messages:
self._append_training_log(msg)
if dataset_yaml:
self._clear_rgb_cache_for_dataset(dataset_yaml)
def _export_labels_for_split(
self,
split_name: str,
images_dir: Path,
dataset_root: Path,
class_index_map: Dict[int, int],
) -> Optional[Dict[str, int]]:
labels_dir = self._infer_labels_dir(images_dir)
labels_dir.mkdir(parents=True, exist_ok=True)
processed_images = 0
registered_images = 0
missing_records = 0
total_annotations = 0
for image_file in images_dir.rglob("*"):
if not image_file.is_file():
continue
suffix = image_file.suffix.lower()
if self.allowed_extensions and suffix not in self.allowed_extensions:
continue
processed_images += 1
label_path = (labels_dir / image_file.relative_to(images_dir)).with_suffix(
".txt"
)
label_path.parent.mkdir(parents=True, exist_ok=True)
found, annotation_entries = self._fetch_annotations_for_image(
image_file, dataset_root, images_dir, class_index_map
)
if found:
registered_images += 1
else:
missing_records += 1
annotations_written = 0
with open(label_path, "w", encoding="utf-8") as handle:
for entry in annotation_entries:
polygon = entry.get("polygon") or []
if polygon:
coords = " ".join(f"{value:.6f}" for value in polygon)
handle.write(f"{entry['class_idx']} {coords}\n")
annotations_written += 1
elif entry.get("bbox"):
x_center, y_center, width, height = entry["bbox"]
handle.write(
f"{entry['class_idx']} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n"
)
annotations_written += 1
total_annotations += annotations_written
cache_reset_root = labels_dir.parent
self._invalidate_split_cache(cache_reset_root)
if processed_images == 0:
self._append_training_log(
f"[{split_name}] No images found to export labels for."
)
return None
return {
"split": split_name,
"processed_images": processed_images,
"registered_images": registered_images,
"missing_records": missing_records,
"total_annotations": total_annotations,
}
def _build_class_index_map(self, dataset_info: Dict[str, Any]) -> Dict[int, int]:
class_names = dataset_info.get("class_names") or []
name_to_index = {name: idx for idx, name in enumerate(class_names)}
mapping: Dict[int, int] = {}
try:
db_classes = self.db_manager.get_object_classes()
except Exception as exc:
logger.error(f"Failed to read object classes from database: {exc}")
return mapping
for cls in db_classes:
idx = name_to_index.get(cls.get("class_name"))
if idx is not None:
mapping[cls["id"]] = idx
return mapping
def _fetch_annotations_for_image(
self,
image_path: Path,
dataset_root: Path,
images_dir: Path,
class_index_map: Dict[int, int],
) -> Tuple[bool, List[Dict[str, Any]]]:
resolved_image = image_path.resolve()
candidates: List[str] = []
for base in (dataset_root, images_dir):
try:
relative = resolved_image.relative_to(base.resolve()).as_posix()
candidates.append(relative)
except ValueError:
continue
candidates.append(resolved_image.name)
image_row: Optional[Dict[str, Any]] = None
seen: set[str] = set()
for candidate in candidates:
normalized = candidate.replace("\\", "/")
if normalized in seen:
continue
seen.add(normalized)
image_row = self.db_manager.get_image_by_path(normalized)
if image_row:
break
if not image_row:
return False, []
annotations = self.db_manager.get_annotations_for_image(image_row["id"]) or []
yolo_entries: List[Dict[str, Any]] = []
for ann in annotations:
class_idx = class_index_map.get(ann.get("class_id"))
if class_idx is None:
continue
x_min = self._clamp01(float(ann.get("x_min", 0.0)))
y_min = self._clamp01(float(ann.get("y_min", 0.0)))
x_max = self._clamp01(float(ann.get("x_max", 0.0)))
y_max = self._clamp01(float(ann.get("y_max", 0.0)))
width = max(0.0, x_max - x_min)
height = max(0.0, y_max - y_min)
bbox_tuple: Optional[Tuple[float, float, float, float]] = None
if width > 0 and height > 0:
x_center = x_min + width / 2.0
y_center = y_min + height / 2.0
bbox_tuple = (x_center, y_center, width, height)
polygon = self._convert_segmentation_mask_to_polygon(
ann.get("segmentation_mask"), (x_min, y_min, x_max, y_max)
)
if not bbox_tuple and not polygon:
continue
yolo_entries.append(
{
"class_idx": class_idx,
"bbox": bbox_tuple,
"polygon": polygon,
}
)
return True, yolo_entries
def _convert_segmentation_mask_to_polygon(
self,
mask_data: Any,
bbox: Optional[Tuple[float, float, float, float]] = None,
) -> List[float]:
if not isinstance(mask_data, list):
return []
candidates: List[Tuple[float, List[float]]] = []
for order in ("yx", "xy"):
coords: List[float] = []
xs: List[float] = []
ys: List[float] = []
for point in mask_data:
if not isinstance(point, (list, tuple)) or len(point) != 2:
continue
first = float(point[0])
second = float(point[1])
if order == "yx":
x_val = self._clamp01(second)
y_val = self._clamp01(first)
else:
x_val = self._clamp01(first)
y_val = self._clamp01(second)
coords.extend([x_val, y_val])
xs.append(x_val)
ys.append(y_val)
if len(coords) < 6:
continue
score = 0.0
if bbox:
x_min, y_min, x_max, y_max = bbox
score = (
abs((min(xs) if xs else 0.0) - x_min)
+ abs((max(xs) if xs else 0.0) - x_max)
+ abs((min(ys) if ys else 0.0) - y_min)
+ abs((max(ys) if ys else 0.0) - y_max)
)
candidates.append((score, coords))
if not candidates:
return []
candidates.sort(key=lambda item: item[0])
return candidates[0][1]
@staticmethod
def _clamp01(value: float) -> float:
if value < 0.0:
return 0.0
if value > 1.0:
return 1.0
return value
def _prepare_dataset_for_training(
self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None
) -> Path:
dataset_info = dataset_info or (
self.selected_dataset
if self.selected_dataset
and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
else self._parse_dataset_yaml(dataset_yaml)
)
train_split = dataset_info.get("splits", {}).get("train") or {}
images_path_str = train_split.get("path")
if not images_path_str:
return dataset_yaml
images_path = Path(images_path_str)
if not images_path.exists():
return dataset_yaml
if not self._dataset_requires_rgb_conversion(images_path):
return dataset_yaml
cache_root = self._get_rgb_cache_root(dataset_yaml)
rgb_yaml = cache_root / "data.yaml"
if rgb_yaml.exists():
self._append_training_log(
f"Detected grayscale dataset; reusing RGB cache at {cache_root}"
)
return rgb_yaml
self._append_training_log(
f"Detected grayscale dataset; creating RGB cache at {cache_root}"
)
self._build_rgb_dataset(cache_root, dataset_info)
return rgb_yaml
def _get_rgb_cache_root(self, dataset_yaml: Path) -> Path:
cache_base = Path("data/datasets/_rgb_cache")
cache_base.mkdir(parents=True, exist_ok=True)
key = hashlib.md5(str(dataset_yaml.parent.resolve()).encode()).hexdigest()[:8]
return cache_base / f"{dataset_yaml.parent.name}_{key}"
def _clear_rgb_cache_for_dataset(self, dataset_yaml: Path):
cache_root = self._get_rgb_cache_root(dataset_yaml)
if cache_root.exists():
try:
shutil.rmtree(cache_root)
logger.debug(f"Removed RGB cache at {cache_root}")
except OSError as exc:
logger.warning(f"Failed to remove RGB cache {cache_root}: {exc}")
def _dataset_requires_rgb_conversion(self, images_dir: Path) -> bool:
sample_image = self._find_first_image(images_dir)
if not sample_image:
return False
try:
with PILImage.open(sample_image) as img:
return img.mode.upper() != "RGB"
except Exception as exc:
logger.warning(f"Failed to inspect image {sample_image}: {exc}")
return False
def _find_first_image(self, directory: Path) -> Optional[Path]:
if not directory.exists():
return None
for path in directory.rglob("*"):
if path.is_file() and path.suffix.lower() in self.allowed_extensions:
return path
return None
def _build_rgb_dataset(self, cache_root: Path, dataset_info: Dict[str, Any]):
if cache_root.exists():
shutil.rmtree(cache_root)
cache_root.mkdir(parents=True, exist_ok=True)
splits = dataset_info.get("splits", {})
for split_name in ("train", "val", "test"):
split_entry = splits.get(split_name)
if not split_entry:
continue
images_src = Path(split_entry.get("path", ""))
if not images_src.exists():
continue
images_dst = cache_root / split_name / "images"
self._convert_images_to_rgb(images_src, images_dst)
labels_src = self._infer_labels_dir(images_src)
if labels_src.exists():
labels_dst = cache_root / split_name / "labels"
self._copy_labels(labels_src, labels_dst)
class_names = dataset_info.get("class_names") or []
names_map = {idx: name for idx, name in enumerate(class_names)}
num_classes = dataset_info.get("num_classes") or len(class_names)
yaml_payload: Dict[str, Any] = {
"path": cache_root.as_posix(),
"names": names_map,
"nc": num_classes,
}
for split_name in ("train", "val", "test"):
images_dir = cache_root / split_name / "images"
if images_dir.exists():
yaml_payload[split_name] = f"{split_name}/images"
with open(cache_root / "data.yaml", "w", encoding="utf-8") as handle:
yaml.safe_dump(yaml_payload, handle, sort_keys=False)
def _convert_images_to_rgb(self, src_dir: Path, dst_dir: Path):
for src in src_dir.rglob("*"):
if not src.is_file() or src.suffix.lower() not in self.allowed_extensions:
continue
relative = src.relative_to(src_dir)
dst = dst_dir / relative
dst.parent.mkdir(parents=True, exist_ok=True)
try:
with PILImage.open(src) as img:
rgb_img = img.convert("RGB")
rgb_img.save(dst)
except Exception as exc:
logger.warning(f"Failed to convert {src} to RGB: {exc}")
def _copy_labels(self, labels_src: Path, labels_dst: Path):
label_files = list(labels_src.rglob("*.txt"))
for label_file in label_files:
relative = label_file.relative_to(labels_src)
dst = labels_dst / relative
dst.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(label_file, dst)
def _infer_labels_dir(self, images_dir: Path) -> Path:
return images_dir.parent / "labels"
def _invalidate_split_cache(self, split_root: Path):
for cache_name in ("labels.cache", "images.cache"):
cache_path = split_root / cache_name
if cache_path.exists():
try:
cache_path.unlink()
logger.debug(f"Removed stale cache file: {cache_path}")
except OSError as exc:
logger.warning(f"Failed to remove cache {cache_path}: {exc}")
def _collect_training_params(self) -> Dict[str, Any]:
model_name = self.model_name_edit.text().strip() or "custom_model"
model_version = self.model_version_edit.text().strip() or "v1"
base_model = self.base_model_edit.text().strip() or self.config_manager.get(
"models.default_base_model", "yolov8s-seg.pt"
)
save_dir = self.save_dir_edit.text().strip() or self.config_manager.get(
"models.models_directory", "data/models"
)
save_dir_path = Path(save_dir).expanduser()
save_dir_path.mkdir(parents=True, exist_ok=True)
run_name = f"{model_name}_{model_version}".replace(" ", "_")
return {
"model_name": model_name,
"model_version": model_version,
"base_model": base_model,
"save_dir": save_dir_path.as_posix(),
"run_name": run_name,
"epochs": self.epochs_spin.value(),
"batch": self.batch_spin.value(),
"imgsz": self.imgsz_spin.value(),
"patience": self.patience_spin.value(),
"lr0": self.lr_spin.value(),
}
def _start_training(self):
if self.training_worker and self.training_worker.isRunning():
return
# Ensure any previous worker objects are fully cleaned up before starting
self._cleanup_training_worker()
dataset_yaml = self.dataset_path_edit.text().strip()
if not dataset_yaml:
QMessageBox.warning(
self,
"Dataset Required",
"Please select or generate a data.yaml file first.",
)
return
dataset_path = Path(dataset_yaml).expanduser()
if not dataset_path.exists():
QMessageBox.warning(
self, "Invalid Dataset", "Selected data.yaml file does not exist."
)
return
dataset_info = (
self.selected_dataset
if self.selected_dataset
and self.selected_dataset.get("yaml_path") == str(dataset_path)
else self._parse_dataset_yaml(dataset_path)
)
self.training_log.clear()
self._export_labels_from_database(dataset_info)
dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info)
if dataset_to_use != dataset_path:
self._append_training_log(
f"Using RGB-converted dataset at {dataset_to_use.parent}"
)
params = self._collect_training_params()
self._active_training_params = params
self._training_cancelled = False
self._append_training_log(
f"Starting training run '{params['run_name']}' using {params['base_model']}"
)
self.training_progress_bar.setVisible(True)
self.training_progress_bar.setMaximum(params["epochs"])
self.training_progress_bar.setValue(0)
self._set_training_state(True)
self.training_worker = TrainingWorker(
data_yaml=dataset_to_use.as_posix(),
base_model=params["base_model"],
epochs=params["epochs"],
batch=params["batch"],
imgsz=params["imgsz"],
patience=params["patience"],
lr0=params["lr0"],
save_dir=params["save_dir"],
run_name=params["run_name"],
)
self.training_worker.progress.connect(self._on_training_progress)
self.training_worker.finished.connect(self._on_training_finished)
self.training_worker.error.connect(self._on_training_error)
self.training_worker.start()
def _stop_training(self):
if self.training_worker and self.training_worker.isRunning():
self._training_cancelled = True
self._append_training_log(
"Stop requested. Waiting for the current epoch to finish..."
)
self.training_worker.stop()
self.stop_training_button.setEnabled(False)
def _disconnect_training_worker_signals(self, worker: TrainingWorker):
for signal, slot in (
(worker.progress, self._on_training_progress),
(worker.finished, self._on_training_finished),
(worker.error, self._on_training_error),
):
try:
signal.disconnect(slot)
except (TypeError, RuntimeError):
pass
def _cleanup_training_worker(
self,
worker: Optional[TrainingWorker] = None,
*,
request_stop: bool = False,
wait_timeout_ms: int = 30000,
):
worker = worker or self.training_worker
if not worker:
return
if worker is self.training_worker:
self.training_worker = None
self._disconnect_training_worker_signals(worker)
if request_stop and worker.isRunning():
worker.stop()
if worker.isRunning():
if not worker.wait(wait_timeout_ms):
logger.warning(
"Training worker did not finish within %sms", wait_timeout_ms
)
worker.deleteLater()
def shutdown(self):
"""Stop any running training worker before the tab is destroyed."""
if self.training_worker and self.training_worker.isRunning():
logger.info("Shutting down training worker before exit")
self._append_training_log("Stopping training before application exit…")
self._cleanup_training_worker(request_stop=True, wait_timeout_ms=120000)
self._training_cancelled = True
else:
self._cleanup_training_worker()
self._set_training_state(False)
self.training_progress_bar.setVisible(False)
def _on_training_progress(
self, current_epoch: int, total_epochs: int, metrics: Dict[str, Any]
):
self.training_progress_bar.setMaximum(total_epochs)
self.training_progress_bar.setValue(current_epoch)
parts = [f"Epoch {current_epoch}/{total_epochs}"]
if metrics:
metric_text = ", ".join(
f"{key}: {value:.4f}" for key, value in metrics.items()
)
parts.append(metric_text)
self._append_training_log(" | ".join(parts))
def _on_training_finished(self, results: Dict[str, Any]):
self._cleanup_training_worker()
self._set_training_state(False)
self.training_progress_bar.setVisible(False)
if self._training_cancelled:
self._append_training_log("Training cancelled by user.")
QMessageBox.information(self, "Training Cancelled", "Training was stopped.")
self._training_cancelled = False
self._active_training_params = None
return
self._append_training_log("Training completed successfully.")
try:
self._register_trained_model(results)
except Exception as exc:
logger.error(f"Failed to register trained model: {exc}")
QMessageBox.warning(
self,
"Model Registration Failed",
f"Model trained but not registered: {exc}",
)
else:
QMessageBox.information(
self, "Training Complete", "Training finished successfully."
)
def _on_training_error(self, message: str):
self._cleanup_training_worker()
self._set_training_state(False)
self.training_progress_bar.setVisible(False)
self._training_cancelled = False
self._active_training_params = None
self._append_training_log(f"ERROR: {message}")
QMessageBox.critical(self, "Training Failed", message)
def _register_trained_model(self, results: Dict[str, Any]):
if not self._active_training_params:
return
params = self._active_training_params
model_path = results.get("best_model_path") or results.get("last_model_path")
if not model_path:
raise ValueError("Training results did not include a model path.")
training_params = {
"epochs": params["epochs"],
"batch": params["batch"],
"imgsz": params["imgsz"],
"patience": params["patience"],
"lr0": params["lr0"],
"run_name": params["run_name"],
}
model_id = self.db_manager.add_model(
model_name=params["model_name"],
model_version=params["model_version"],
model_path=model_path,
base_model=params["base_model"],
training_params=training_params,
metrics=results.get("metrics"),
)
self._append_training_log(
f"Registered model '{params['model_name']}' (ID {model_id}) at {model_path}"
)
self._active_training_params = None
def _set_training_state(self, is_training: bool):
self.start_training_button.setEnabled(not is_training)
self.stop_training_button.setEnabled(is_training)
self.generate_yaml_button.setEnabled(not is_training)
self.dataset_combo.setEnabled(not is_training)
self.browse_button.setEnabled(not is_training)
self.rescan_button.setEnabled(not is_training)
self.model_name_edit.setEnabled(not is_training)
self.model_version_edit.setEnabled(not is_training)
self.base_model_edit.setEnabled(not is_training)
self.base_model_browse_button.setEnabled(not is_training)
self.save_dir_edit.setEnabled(not is_training)
self.save_dir_browse_button.setEnabled(not is_training)
self.epochs_spin.setEnabled(not is_training)
self.batch_spin.setEnabled(not is_training)
self.imgsz_spin.setEnabled(not is_training)
self.patience_spin.setEnabled(not is_training)
self.lr_spin.setEnabled(not is_training)
def _append_training_log(self, message: str):
timestamp = datetime.now().strftime("%H:%M:%S")
self.training_log.append(f"[{timestamp}] {message}")
def _browse_base_model(self):
start_path = self.base_model_edit.text().strip() or "."
file_path, _ = QFileDialog.getOpenFileName(
self,
"Select Base Model Weights",
start_path,
"PyTorch weights (*.pt *.pth)",
)
if file_path:
self.base_model_edit.setText(file_path)
def _browse_save_dir(self):
start_path = self.save_dir_edit.text().strip() or "data/models"
directory = QFileDialog.getExistingDirectory(
self, "Select Save Directory", start_path
)
if directory:
self.save_dir_edit.setText(directory)
def _display_dataset_success(self, message: str):
self.dataset_status_label.setStyleSheet(self._status_styles["success"])
self.dataset_status_label.setText(message)
def _display_dataset_warning(self, message: str):
self.dataset_status_label.setStyleSheet(self._status_styles["warning"])
self.dataset_status_label.setText(message)
def _display_dataset_error(self, message: str):
self.dataset_status_label.setStyleSheet(self._status_styles["error"])
self.dataset_status_label.setText(message)
def refresh(self):
"""Refresh the tab."""
if self.training_worker and self.training_worker.isRunning():
self._append_training_log("Refresh skipped while training is running.")
return
self._discover_datasets()
current_path = self.dataset_path_edit.text().strip()
if current_path:
self._set_dataset_path(current_path, persist=False)
else:
self._load_saved_dataset()