Files
object-segmentation/src/gui/tabs/detection_tab.py
2025-12-10 16:55:28 +02:00

466 lines
16 KiB
Python

"""
Detection tab for the microscopy object detection application.
Handles single image and batch detection.
"""
from PySide6.QtWidgets import (
QWidget,
QVBoxLayout,
QHBoxLayout,
QPushButton,
QLabel,
QComboBox,
QSlider,
QFileDialog,
QMessageBox,
QProgressBar,
QTextEdit,
QGroupBox,
QFormLayout,
)
from PySide6.QtCore import Qt, QThread, Signal
from pathlib import Path
from typing import Optional
from src.database.db_manager import DatabaseManager
from src.utils.config_manager import ConfigManager
from src.utils.logger import get_logger
from src.utils.file_utils import get_image_files
from src.model.inference import InferenceEngine
logger = get_logger(__name__)
class DetectionWorker(QThread):
"""Worker thread for running detection."""
progress = Signal(int, int, str) # current, total, message
finished = Signal(list) # results
error = Signal(str) # error message
def __init__(self, engine, image_paths, repo_root, conf):
super().__init__()
self.engine = engine
self.image_paths = image_paths
self.repo_root = repo_root
self.conf = conf
def run(self):
"""Run detection in background thread."""
try:
results = self.engine.detect_batch(
self.image_paths, self.repo_root, self.conf, self.progress.emit
)
self.finished.emit(results)
except Exception as e:
logger.error(f"Detection error: {e}")
self.error.emit(str(e))
class DetectionTab(QWidget):
"""Detection tab for single image and batch detection."""
def __init__(
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
):
super().__init__(parent)
self.db_manager = db_manager
self.config_manager = config_manager
self.inference_engine = None
self.current_model_id = None
self._setup_ui()
self._connect_signals()
self._load_models()
def _setup_ui(self):
"""Setup user interface."""
layout = QVBoxLayout()
# Model selection group
model_group = QGroupBox("Model Selection")
model_layout = QFormLayout()
self.model_combo = QComboBox()
self.model_combo.addItem("No models available", None)
model_layout.addRow("Model:", self.model_combo)
model_group.setLayout(model_layout)
layout.addWidget(model_group)
# Detection settings group
settings_group = QGroupBox("Detection Settings")
settings_layout = QFormLayout()
# Confidence threshold
conf_layout = QHBoxLayout()
self.conf_slider = QSlider(Qt.Horizontal)
self.conf_slider.setRange(0, 100)
self.conf_slider.setValue(25)
self.conf_slider.setTickPosition(QSlider.TicksBelow)
self.conf_slider.setTickInterval(10)
conf_layout.addWidget(self.conf_slider)
self.conf_label = QLabel("0.25")
conf_layout.addWidget(self.conf_label)
settings_layout.addRow("Confidence:", conf_layout)
settings_group.setLayout(settings_layout)
layout.addWidget(settings_group)
# Action buttons
button_layout = QHBoxLayout()
self.single_image_btn = QPushButton("Detect Single Image")
self.single_image_btn.clicked.connect(self._detect_single_image)
button_layout.addWidget(self.single_image_btn)
self.batch_btn = QPushButton("Detect Batch (Folder)")
self.batch_btn.clicked.connect(self._detect_batch)
button_layout.addWidget(self.batch_btn)
layout.addLayout(button_layout)
# Progress bar
self.progress_bar = QProgressBar()
self.progress_bar.setVisible(False)
layout.addWidget(self.progress_bar)
# Results display
results_group = QGroupBox("Detection Results")
results_layout = QVBoxLayout()
self.results_text = QTextEdit()
self.results_text.setReadOnly(True)
self.results_text.setMaximumHeight(200)
results_layout.addWidget(self.results_text)
results_group.setLayout(results_layout)
layout.addWidget(results_group)
layout.addStretch()
self.setLayout(layout)
def _connect_signals(self):
"""Connect signals and slots."""
self.conf_slider.valueChanged.connect(self._update_confidence_label)
self.model_combo.currentIndexChanged.connect(self._on_model_changed)
def _load_models(self):
"""Load available models from database and local storage."""
try:
self.model_combo.clear()
models = self.db_manager.get_models()
has_models = False
known_paths = set()
# Add base model option first (always available)
base_model = self.config_manager.get(
"models.default_base_model", "yolov8s-seg.pt"
)
if base_model:
base_data = {
"id": 0,
"path": base_model,
"model_name": Path(base_model).stem or "Base Model",
"model_version": "pretrained",
"base_model": base_model,
"source": "base",
}
self.model_combo.addItem(f"Base Model ({base_model})", base_data)
known_paths.add(self._normalize_model_path(base_model))
has_models = True
# Add trained models from database
for model in models:
display_name = f"{model['model_name']} v{model['model_version']}"
model_data = {**model, "path": model.get("model_path")}
normalized = self._normalize_model_path(model_data.get("path"))
if normalized:
known_paths.add(normalized)
self.model_combo.addItem(display_name, model_data)
has_models = True
# Discover local model files not yet in the database
local_models = self._discover_local_models()
for model_path in local_models:
normalized = self._normalize_model_path(model_path)
if normalized in known_paths:
continue
display_name = f"Local Model ({Path(model_path).stem})"
model_data = {
"id": None,
"path": str(model_path),
"model_name": Path(model_path).stem,
"model_version": "local",
"base_model": Path(model_path).stem,
"source": "local",
}
self.model_combo.addItem(display_name, model_data)
known_paths.add(normalized)
has_models = True
if not has_models:
self.model_combo.addItem("No models available", None)
self._set_buttons_enabled(False)
else:
self._set_buttons_enabled(True)
except Exception as e:
logger.error(f"Error loading models: {e}")
QMessageBox.warning(self, "Error", f"Failed to load models:\n{str(e)}")
def _on_model_changed(self, index: int):
"""Handle model selection change."""
model_data = self.model_combo.itemData(index)
if model_data and model_data["id"] != 0:
self.current_model_id = model_data["id"]
else:
self.current_model_id = None
def _update_confidence_label(self, value: int):
"""Update confidence label."""
conf = value / 100.0
self.conf_label.setText(f"{conf:.2f}")
def _detect_single_image(self):
"""Detect objects in a single image."""
# Get image file
repo_path = self.config_manager.get_image_repository_path()
start_dir = repo_path if repo_path else ""
file_path, _ = QFileDialog.getOpenFileName(
self,
"Select Image",
start_dir,
"Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)",
)
if not file_path:
return
# Run detection
self._run_detection([file_path])
def _detect_batch(self):
"""Detect objects in batch (folder)."""
# Get folder
repo_path = self.config_manager.get_image_repository_path()
start_dir = repo_path if repo_path else ""
folder_path = QFileDialog.getExistingDirectory(self, "Select Folder", start_dir)
if not folder_path:
return
# Get all image files
allowed_ext = self.config_manager.get_allowed_extensions()
image_files = get_image_files(folder_path, allowed_ext, recursive=False)
if not image_files:
QMessageBox.information(
self, "No Images", "No image files found in selected folder."
)
return
# Confirm batch processing
reply = QMessageBox.question(
self,
"Confirm Batch Detection",
f"Process {len(image_files)} images?",
QMessageBox.Yes | QMessageBox.No,
)
if reply == QMessageBox.Yes:
self._run_detection(image_files)
def _run_detection(self, image_paths: list):
"""Run detection on image list."""
try:
# Get selected model
model_data = self.model_combo.currentData()
if not model_data:
QMessageBox.warning(self, "No Model", "Please select a model first.")
return
model_path = model_data.get("path")
if not model_path:
QMessageBox.warning(
self, "Invalid Model", "Selected model is missing a file path."
)
return
if not Path(model_path).exists():
QMessageBox.critical(
self,
"Model Not Found",
f"The selected model file could not be found:\n{model_path}",
)
return
model_id = model_data.get("id")
# Ensure we have a database entry for the selected model
if model_id in (None, 0):
model_id = self._ensure_model_record(model_data)
if not model_id:
QMessageBox.critical(
self,
"Model Registration Failed",
"Unable to register the selected model in the database.",
)
return
normalized_model_path = self._normalize_model_path(model_path) or model_path
# Create inference engine
self.inference_engine = InferenceEngine(
normalized_model_path, self.db_manager, model_id
)
# Get confidence threshold
conf = self.conf_slider.value() / 100.0
# Get repository root
repo_root = self.config_manager.get_image_repository_path()
if not repo_root:
repo_root = str(Path(image_paths[0]).parent)
# Show progress bar
self.progress_bar.setVisible(True)
self.progress_bar.setMaximum(len(image_paths))
self._set_buttons_enabled(False)
# Create and start worker thread
self.worker = DetectionWorker(
self.inference_engine, image_paths, repo_root, conf
)
self.worker.progress.connect(self._on_progress)
self.worker.finished.connect(self._on_detection_finished)
self.worker.error.connect(self._on_detection_error)
self.worker.start()
except Exception as e:
logger.error(f"Error starting detection: {e}")
QMessageBox.critical(self, "Error", f"Failed to start detection:\n{str(e)}")
self._set_buttons_enabled(True)
def _on_progress(self, current: int, total: int, message: str):
"""Handle progress update."""
self.progress_bar.setValue(current)
self.results_text.append(f"[{current}/{total}] {message}")
def _on_detection_finished(self, results: list):
"""Handle detection completion."""
self.progress_bar.setVisible(False)
self._set_buttons_enabled(True)
# Calculate statistics
total_detections = sum(r["count"] for r in results)
successful = sum(1 for r in results if r.get("success", False))
summary = f"\n=== Detection Complete ===\n"
summary += f"Processed: {len(results)} images\n"
summary += f"Successful: {successful}\n"
summary += f"Total detections: {total_detections}\n"
self.results_text.append(summary)
QMessageBox.information(
self,
"Detection Complete",
f"Processed {len(results)} images\n{total_detections} objects detected",
)
def _on_detection_error(self, error_msg: str):
"""Handle detection error."""
self.progress_bar.setVisible(False)
self._set_buttons_enabled(True)
self.results_text.append(f"\nERROR: {error_msg}")
QMessageBox.critical(self, "Detection Error", error_msg)
def _set_buttons_enabled(self, enabled: bool):
"""Enable/disable action buttons."""
self.single_image_btn.setEnabled(enabled)
self.batch_btn.setEnabled(enabled)
self.model_combo.setEnabled(enabled)
def _discover_local_models(self) -> list:
"""Scan the models directory for standalone .pt files."""
models_dir = self.config_manager.get_models_directory()
if not models_dir:
return []
models_path = Path(models_dir)
if not models_path.exists():
return []
try:
return sorted(
[p for p in models_path.rglob("*.pt") if p.is_file()],
key=lambda p: str(p).lower(),
)
except Exception as e:
logger.warning(f"Error discovering local models: {e}")
return []
def _normalize_model_path(self, path_value) -> str:
"""Return a normalized absolute path string for comparison."""
if not path_value:
return ""
try:
return str(Path(path_value).resolve())
except Exception:
return str(path_value)
def _ensure_model_record(self, model_data: dict) -> Optional[int]:
"""Ensure a database record exists for the selected model."""
model_path = model_data.get("path")
if not model_path:
return None
normalized_target = self._normalize_model_path(model_path)
try:
existing_models = self.db_manager.get_models()
for model in existing_models:
existing_path = model.get("model_path")
if not existing_path:
continue
normalized_existing = self._normalize_model_path(existing_path)
if (
normalized_existing == normalized_target
or existing_path == model_path
):
return model["id"]
model_name = (
model_data.get("model_name") or Path(model_path).stem or "Custom Model"
)
model_version = (
model_data.get("model_version") or model_data.get("source") or "local"
)
base_model = model_data.get(
"base_model",
self.config_manager.get("models.default_base_model", "yolov8s-seg.pt"),
)
return self.db_manager.add_model(
model_name=model_name,
model_version=model_version,
model_path=normalized_target,
base_model=base_model,
)
except Exception as e:
logger.error(f"Failed to ensure model record for {model_path}: {e}")
return None
def refresh(self):
"""Refresh the tab."""
self._load_models()
self.results_text.clear()