345 lines
11 KiB
Python
345 lines
11 KiB
Python
|
|
"""
|
||
|
|
Detection tab for the microscopy object detection application.
|
||
|
|
Handles single image and batch detection.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from PySide6.QtWidgets import (
|
||
|
|
QWidget,
|
||
|
|
QVBoxLayout,
|
||
|
|
QHBoxLayout,
|
||
|
|
QPushButton,
|
||
|
|
QLabel,
|
||
|
|
QComboBox,
|
||
|
|
QSlider,
|
||
|
|
QFileDialog,
|
||
|
|
QMessageBox,
|
||
|
|
QProgressBar,
|
||
|
|
QTextEdit,
|
||
|
|
QGroupBox,
|
||
|
|
QFormLayout,
|
||
|
|
)
|
||
|
|
from PySide6.QtCore import Qt, QThread, Signal
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
from src.database.db_manager import DatabaseManager
|
||
|
|
from src.utils.config_manager import ConfigManager
|
||
|
|
from src.utils.logger import get_logger
|
||
|
|
from src.utils.file_utils import get_image_files
|
||
|
|
from src.model.inference import InferenceEngine
|
||
|
|
|
||
|
|
|
||
|
|
logger = get_logger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class DetectionWorker(QThread):
|
||
|
|
"""Worker thread for running detection."""
|
||
|
|
|
||
|
|
progress = Signal(int, int, str) # current, total, message
|
||
|
|
finished = Signal(list) # results
|
||
|
|
error = Signal(str) # error message
|
||
|
|
|
||
|
|
def __init__(self, engine, image_paths, repo_root, conf):
|
||
|
|
super().__init__()
|
||
|
|
self.engine = engine
|
||
|
|
self.image_paths = image_paths
|
||
|
|
self.repo_root = repo_root
|
||
|
|
self.conf = conf
|
||
|
|
|
||
|
|
def run(self):
|
||
|
|
"""Run detection in background thread."""
|
||
|
|
try:
|
||
|
|
results = self.engine.detect_batch(
|
||
|
|
self.image_paths, self.repo_root, self.conf, self.progress.emit
|
||
|
|
)
|
||
|
|
self.finished.emit(results)
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Detection error: {e}")
|
||
|
|
self.error.emit(str(e))
|
||
|
|
|
||
|
|
|
||
|
|
class DetectionTab(QWidget):
|
||
|
|
"""Detection tab for single image and batch detection."""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
||
|
|
):
|
||
|
|
super().__init__(parent)
|
||
|
|
self.db_manager = db_manager
|
||
|
|
self.config_manager = config_manager
|
||
|
|
self.inference_engine = None
|
||
|
|
self.current_model_id = None
|
||
|
|
|
||
|
|
self._setup_ui()
|
||
|
|
self._connect_signals()
|
||
|
|
self._load_models()
|
||
|
|
|
||
|
|
def _setup_ui(self):
|
||
|
|
"""Setup user interface."""
|
||
|
|
layout = QVBoxLayout()
|
||
|
|
|
||
|
|
# Model selection group
|
||
|
|
model_group = QGroupBox("Model Selection")
|
||
|
|
model_layout = QFormLayout()
|
||
|
|
|
||
|
|
self.model_combo = QComboBox()
|
||
|
|
self.model_combo.addItem("No models available", None)
|
||
|
|
model_layout.addRow("Model:", self.model_combo)
|
||
|
|
|
||
|
|
model_group.setLayout(model_layout)
|
||
|
|
layout.addWidget(model_group)
|
||
|
|
|
||
|
|
# Detection settings group
|
||
|
|
settings_group = QGroupBox("Detection Settings")
|
||
|
|
settings_layout = QFormLayout()
|
||
|
|
|
||
|
|
# Confidence threshold
|
||
|
|
conf_layout = QHBoxLayout()
|
||
|
|
self.conf_slider = QSlider(Qt.Horizontal)
|
||
|
|
self.conf_slider.setRange(0, 100)
|
||
|
|
self.conf_slider.setValue(25)
|
||
|
|
self.conf_slider.setTickPosition(QSlider.TicksBelow)
|
||
|
|
self.conf_slider.setTickInterval(10)
|
||
|
|
conf_layout.addWidget(self.conf_slider)
|
||
|
|
|
||
|
|
self.conf_label = QLabel("0.25")
|
||
|
|
conf_layout.addWidget(self.conf_label)
|
||
|
|
|
||
|
|
settings_layout.addRow("Confidence:", conf_layout)
|
||
|
|
settings_group.setLayout(settings_layout)
|
||
|
|
layout.addWidget(settings_group)
|
||
|
|
|
||
|
|
# Action buttons
|
||
|
|
button_layout = QHBoxLayout()
|
||
|
|
|
||
|
|
self.single_image_btn = QPushButton("Detect Single Image")
|
||
|
|
self.single_image_btn.clicked.connect(self._detect_single_image)
|
||
|
|
button_layout.addWidget(self.single_image_btn)
|
||
|
|
|
||
|
|
self.batch_btn = QPushButton("Detect Batch (Folder)")
|
||
|
|
self.batch_btn.clicked.connect(self._detect_batch)
|
||
|
|
button_layout.addWidget(self.batch_btn)
|
||
|
|
|
||
|
|
layout.addLayout(button_layout)
|
||
|
|
|
||
|
|
# Progress bar
|
||
|
|
self.progress_bar = QProgressBar()
|
||
|
|
self.progress_bar.setVisible(False)
|
||
|
|
layout.addWidget(self.progress_bar)
|
||
|
|
|
||
|
|
# Results display
|
||
|
|
results_group = QGroupBox("Detection Results")
|
||
|
|
results_layout = QVBoxLayout()
|
||
|
|
|
||
|
|
self.results_text = QTextEdit()
|
||
|
|
self.results_text.setReadOnly(True)
|
||
|
|
self.results_text.setMaximumHeight(200)
|
||
|
|
results_layout.addWidget(self.results_text)
|
||
|
|
|
||
|
|
results_group.setLayout(results_layout)
|
||
|
|
layout.addWidget(results_group)
|
||
|
|
|
||
|
|
layout.addStretch()
|
||
|
|
self.setLayout(layout)
|
||
|
|
|
||
|
|
def _connect_signals(self):
|
||
|
|
"""Connect signals and slots."""
|
||
|
|
self.conf_slider.valueChanged.connect(self._update_confidence_label)
|
||
|
|
self.model_combo.currentIndexChanged.connect(self._on_model_changed)
|
||
|
|
|
||
|
|
def _load_models(self):
|
||
|
|
"""Load available models from database."""
|
||
|
|
try:
|
||
|
|
models = self.db_manager.get_models()
|
||
|
|
self.model_combo.clear()
|
||
|
|
|
||
|
|
if not models:
|
||
|
|
self.model_combo.addItem("No models available", None)
|
||
|
|
self._set_buttons_enabled(False)
|
||
|
|
return
|
||
|
|
|
||
|
|
# Add base model option
|
||
|
|
base_model = self.config_manager.get(
|
||
|
|
"models.default_base_model", "yolov8s.pt"
|
||
|
|
)
|
||
|
|
self.model_combo.addItem(
|
||
|
|
f"Base Model ({base_model})", {"id": 0, "path": base_model}
|
||
|
|
)
|
||
|
|
|
||
|
|
# Add trained models
|
||
|
|
for model in models:
|
||
|
|
display_name = f"{model['model_name']} v{model['model_version']}"
|
||
|
|
self.model_combo.addItem(display_name, model)
|
||
|
|
|
||
|
|
self._set_buttons_enabled(True)
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error loading models: {e}")
|
||
|
|
QMessageBox.warning(self, "Error", f"Failed to load models:\n{str(e)}")
|
||
|
|
|
||
|
|
def _on_model_changed(self, index: int):
|
||
|
|
"""Handle model selection change."""
|
||
|
|
model_data = self.model_combo.itemData(index)
|
||
|
|
if model_data and model_data["id"] != 0:
|
||
|
|
self.current_model_id = model_data["id"]
|
||
|
|
else:
|
||
|
|
self.current_model_id = None
|
||
|
|
|
||
|
|
def _update_confidence_label(self, value: int):
|
||
|
|
"""Update confidence label."""
|
||
|
|
conf = value / 100.0
|
||
|
|
self.conf_label.setText(f"{conf:.2f}")
|
||
|
|
|
||
|
|
def _detect_single_image(self):
|
||
|
|
"""Detect objects in a single image."""
|
||
|
|
# Get image file
|
||
|
|
repo_path = self.config_manager.get_image_repository_path()
|
||
|
|
start_dir = repo_path if repo_path else ""
|
||
|
|
|
||
|
|
file_path, _ = QFileDialog.getOpenFileName(
|
||
|
|
self,
|
||
|
|
"Select Image",
|
||
|
|
start_dir,
|
||
|
|
"Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)",
|
||
|
|
)
|
||
|
|
|
||
|
|
if not file_path:
|
||
|
|
return
|
||
|
|
|
||
|
|
# Run detection
|
||
|
|
self._run_detection([file_path])
|
||
|
|
|
||
|
|
def _detect_batch(self):
|
||
|
|
"""Detect objects in batch (folder)."""
|
||
|
|
# Get folder
|
||
|
|
repo_path = self.config_manager.get_image_repository_path()
|
||
|
|
start_dir = repo_path if repo_path else ""
|
||
|
|
|
||
|
|
folder_path = QFileDialog.getExistingDirectory(self, "Select Folder", start_dir)
|
||
|
|
|
||
|
|
if not folder_path:
|
||
|
|
return
|
||
|
|
|
||
|
|
# Get all image files
|
||
|
|
allowed_ext = self.config_manager.get_allowed_extensions()
|
||
|
|
image_files = get_image_files(folder_path, allowed_ext, recursive=False)
|
||
|
|
|
||
|
|
if not image_files:
|
||
|
|
QMessageBox.information(
|
||
|
|
self, "No Images", "No image files found in selected folder."
|
||
|
|
)
|
||
|
|
return
|
||
|
|
|
||
|
|
# Confirm batch processing
|
||
|
|
reply = QMessageBox.question(
|
||
|
|
self,
|
||
|
|
"Confirm Batch Detection",
|
||
|
|
f"Process {len(image_files)} images?",
|
||
|
|
QMessageBox.Yes | QMessageBox.No,
|
||
|
|
)
|
||
|
|
|
||
|
|
if reply == QMessageBox.Yes:
|
||
|
|
self._run_detection(image_files)
|
||
|
|
|
||
|
|
def _run_detection(self, image_paths: list):
|
||
|
|
"""Run detection on image list."""
|
||
|
|
try:
|
||
|
|
# Get selected model
|
||
|
|
model_data = self.model_combo.currentData()
|
||
|
|
if not model_data:
|
||
|
|
QMessageBox.warning(self, "No Model", "Please select a model first.")
|
||
|
|
return
|
||
|
|
|
||
|
|
model_path = model_data["path"]
|
||
|
|
model_id = model_data["id"]
|
||
|
|
|
||
|
|
# Ensure we have a valid model ID (create entry for base model if needed)
|
||
|
|
if model_id == 0:
|
||
|
|
# Create database entry for base model
|
||
|
|
base_model = self.config_manager.get(
|
||
|
|
"models.default_base_model", "yolov8s.pt"
|
||
|
|
)
|
||
|
|
model_id = self.db_manager.add_model(
|
||
|
|
model_name="Base Model",
|
||
|
|
model_version="pretrained",
|
||
|
|
model_path=base_model,
|
||
|
|
base_model=base_model,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Create inference engine
|
||
|
|
self.inference_engine = InferenceEngine(
|
||
|
|
model_path, self.db_manager, model_id
|
||
|
|
)
|
||
|
|
|
||
|
|
# Get confidence threshold
|
||
|
|
conf = self.conf_slider.value() / 100.0
|
||
|
|
|
||
|
|
# Get repository root
|
||
|
|
repo_root = self.config_manager.get_image_repository_path()
|
||
|
|
if not repo_root:
|
||
|
|
repo_root = str(Path(image_paths[0]).parent)
|
||
|
|
|
||
|
|
# Show progress bar
|
||
|
|
self.progress_bar.setVisible(True)
|
||
|
|
self.progress_bar.setMaximum(len(image_paths))
|
||
|
|
self._set_buttons_enabled(False)
|
||
|
|
|
||
|
|
# Create and start worker thread
|
||
|
|
self.worker = DetectionWorker(
|
||
|
|
self.inference_engine, image_paths, repo_root, conf
|
||
|
|
)
|
||
|
|
self.worker.progress.connect(self._on_progress)
|
||
|
|
self.worker.finished.connect(self._on_detection_finished)
|
||
|
|
self.worker.error.connect(self._on_detection_error)
|
||
|
|
self.worker.start()
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error starting detection: {e}")
|
||
|
|
QMessageBox.critical(self, "Error", f"Failed to start detection:\n{str(e)}")
|
||
|
|
self._set_buttons_enabled(True)
|
||
|
|
|
||
|
|
def _on_progress(self, current: int, total: int, message: str):
|
||
|
|
"""Handle progress update."""
|
||
|
|
self.progress_bar.setValue(current)
|
||
|
|
self.results_text.append(f"[{current}/{total}] {message}")
|
||
|
|
|
||
|
|
def _on_detection_finished(self, results: list):
|
||
|
|
"""Handle detection completion."""
|
||
|
|
self.progress_bar.setVisible(False)
|
||
|
|
self._set_buttons_enabled(True)
|
||
|
|
|
||
|
|
# Calculate statistics
|
||
|
|
total_detections = sum(r["count"] for r in results)
|
||
|
|
successful = sum(1 for r in results if r.get("success", False))
|
||
|
|
|
||
|
|
summary = f"\n=== Detection Complete ===\n"
|
||
|
|
summary += f"Processed: {len(results)} images\n"
|
||
|
|
summary += f"Successful: {successful}\n"
|
||
|
|
summary += f"Total detections: {total_detections}\n"
|
||
|
|
|
||
|
|
self.results_text.append(summary)
|
||
|
|
|
||
|
|
QMessageBox.information(
|
||
|
|
self,
|
||
|
|
"Detection Complete",
|
||
|
|
f"Processed {len(results)} images\n{total_detections} objects detected",
|
||
|
|
)
|
||
|
|
|
||
|
|
def _on_detection_error(self, error_msg: str):
|
||
|
|
"""Handle detection error."""
|
||
|
|
self.progress_bar.setVisible(False)
|
||
|
|
self._set_buttons_enabled(True)
|
||
|
|
|
||
|
|
self.results_text.append(f"\nERROR: {error_msg}")
|
||
|
|
QMessageBox.critical(self, "Detection Error", error_msg)
|
||
|
|
|
||
|
|
def _set_buttons_enabled(self, enabled: bool):
|
||
|
|
"""Enable/disable action buttons."""
|
||
|
|
self.single_image_btn.setEnabled(enabled)
|
||
|
|
self.batch_btn.setEnabled(enabled)
|
||
|
|
self.model_combo.setEnabled(enabled)
|
||
|
|
|
||
|
|
def refresh(self):
|
||
|
|
"""Refresh the tab."""
|
||
|
|
self._load_models()
|
||
|
|
self.results_text.clear()
|