""" Detection tab for the microscopy object detection application. Handles single image and batch detection. """ from PySide6.QtWidgets import ( QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel, QComboBox, QSlider, QFileDialog, QMessageBox, QProgressBar, QTextEdit, QGroupBox, QFormLayout, ) from PySide6.QtCore import Qt, QThread, Signal from pathlib import Path from src.database.db_manager import DatabaseManager from src.utils.config_manager import ConfigManager from src.utils.logger import get_logger from src.utils.file_utils import get_image_files from src.model.inference import InferenceEngine logger = get_logger(__name__) class DetectionWorker(QThread): """Worker thread for running detection.""" progress = Signal(int, int, str) # current, total, message finished = Signal(list) # results error = Signal(str) # error message def __init__(self, engine, image_paths, repo_root, conf): super().__init__() self.engine = engine self.image_paths = image_paths self.repo_root = repo_root self.conf = conf def run(self): """Run detection in background thread.""" try: results = self.engine.detect_batch( self.image_paths, self.repo_root, self.conf, self.progress.emit ) self.finished.emit(results) except Exception as e: logger.error(f"Detection error: {e}") self.error.emit(str(e)) class DetectionTab(QWidget): """Detection tab for single image and batch detection.""" def __init__( self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None ): super().__init__(parent) self.db_manager = db_manager self.config_manager = config_manager self.inference_engine = None self.current_model_id = None self._setup_ui() self._connect_signals() self._load_models() def _setup_ui(self): """Setup user interface.""" layout = QVBoxLayout() # Model selection group model_group = QGroupBox("Model Selection") model_layout = QFormLayout() self.model_combo = QComboBox() self.model_combo.addItem("No models available", None) model_layout.addRow("Model:", self.model_combo) model_group.setLayout(model_layout) layout.addWidget(model_group) # Detection settings group settings_group = QGroupBox("Detection Settings") settings_layout = QFormLayout() # Confidence threshold conf_layout = QHBoxLayout() self.conf_slider = QSlider(Qt.Horizontal) self.conf_slider.setRange(0, 100) self.conf_slider.setValue(25) self.conf_slider.setTickPosition(QSlider.TicksBelow) self.conf_slider.setTickInterval(10) conf_layout.addWidget(self.conf_slider) self.conf_label = QLabel("0.25") conf_layout.addWidget(self.conf_label) settings_layout.addRow("Confidence:", conf_layout) settings_group.setLayout(settings_layout) layout.addWidget(settings_group) # Action buttons button_layout = QHBoxLayout() self.single_image_btn = QPushButton("Detect Single Image") self.single_image_btn.clicked.connect(self._detect_single_image) button_layout.addWidget(self.single_image_btn) self.batch_btn = QPushButton("Detect Batch (Folder)") self.batch_btn.clicked.connect(self._detect_batch) button_layout.addWidget(self.batch_btn) layout.addLayout(button_layout) # Progress bar self.progress_bar = QProgressBar() self.progress_bar.setVisible(False) layout.addWidget(self.progress_bar) # Results display results_group = QGroupBox("Detection Results") results_layout = QVBoxLayout() self.results_text = QTextEdit() self.results_text.setReadOnly(True) self.results_text.setMaximumHeight(200) results_layout.addWidget(self.results_text) results_group.setLayout(results_layout) layout.addWidget(results_group) layout.addStretch() self.setLayout(layout) def _connect_signals(self): """Connect signals and slots.""" self.conf_slider.valueChanged.connect(self._update_confidence_label) self.model_combo.currentIndexChanged.connect(self._on_model_changed) def _load_models(self): """Load available models from database.""" try: models = self.db_manager.get_models() self.model_combo.clear() if not models: self.model_combo.addItem("No models available", None) self._set_buttons_enabled(False) return # Add base model option base_model = self.config_manager.get( "models.default_base_model", "yolov8s.pt" ) self.model_combo.addItem( f"Base Model ({base_model})", {"id": 0, "path": base_model} ) # Add trained models for model in models: display_name = f"{model['model_name']} v{model['model_version']}" self.model_combo.addItem(display_name, model) self._set_buttons_enabled(True) except Exception as e: logger.error(f"Error loading models: {e}") QMessageBox.warning(self, "Error", f"Failed to load models:\n{str(e)}") def _on_model_changed(self, index: int): """Handle model selection change.""" model_data = self.model_combo.itemData(index) if model_data and model_data["id"] != 0: self.current_model_id = model_data["id"] else: self.current_model_id = None def _update_confidence_label(self, value: int): """Update confidence label.""" conf = value / 100.0 self.conf_label.setText(f"{conf:.2f}") def _detect_single_image(self): """Detect objects in a single image.""" # Get image file repo_path = self.config_manager.get_image_repository_path() start_dir = repo_path if repo_path else "" file_path, _ = QFileDialog.getOpenFileName( self, "Select Image", start_dir, "Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)", ) if not file_path: return # Run detection self._run_detection([file_path]) def _detect_batch(self): """Detect objects in batch (folder).""" # Get folder repo_path = self.config_manager.get_image_repository_path() start_dir = repo_path if repo_path else "" folder_path = QFileDialog.getExistingDirectory(self, "Select Folder", start_dir) if not folder_path: return # Get all image files allowed_ext = self.config_manager.get_allowed_extensions() image_files = get_image_files(folder_path, allowed_ext, recursive=False) if not image_files: QMessageBox.information( self, "No Images", "No image files found in selected folder." ) return # Confirm batch processing reply = QMessageBox.question( self, "Confirm Batch Detection", f"Process {len(image_files)} images?", QMessageBox.Yes | QMessageBox.No, ) if reply == QMessageBox.Yes: self._run_detection(image_files) def _run_detection(self, image_paths: list): """Run detection on image list.""" try: # Get selected model model_data = self.model_combo.currentData() if not model_data: QMessageBox.warning(self, "No Model", "Please select a model first.") return model_path = model_data["path"] model_id = model_data["id"] # Ensure we have a valid model ID (create entry for base model if needed) if model_id == 0: # Create database entry for base model base_model = self.config_manager.get( "models.default_base_model", "yolov8s.pt" ) model_id = self.db_manager.add_model( model_name="Base Model", model_version="pretrained", model_path=base_model, base_model=base_model, ) # Create inference engine self.inference_engine = InferenceEngine( model_path, self.db_manager, model_id ) # Get confidence threshold conf = self.conf_slider.value() / 100.0 # Get repository root repo_root = self.config_manager.get_image_repository_path() if not repo_root: repo_root = str(Path(image_paths[0]).parent) # Show progress bar self.progress_bar.setVisible(True) self.progress_bar.setMaximum(len(image_paths)) self._set_buttons_enabled(False) # Create and start worker thread self.worker = DetectionWorker( self.inference_engine, image_paths, repo_root, conf ) self.worker.progress.connect(self._on_progress) self.worker.finished.connect(self._on_detection_finished) self.worker.error.connect(self._on_detection_error) self.worker.start() except Exception as e: logger.error(f"Error starting detection: {e}") QMessageBox.critical(self, "Error", f"Failed to start detection:\n{str(e)}") self._set_buttons_enabled(True) def _on_progress(self, current: int, total: int, message: str): """Handle progress update.""" self.progress_bar.setValue(current) self.results_text.append(f"[{current}/{total}] {message}") def _on_detection_finished(self, results: list): """Handle detection completion.""" self.progress_bar.setVisible(False) self._set_buttons_enabled(True) # Calculate statistics total_detections = sum(r["count"] for r in results) successful = sum(1 for r in results if r.get("success", False)) summary = f"\n=== Detection Complete ===\n" summary += f"Processed: {len(results)} images\n" summary += f"Successful: {successful}\n" summary += f"Total detections: {total_detections}\n" self.results_text.append(summary) QMessageBox.information( self, "Detection Complete", f"Processed {len(results)} images\n{total_detections} objects detected", ) def _on_detection_error(self, error_msg: str): """Handle detection error.""" self.progress_bar.setVisible(False) self._set_buttons_enabled(True) self.results_text.append(f"\nERROR: {error_msg}") QMessageBox.critical(self, "Detection Error", error_msg) def _set_buttons_enabled(self, enabled: bool): """Enable/disable action buttons.""" self.single_image_btn.setEnabled(enabled) self.batch_btn.setEnabled(enabled) self.model_combo.setEnabled(enabled) def refresh(self): """Refresh the tab.""" self._load_models() self.results_text.clear()