From 9c8931e6f3bb30b3b03aaff70d48795108c0855a Mon Sep 17 00:00:00 2001 From: Martin Laasmaa Date: Fri, 16 Jan 2026 13:58:02 +0200 Subject: [PATCH] Finish validation tab --- src/gui/tabs/validation_tab.py | 553 +++++++++++++++++++++++++++++++-- 1 file changed, 531 insertions(+), 22 deletions(-) diff --git a/src/gui/tabs/validation_tab.py b/src/gui/tabs/validation_tab.py index 2e7749a..3d6be69 100644 --- a/src/gui/tabs/validation_tab.py +++ b/src/gui/tabs/validation_tab.py @@ -2,45 +2,554 @@ Validation tab for the microscopy object detection application. """ -from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Tuple + +from PySide6.QtCore import Qt, QSize +from PySide6.QtGui import QPainter, QPixmap +from PySide6.QtWidgets import ( + QWidget, + QVBoxLayout, + QLabel, + QGroupBox, + QHBoxLayout, + QPushButton, + QComboBox, + QFormLayout, + QScrollArea, + QGridLayout, + QFrame, + QTableWidget, + QTableWidgetItem, + QHeaderView, + QSplitter, + QListWidget, + QListWidgetItem, + QAbstractItemView, + QGraphicsView, + QGraphicsScene, + QGraphicsPixmapItem, +) from src.database.db_manager import DatabaseManager from src.utils.config_manager import ConfigManager +from src.utils.logger import get_logger + + +logger = get_logger(__name__) + + +@dataclass(frozen=True) +class _PlotItem: + label: str + path: Path + + +class _ZoomableImageView(QGraphicsView): + """Zoomable image viewer. + + - Mouse wheel: zoom in/out + - Left mouse drag: pan (ScrollHandDrag) + """ + + def __init__(self, parent: Optional[QWidget] = None): + super().__init__(parent) + self._scene = QGraphicsScene(self) + self.setScene(self._scene) + self._pixmap_item = QGraphicsPixmapItem() + self._scene.addItem(self._pixmap_item) + + # QGraphicsView render hints are QPainter.RenderHints. + self.setRenderHints(self.renderHints() | QPainter.RenderHint.SmoothPixmapTransform) + self.setDragMode(QGraphicsView.DragMode.ScrollHandDrag) + self.setTransformationAnchor(QGraphicsView.ViewportAnchor.AnchorUnderMouse) + self.setResizeAnchor(QGraphicsView.ViewportAnchor.AnchorUnderMouse) + + self._has_pixmap = False + + def clear(self) -> None: + self._pixmap_item.setPixmap(QPixmap()) + self._scene.setSceneRect(0, 0, 1, 1) + self.resetTransform() + self._has_pixmap = False + + def set_pixmap(self, pixmap: QPixmap, *, fit: bool = True) -> None: + self._pixmap_item.setPixmap(pixmap) + self._scene.setSceneRect(pixmap.rect()) + self._has_pixmap = not pixmap.isNull() + self.resetTransform() + if fit and self._has_pixmap: + self.fitInView(self._pixmap_item, Qt.AspectRatioMode.KeepAspectRatio) + + def wheelEvent(self, event) -> None: # type: ignore[override] + if not self._has_pixmap: + return + zoom_in_factor = 1.25 + zoom_out_factor = 1.0 / zoom_in_factor + factor = zoom_in_factor if event.angleDelta().y() > 0 else zoom_out_factor + self.scale(factor, factor) class ValidationTab(QWidget): - """Validation tab placeholder.""" + """Validation tab that shows stored validation metrics + plots for a selected model.""" - def __init__( - self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None - ): + def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None): super().__init__(parent) self.db_manager = db_manager self.config_manager = config_manager + self._models: List[Dict[str, Any]] = [] + self._selected_model_id: Optional[int] = None + self._plot_widgets: List[QWidget] = [] + self._plot_items: List[_PlotItem] = [] + self._setup_ui() + self.refresh() def _setup_ui(self): """Setup user interface.""" - layout = QVBoxLayout() + layout = QVBoxLayout(self) - group = QGroupBox("Validation") - group_layout = QVBoxLayout() - label = QLabel( - "Validation functionality will be implemented here.\n\n" - "Features:\n" - "- Model validation\n" - "- Metrics visualization\n" - "- Confusion matrix\n" - "- Precision-Recall curves" - ) - group_layout.addWidget(label) - group.setLayout(group_layout) + # ===== Header controls ===== + header = QGroupBox("Validation") + header_layout = QVBoxLayout() + header_row = QHBoxLayout() - layout.addWidget(group) - layout.addStretch() - self.setLayout(layout) + header_row.addWidget(QLabel("Select model:")) + + self.model_combo = QComboBox() + self.model_combo.setMinimumWidth(420) + self.model_combo.currentIndexChanged.connect(self._on_model_selected) + header_row.addWidget(self.model_combo, 1) + + self.refresh_btn = QPushButton("Refresh") + self.refresh_btn.clicked.connect(self.refresh) + header_row.addWidget(self.refresh_btn) + header_row.addStretch() + + header_layout.addLayout(header_row) + self.header_status = QLabel("No models loaded.") + self.header_status.setWordWrap(True) + header_layout.addWidget(self.header_status) + header.setLayout(header_layout) + layout.addWidget(header) + + # ===== Metrics ===== + metrics_group = QGroupBox("Validation Metrics") + metrics_layout = QVBoxLayout() + + self.metrics_form = QFormLayout() + self.metric_labels: Dict[str, QLabel] = {} + for key in ("mAP50", "mAP50-95", "precision", "recall", "fitness"): + value_label = QLabel("–") + value_label.setTextInteractionFlags(Qt.TextSelectableByMouse) + self.metric_labels[key] = value_label + self.metrics_form.addRow(f"{key}:", value_label) + metrics_layout.addLayout(self.metrics_form) + + self.per_class_table = QTableWidget(0, 3) + self.per_class_table.setHorizontalHeaderLabels(["Class", "AP", "AP50"]) + self.per_class_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Stretch) + self.per_class_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeToContents) + self.per_class_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeToContents) + self.per_class_table.setEditTriggers(QTableWidget.NoEditTriggers) + self.per_class_table.setMinimumHeight(160) + metrics_layout.addWidget(QLabel("Per-class metrics (if available):")) + metrics_layout.addWidget(self.per_class_table) + + metrics_group.setLayout(metrics_layout) + layout.addWidget(metrics_group) + + # ===== Plots ===== + plots_group = QGroupBox("Validation Plots") + plots_layout = QVBoxLayout() + + self.plots_status = QLabel("Select a model to see validation plots.") + self.plots_status.setWordWrap(True) + plots_layout.addWidget(self.plots_status) + + self.plots_splitter = QSplitter(Qt.Orientation.Horizontal) + + # Left: selected image viewer + left_widget = QWidget() + left_layout = QVBoxLayout(left_widget) + left_layout.setContentsMargins(0, 0, 0, 0) + + self.selected_plot_title = QLabel("No image selected.") + self.selected_plot_title.setWordWrap(True) + self.selected_plot_title.setTextInteractionFlags(Qt.TextSelectableByMouse) + left_layout.addWidget(self.selected_plot_title) + + self.plot_view = _ZoomableImageView() + self.plot_view.setMinimumHeight(360) + left_layout.addWidget(self.plot_view, 1) + + self.selected_plot_path = QLabel("") + self.selected_plot_path.setWordWrap(True) + self.selected_plot_path.setStyleSheet("color: #888;") + self.selected_plot_path.setTextInteractionFlags(Qt.TextSelectableByMouse) + left_layout.addWidget(self.selected_plot_path) + + # Right: scrollable list + right_widget = QWidget() + right_layout = QVBoxLayout(right_widget) + right_layout.setContentsMargins(0, 0, 0, 0) + right_layout.addWidget(QLabel("Images:")) + + self.plots_list = QListWidget() + self.plots_list.setSelectionMode(QAbstractItemView.SelectionMode.SingleSelection) + self.plots_list.setIconSize(QSize(160, 160)) + self.plots_list.itemSelectionChanged.connect(self._on_plot_item_selected) + right_layout.addWidget(self.plots_list, 1) + + self.plots_splitter.addWidget(left_widget) + self.plots_splitter.addWidget(right_widget) + self.plots_splitter.setStretchFactor(0, 3) + self.plots_splitter.setStretchFactor(1, 1) + plots_layout.addWidget(self.plots_splitter, 1) + + plots_group.setLayout(plots_layout) + layout.addWidget(plots_group, 1) + + layout.addStretch(0) + + self._clear_metrics() + self._clear_plots() + + # ==================== Public API ==================== def refresh(self): """Refresh the tab.""" - pass + self._load_models() + self._populate_model_combo() + self._restore_or_select_default_model() + + # ==================== Internal: models ==================== + + def _load_models(self) -> None: + try: + self._models = self.db_manager.get_models() or [] + except Exception as exc: + logger.error("Failed to load models: %s", exc) + self._models = [] + + def _populate_model_combo(self) -> None: + self.model_combo.blockSignals(True) + self.model_combo.clear() + self.model_combo.addItem("Select a model…", None) + + for model in self._models: + model_id = model.get("id") + name = (model.get("model_name") or "").strip() + version = (model.get("model_version") or "").strip() + created_at = model.get("created_at") + label = f"{name} {version}".strip() + if created_at: + label = f"{label} ({created_at})" + self.model_combo.addItem(label, model_id) + + self.model_combo.blockSignals(False) + + if self._models: + self.header_status.setText(f"Loaded {len(self._models)} model(s).") + else: + self.header_status.setText("No models found. Train a model first.") + + def _restore_or_select_default_model(self) -> None: + if not self._models: + self._selected_model_id = None + self._clear_metrics() + self._clear_plots() + return + + # Keep selection if still present. + if self._selected_model_id is not None: + for idx in range(1, self.model_combo.count()): + if self.model_combo.itemData(idx) == self._selected_model_id: + self.model_combo.setCurrentIndex(idx) + return + + # Otherwise select the newest model (top of get_models ORDER BY created_at DESC). + first_model_id = self.model_combo.itemData(1) if self.model_combo.count() > 1 else None + if first_model_id is not None: + self.model_combo.setCurrentIndex(1) + + def _on_model_selected(self, index: int) -> None: + model_id = self.model_combo.itemData(index) + if not model_id: + self._selected_model_id = None + self._clear_metrics() + self._clear_plots() + self.plots_status.setText("Select a model to see validation plots.") + return + + self._selected_model_id = int(model_id) + model = self._get_model_by_id(self._selected_model_id) + if not model: + self._clear_metrics() + self._clear_plots() + self.plots_status.setText("Selected model not found.") + return + + self._render_metrics(model) + self._render_plots(model) + + def _get_model_by_id(self, model_id: int) -> Optional[Dict[str, Any]]: + for model in self._models: + if model.get("id") == model_id: + return model + try: + return self.db_manager.get_model_by_id(model_id) + except Exception: + return None + + # ==================== Internal: metrics ==================== + + def _clear_metrics(self) -> None: + for label in self.metric_labels.values(): + label.setText("–") + self.per_class_table.setRowCount(0) + + def _render_metrics(self, model: Dict[str, Any]) -> None: + self._clear_metrics() + + metrics: Dict[str, Any] = model.get("metrics") or {} + # Training tab stores metrics under results['metrics'] in training results payload. + if isinstance(metrics, dict) and "metrics" in metrics and isinstance(metrics.get("metrics"), dict): + metrics = metrics.get("metrics") or {} + + def set_metric(key: str, value: Any) -> None: + if key not in self.metric_labels: + return + if value is None: + self.metric_labels[key].setText("–") + return + try: + self.metric_labels[key].setText(f"{float(value):.4f}") + except Exception: + self.metric_labels[key].setText(str(value)) + + set_metric("mAP50", metrics.get("mAP50")) + set_metric("mAP50-95", metrics.get("mAP50-95") or metrics.get("mAP50_95") or metrics.get("mAP50-95")) + set_metric("precision", metrics.get("precision")) + set_metric("recall", metrics.get("recall")) + set_metric("fitness", metrics.get("fitness")) + + # Optional per-class metrics + class_metrics = metrics.get("class_metrics") if isinstance(metrics, dict) else None + if isinstance(class_metrics, dict) and class_metrics: + items = sorted(class_metrics.items(), key=lambda kv: str(kv[0])) + self.per_class_table.setRowCount(len(items)) + for row, (cls_name, cls_stats) in enumerate(items): + ap = (cls_stats or {}).get("ap") + ap50 = (cls_stats or {}).get("ap50") + self.per_class_table.setItem(row, 0, QTableWidgetItem(str(cls_name))) + self.per_class_table.setItem(row, 1, QTableWidgetItem(self._format_float(ap))) + self.per_class_table.setItem(row, 2, QTableWidgetItem(self._format_float(ap50))) + else: + self.per_class_table.setRowCount(0) + + @staticmethod + def _format_float(value: Any) -> str: + if value is None: + return "–" + try: + return f"{float(value):.4f}" + except Exception: + return str(value) + + # ==================== Internal: plots ==================== + + def _clear_plots(self) -> None: + # Remove legacy grid widgets (from the initial implementation). + for widget in self._plot_widgets: + widget.setParent(None) + widget.deleteLater() + self._plot_widgets = [] + + self._plot_items = [] + + if hasattr(self, "plots_list"): + self.plots_list.blockSignals(True) + self.plots_list.clear() + self.plots_list.blockSignals(False) + + if hasattr(self, "plot_view"): + self.plot_view.clear() + if hasattr(self, "selected_plot_title"): + self.selected_plot_title.setText("No image selected.") + if hasattr(self, "selected_plot_path"): + self.selected_plot_path.setText("") + + def _render_plots(self, model: Dict[str, Any]) -> None: + self._clear_plots() + + plot_dirs = self._infer_run_directories(model) + plot_items = self._discover_plot_items(plot_dirs) + + if not plot_items: + dirs_text = "\n".join(str(p) for p in plot_dirs if p) + self.plots_status.setText( + "No validation plot images found for this model.\n\n" + "Searched directories:\n" + (dirs_text or "(none)") + ) + return + + self._plot_items = list(plot_items) + self.plots_status.setText(f"Found {len(plot_items)} plot image(s). Select one to view/zoom.") + + self.plots_list.blockSignals(True) + self.plots_list.clear() + for idx, item in enumerate(self._plot_items): + qitem = QListWidgetItem(item.label) + qitem.setData(Qt.ItemDataRole.UserRole, idx) + + pix = QPixmap(str(item.path)) + if not pix.isNull(): + thumb = pix.scaled( + self.plots_list.iconSize(), + Qt.AspectRatioMode.KeepAspectRatio, + Qt.TransformationMode.SmoothTransformation, + ) + qitem.setIcon(thumb) + self.plots_list.addItem(qitem) + self.plots_list.blockSignals(False) + + if self.plots_list.count() > 0: + self.plots_list.setCurrentRow(0) + + def _on_plot_item_selected(self) -> None: + if not self._plot_items: + return + + selected = self.plots_list.selectedItems() + if not selected: + return + + idx = selected[0].data(Qt.ItemDataRole.UserRole) + try: + idx_int = int(idx) + except Exception: + return + if idx_int < 0 or idx_int >= len(self._plot_items): + return + + plot = self._plot_items[idx_int] + self.selected_plot_title.setText(plot.label) + self.selected_plot_path.setText(str(plot.path)) + + pix = QPixmap(str(plot.path)) + if pix.isNull(): + self.plot_view.clear() + return + self.plot_view.set_pixmap(pix, fit=True) + + def _infer_run_directories(self, model: Dict[str, Any]) -> List[Path]: + dirs: List[Path] = [] + + # 1) Infer from model_path: ...//weights/best.pt -> + model_path = model.get("model_path") + if model_path: + try: + p = Path(str(model_path)).expanduser() + if p.name.lower().endswith(".pt"): + # If it lives under weights/, use parent.parent. + if p.parent.name == "weights" and p.parent.parent.exists(): + dirs.append(p.parent.parent) + elif p.parent.exists(): + dirs.append(p.parent) + except Exception: + pass + + # 2) Look at training_params.stage_results[].results.save_dir + training_params = model.get("training_params") or {} + stage_results = None + if isinstance(training_params, dict): + stage_results = training_params.get("stage_results") + if isinstance(stage_results, list): + for stage in stage_results: + results = (stage or {}).get("results") + save_dir = (results or {}).get("save_dir") if isinstance(results, dict) else None + if save_dir: + try: + save_path = Path(str(save_dir)).expanduser() + if save_path.exists(): + dirs.append(save_path) + except Exception: + continue + + # Deduplicate while preserving order. + unique: List[Path] = [] + seen: set[str] = set() + for d in dirs: + try: + resolved = str(d.resolve()) + except Exception: + resolved = str(d) + if resolved not in seen and d.exists() and d.is_dir(): + seen.add(resolved) + unique.append(d) + return unique + + def _discover_plot_items(self, directories: Sequence[Path]) -> List[_PlotItem]: + # Prefer canonical Ultralytics filenames first, then fall back to any png/jpg. + preferred_names = [ + "results.png", + "results.jpg", + "confusion_matrix.png", + "confusion_matrix_normalized.png", + "labels.jpg", + "labels.png", + "BoxPR_curve.png", + "BoxP_curve.png", + "BoxR_curve.png", + "BoxF1_curve.png", + "MaskPR_curve.png", + "MaskP_curve.png", + "MaskR_curve.png", + "MaskF1_curve.png", + "val_batch0_pred.jpg", + "val_batch0_labels.jpg", + ] + + found: List[_PlotItem] = [] + seen: set[str] = set() + + for d in directories: + # 1) Preferred + for name in preferred_names: + p = d / name + if p.exists() and p.is_file(): + key = str(p) + if key in seen: + continue + seen.add(key) + found.append(_PlotItem(label=f"{name} (from {d.name})", path=p)) + + # 2) Curated globs + for pattern in ("train_batch*.jpg", "val_batch*.jpg", "*curve*.png"): + for p in sorted(d.glob(pattern)): + if not p.is_file(): + continue + key = str(p) + if key in seen: + continue + seen.add(key) + found.append(_PlotItem(label=f"{p.name} (from {d.name})", path=p)) + + # 3) Fallback: any top-level png/jpg (excluding weights dir contents) + for ext in ("*.png", "*.jpg", "*.jpeg", "*.webp"): + for p in sorted(d.glob(ext)): + if not p.is_file(): + continue + key = str(p) + if key in seen: + continue + seen.add(key) + found.append(_PlotItem(label=f"{p.name} (from {d.name})", path=p)) + + # Keep list bounded to avoid UI overload for huge runs. + return found[:60]