""" Detection tab for the microscopy object detection application. Handles single image and batch detection. """ from PySide6.QtWidgets import ( QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel, QComboBox, QSlider, QFileDialog, QMessageBox, QProgressBar, QTextEdit, QGroupBox, QFormLayout, ) from PySide6.QtCore import Qt, QThread, Signal from pathlib import Path from typing import Optional from src.database.db_manager import DatabaseManager from src.utils.config_manager import ConfigManager from src.utils.logger import get_logger from src.utils.file_utils import get_image_files from src.model.inference import InferenceEngine logger = get_logger(__name__) class DetectionWorker(QThread): """Worker thread for running detection.""" progress = Signal(int, int, str) # current, total, message finished = Signal(list) # results error = Signal(str) # error message def __init__(self, engine, image_paths, repo_root, conf): super().__init__() self.engine = engine self.image_paths = image_paths self.repo_root = repo_root self.conf = conf def run(self): """Run detection in background thread.""" try: results = self.engine.detect_batch( self.image_paths, self.repo_root, self.conf, self.progress.emit ) self.finished.emit(results) except Exception as e: logger.error(f"Detection error: {e}") self.error.emit(str(e)) class DetectionTab(QWidget): """Detection tab for single image and batch detection.""" def __init__( self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None ): super().__init__(parent) self.db_manager = db_manager self.config_manager = config_manager self.inference_engine = None self.current_model_id = None self._setup_ui() self._connect_signals() self._load_models() def _setup_ui(self): """Setup user interface.""" layout = QVBoxLayout() # Model selection group model_group = QGroupBox("Model Selection") model_layout = QFormLayout() self.model_combo = QComboBox() self.model_combo.addItem("No models available", None) model_layout.addRow("Model:", self.model_combo) model_group.setLayout(model_layout) layout.addWidget(model_group) # Detection settings group settings_group = QGroupBox("Detection Settings") settings_layout = QFormLayout() # Confidence threshold conf_layout = QHBoxLayout() self.conf_slider = QSlider(Qt.Horizontal) self.conf_slider.setRange(0, 100) self.conf_slider.setValue(25) self.conf_slider.setTickPosition(QSlider.TicksBelow) self.conf_slider.setTickInterval(10) conf_layout.addWidget(self.conf_slider) self.conf_label = QLabel("0.25") conf_layout.addWidget(self.conf_label) settings_layout.addRow("Confidence:", conf_layout) settings_group.setLayout(settings_layout) layout.addWidget(settings_group) # Action buttons button_layout = QHBoxLayout() self.single_image_btn = QPushButton("Detect Single Image") self.single_image_btn.clicked.connect(self._detect_single_image) button_layout.addWidget(self.single_image_btn) self.batch_btn = QPushButton("Detect Batch (Folder)") self.batch_btn.clicked.connect(self._detect_batch) button_layout.addWidget(self.batch_btn) layout.addLayout(button_layout) # Progress bar self.progress_bar = QProgressBar() self.progress_bar.setVisible(False) layout.addWidget(self.progress_bar) # Results display results_group = QGroupBox("Detection Results") results_layout = QVBoxLayout() self.results_text = QTextEdit() self.results_text.setReadOnly(True) self.results_text.setMaximumHeight(200) results_layout.addWidget(self.results_text) results_group.setLayout(results_layout) layout.addWidget(results_group) layout.addStretch() self.setLayout(layout) def _connect_signals(self): """Connect signals and slots.""" self.conf_slider.valueChanged.connect(self._update_confidence_label) self.model_combo.currentIndexChanged.connect(self._on_model_changed) def _load_models(self): """Load available models from database and local storage.""" try: self.model_combo.clear() models = self.db_manager.get_models() has_models = False known_paths = set() # Add base model option first (always available) base_model = self.config_manager.get( "models.default_base_model", "yolov8s-seg.pt" ) if base_model: base_data = { "id": 0, "path": base_model, "model_name": Path(base_model).stem or "Base Model", "model_version": "pretrained", "base_model": base_model, "source": "base", } self.model_combo.addItem(f"Base Model ({base_model})", base_data) known_paths.add(self._normalize_model_path(base_model)) has_models = True # Add trained models from database for model in models: display_name = f"{model['model_name']} v{model['model_version']}" model_data = {**model, "path": model.get("model_path")} normalized = self._normalize_model_path(model_data.get("path")) if normalized: known_paths.add(normalized) self.model_combo.addItem(display_name, model_data) has_models = True # Discover local model files not yet in the database local_models = self._discover_local_models() for model_path in local_models: normalized = self._normalize_model_path(model_path) if normalized in known_paths: continue display_name = f"Local Model ({Path(model_path).stem})" model_data = { "id": None, "path": str(model_path), "model_name": Path(model_path).stem, "model_version": "local", "base_model": Path(model_path).stem, "source": "local", } self.model_combo.addItem(display_name, model_data) known_paths.add(normalized) has_models = True if not has_models: self.model_combo.addItem("No models available", None) self._set_buttons_enabled(False) else: self._set_buttons_enabled(True) except Exception as e: logger.error(f"Error loading models: {e}") QMessageBox.warning(self, "Error", f"Failed to load models:\n{str(e)}") def _on_model_changed(self, index: int): """Handle model selection change.""" model_data = self.model_combo.itemData(index) if model_data and model_data["id"] != 0: self.current_model_id = model_data["id"] else: self.current_model_id = None def _update_confidence_label(self, value: int): """Update confidence label.""" conf = value / 100.0 self.conf_label.setText(f"{conf:.2f}") def _detect_single_image(self): """Detect objects in a single image.""" # Get image file repo_path = self.config_manager.get_image_repository_path() start_dir = repo_path if repo_path else "" file_path, _ = QFileDialog.getOpenFileName( self, "Select Image", start_dir, "Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)", ) if not file_path: return # Run detection self._run_detection([file_path]) def _detect_batch(self): """Detect objects in batch (folder).""" # Get folder repo_path = self.config_manager.get_image_repository_path() start_dir = repo_path if repo_path else "" folder_path = QFileDialog.getExistingDirectory(self, "Select Folder", start_dir) if not folder_path: return # Get all image files allowed_ext = self.config_manager.get_allowed_extensions() image_files = get_image_files(folder_path, allowed_ext, recursive=False) if not image_files: QMessageBox.information( self, "No Images", "No image files found in selected folder." ) return # Confirm batch processing reply = QMessageBox.question( self, "Confirm Batch Detection", f"Process {len(image_files)} images?", QMessageBox.Yes | QMessageBox.No, ) if reply == QMessageBox.Yes: self._run_detection(image_files) def _run_detection(self, image_paths: list): """Run detection on image list.""" try: # Get selected model model_data = self.model_combo.currentData() if not model_data: QMessageBox.warning(self, "No Model", "Please select a model first.") return model_path = model_data.get("path") if not model_path: QMessageBox.warning( self, "Invalid Model", "Selected model is missing a file path." ) return if not Path(model_path).exists(): QMessageBox.critical( self, "Model Not Found", f"The selected model file could not be found:\n{model_path}", ) return model_id = model_data.get("id") # Ensure we have a database entry for the selected model if model_id in (None, 0): model_id = self._ensure_model_record(model_data) if not model_id: QMessageBox.critical( self, "Model Registration Failed", "Unable to register the selected model in the database.", ) return normalized_model_path = self._normalize_model_path(model_path) or model_path # Create inference engine self.inference_engine = InferenceEngine( normalized_model_path, self.db_manager, model_id ) # Get confidence threshold conf = self.conf_slider.value() / 100.0 # Get repository root repo_root = self.config_manager.get_image_repository_path() if not repo_root: repo_root = str(Path(image_paths[0]).parent) # Show progress bar self.progress_bar.setVisible(True) self.progress_bar.setMaximum(len(image_paths)) self._set_buttons_enabled(False) # Create and start worker thread self.worker = DetectionWorker( self.inference_engine, image_paths, repo_root, conf ) self.worker.progress.connect(self._on_progress) self.worker.finished.connect(self._on_detection_finished) self.worker.error.connect(self._on_detection_error) self.worker.start() except Exception as e: logger.error(f"Error starting detection: {e}") QMessageBox.critical(self, "Error", f"Failed to start detection:\n{str(e)}") self._set_buttons_enabled(True) def _on_progress(self, current: int, total: int, message: str): """Handle progress update.""" self.progress_bar.setValue(current) self.results_text.append(f"[{current}/{total}] {message}") def _on_detection_finished(self, results: list): """Handle detection completion.""" self.progress_bar.setVisible(False) self._set_buttons_enabled(True) # Calculate statistics total_detections = sum(r["count"] for r in results) successful = sum(1 for r in results if r.get("success", False)) summary = f"\n=== Detection Complete ===\n" summary += f"Processed: {len(results)} images\n" summary += f"Successful: {successful}\n" summary += f"Total detections: {total_detections}\n" self.results_text.append(summary) QMessageBox.information( self, "Detection Complete", f"Processed {len(results)} images\n{total_detections} objects detected", ) def _on_detection_error(self, error_msg: str): """Handle detection error.""" self.progress_bar.setVisible(False) self._set_buttons_enabled(True) self.results_text.append(f"\nERROR: {error_msg}") QMessageBox.critical(self, "Detection Error", error_msg) def _set_buttons_enabled(self, enabled: bool): """Enable/disable action buttons.""" self.single_image_btn.setEnabled(enabled) self.batch_btn.setEnabled(enabled) self.model_combo.setEnabled(enabled) def _discover_local_models(self) -> list: """Scan the models directory for standalone .pt files.""" models_dir = self.config_manager.get_models_directory() if not models_dir: return [] models_path = Path(models_dir) if not models_path.exists(): return [] try: return sorted( [p for p in models_path.rglob("*.pt") if p.is_file()], key=lambda p: str(p).lower(), ) except Exception as e: logger.warning(f"Error discovering local models: {e}") return [] def _normalize_model_path(self, path_value) -> str: """Return a normalized absolute path string for comparison.""" if not path_value: return "" try: return str(Path(path_value).resolve()) except Exception: return str(path_value) def _ensure_model_record(self, model_data: dict) -> Optional[int]: """Ensure a database record exists for the selected model.""" model_path = model_data.get("path") if not model_path: return None normalized_target = self._normalize_model_path(model_path) try: existing_models = self.db_manager.get_models() for model in existing_models: existing_path = model.get("model_path") if not existing_path: continue normalized_existing = self._normalize_model_path(existing_path) if ( normalized_existing == normalized_target or existing_path == model_path ): return model["id"] model_name = ( model_data.get("model_name") or Path(model_path).stem or "Custom Model" ) model_version = ( model_data.get("model_version") or model_data.get("source") or "local" ) base_model = model_data.get( "base_model", self.config_manager.get("models.default_base_model", "yolov8s-seg.pt"), ) return self.db_manager.add_model( model_name=model_name, model_version=model_version, model_path=normalized_target, base_model=base_model, ) except Exception as e: logger.error(f"Failed to ensure model record for {model_path}: {e}") return None def refresh(self): """Refresh the tab.""" self._load_models() self.results_text.clear()