diff --git a/src/database/db_manager.py b/src/database/db_manager.py index 0100798..039786d 100644 --- a/src/database/db_manager.py +++ b/src/database/db_manager.py @@ -450,6 +450,25 @@ class DatabaseManager: filters["model_id"] = model_id return self.get_detections(filters) + def delete_detections_for_image( + self, image_id: int, model_id: Optional[int] = None + ) -> int: + """Delete detections tied to a specific image and optional model.""" + conn = self.get_connection() + try: + cursor = conn.cursor() + if model_id is not None: + cursor.execute( + "DELETE FROM detections WHERE image_id = ? AND model_id = ?", + (image_id, model_id), + ) + else: + cursor.execute("DELETE FROM detections WHERE image_id = ?", (image_id,)) + conn.commit() + return cursor.rowcount + finally: + conn.close() + def delete_detections_for_model(self, model_id: int) -> int: """Delete all detections for a specific model.""" conn = self.get_connection() diff --git a/src/gui/tabs/detection_tab.py b/src/gui/tabs/detection_tab.py index 01a3861..364783f 100644 --- a/src/gui/tabs/detection_tab.py +++ b/src/gui/tabs/detection_tab.py @@ -20,6 +20,7 @@ from PySide6.QtWidgets import ( ) from PySide6.QtCore import Qt, QThread, Signal from pathlib import Path +from typing import Optional from src.database.db_manager import DatabaseManager from src.utils.config_manager import ConfigManager @@ -147,30 +148,66 @@ class DetectionTab(QWidget): self.model_combo.currentIndexChanged.connect(self._on_model_changed) def _load_models(self): - """Load available models from database.""" + """Load available models from database and local storage.""" try: - models = self.db_manager.get_models() self.model_combo.clear() + models = self.db_manager.get_models() + has_models = False - if not models: - self.model_combo.addItem("No models available", None) - self._set_buttons_enabled(False) - return + known_paths = set() - # Add base model option + # Add base model option first (always available) base_model = self.config_manager.get( "models.default_base_model", "yolov8s-seg.pt" ) - self.model_combo.addItem( - f"Base Model ({base_model})", {"id": 0, "path": base_model} - ) + if base_model: + base_data = { + "id": 0, + "path": base_model, + "model_name": Path(base_model).stem or "Base Model", + "model_version": "pretrained", + "base_model": base_model, + "source": "base", + } + self.model_combo.addItem(f"Base Model ({base_model})", base_data) + known_paths.add(self._normalize_model_path(base_model)) + has_models = True - # Add trained models + # Add trained models from database for model in models: display_name = f"{model['model_name']} v{model['model_version']}" - self.model_combo.addItem(display_name, model) + model_data = {**model, "path": model.get("model_path")} + normalized = self._normalize_model_path(model_data.get("path")) + if normalized: + known_paths.add(normalized) + self.model_combo.addItem(display_name, model_data) + has_models = True - self._set_buttons_enabled(True) + # Discover local model files not yet in the database + local_models = self._discover_local_models() + for model_path in local_models: + normalized = self._normalize_model_path(model_path) + if normalized in known_paths: + continue + + display_name = f"Local Model ({Path(model_path).stem})" + model_data = { + "id": None, + "path": str(model_path), + "model_name": Path(model_path).stem, + "model_version": "local", + "base_model": Path(model_path).stem, + "source": "local", + } + self.model_combo.addItem(display_name, model_data) + known_paths.add(normalized) + has_models = True + + if not has_models: + self.model_combo.addItem("No models available", None) + self._set_buttons_enabled(False) + else: + self._set_buttons_enabled(True) except Exception as e: logger.error(f"Error loading models: {e}") @@ -249,25 +286,39 @@ class DetectionTab(QWidget): QMessageBox.warning(self, "No Model", "Please select a model first.") return - model_path = model_data["path"] - model_id = model_data["id"] + model_path = model_data.get("path") + if not model_path: + QMessageBox.warning( + self, "Invalid Model", "Selected model is missing a file path." + ) + return - # Ensure we have a valid model ID (create entry for base model if needed) - if model_id == 0: - # Create database entry for base model - base_model = self.config_manager.get( - "models.default_base_model", "yolov8s-seg.pt" - ) - model_id = self.db_manager.add_model( - model_name="Base Model", - model_version="pretrained", - model_path=base_model, - base_model=base_model, + if not Path(model_path).exists(): + QMessageBox.critical( + self, + "Model Not Found", + f"The selected model file could not be found:\n{model_path}", ) + return + + model_id = model_data.get("id") + + # Ensure we have a database entry for the selected model + if model_id in (None, 0): + model_id = self._ensure_model_record(model_data) + if not model_id: + QMessageBox.critical( + self, + "Model Registration Failed", + "Unable to register the selected model in the database.", + ) + return + + normalized_model_path = self._normalize_model_path(model_path) or model_path # Create inference engine self.inference_engine = InferenceEngine( - model_path, self.db_manager, model_id + normalized_model_path, self.db_manager, model_id ) # Get confidence threshold @@ -338,6 +389,76 @@ class DetectionTab(QWidget): self.batch_btn.setEnabled(enabled) self.model_combo.setEnabled(enabled) + def _discover_local_models(self) -> list: + """Scan the models directory for standalone .pt files.""" + models_dir = self.config_manager.get_models_directory() + if not models_dir: + return [] + + models_path = Path(models_dir) + if not models_path.exists(): + return [] + + try: + return sorted( + [p for p in models_path.rglob("*.pt") if p.is_file()], + key=lambda p: str(p).lower(), + ) + except Exception as e: + logger.warning(f"Error discovering local models: {e}") + return [] + + def _normalize_model_path(self, path_value) -> str: + """Return a normalized absolute path string for comparison.""" + if not path_value: + return "" + try: + return str(Path(path_value).resolve()) + except Exception: + return str(path_value) + + def _ensure_model_record(self, model_data: dict) -> Optional[int]: + """Ensure a database record exists for the selected model.""" + model_path = model_data.get("path") + if not model_path: + return None + + normalized_target = self._normalize_model_path(model_path) + + try: + existing_models = self.db_manager.get_models() + for model in existing_models: + existing_path = model.get("model_path") + if not existing_path: + continue + normalized_existing = self._normalize_model_path(existing_path) + if ( + normalized_existing == normalized_target + or existing_path == model_path + ): + return model["id"] + + model_name = ( + model_data.get("model_name") or Path(model_path).stem or "Custom Model" + ) + model_version = ( + model_data.get("model_version") or model_data.get("source") or "local" + ) + base_model = model_data.get( + "base_model", + self.config_manager.get("models.default_base_model", "yolov8s-seg.pt"), + ) + + return self.db_manager.add_model( + model_name=model_name, + model_version=model_version, + model_path=normalized_target, + base_model=base_model, + ) + except Exception as e: + logger.error(f"Failed to ensure model record for {model_path}: {e}") + return None + def refresh(self): """Refresh the tab.""" self._load_models() diff --git a/src/gui/tabs/results_tab.py b/src/gui/tabs/results_tab.py index 71e523c..530ff53 100644 --- a/src/gui/tabs/results_tab.py +++ b/src/gui/tabs/results_tab.py @@ -1,15 +1,39 @@ """ -Results tab for the microscopy object detection application. +Results tab for browsing stored detections and visualizing overlays. """ -from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox +from pathlib import Path +from typing import Dict, List, Optional + +from PySide6.QtWidgets import ( + QWidget, + QVBoxLayout, + QHBoxLayout, + QLabel, + QGroupBox, + QPushButton, + QSplitter, + QTableWidget, + QTableWidgetItem, + QHeaderView, + QAbstractItemView, + QMessageBox, + QCheckBox, +) +from PySide6.QtCore import Qt from src.database.db_manager import DatabaseManager from src.utils.config_manager import ConfigManager +from src.utils.logger import get_logger +from src.utils.image import Image, ImageLoadError +from src.gui.widgets import AnnotationCanvasWidget + + +logger = get_logger(__name__) class ResultsTab(QWidget): - """Results tab placeholder.""" + """Results tab showing detection history and preview overlays.""" def __init__( self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None @@ -18,29 +42,387 @@ class ResultsTab(QWidget): self.db_manager = db_manager self.config_manager = config_manager + self.detection_summary: List[Dict] = [] + self.current_selection: Optional[Dict] = None + self.current_image: Optional[Image] = None + self.current_detections: List[Dict] = [] + self._image_path_cache: Dict[str, str] = {} + self._setup_ui() + self.refresh() def _setup_ui(self): """Setup user interface.""" layout = QVBoxLayout() - group = QGroupBox("Results") - group_layout = QVBoxLayout() - label = QLabel( - "Results viewer will be implemented here.\n\n" - "Features:\n" - "- Detection history browser\n" - "- Advanced filtering\n" - "- Statistics dashboard\n" - "- Export functionality" - ) - group_layout.addWidget(label) - group.setLayout(group_layout) + # Splitter for list + preview + splitter = QSplitter(Qt.Horizontal) - layout.addWidget(group) - layout.addStretch() + # Left pane: detection list + left_container = QWidget() + left_layout = QVBoxLayout() + left_layout.setContentsMargins(0, 0, 0, 0) + + controls_layout = QHBoxLayout() + self.refresh_btn = QPushButton("Refresh") + self.refresh_btn.clicked.connect(self.refresh) + controls_layout.addWidget(self.refresh_btn) + controls_layout.addStretch() + left_layout.addLayout(controls_layout) + + self.results_table = QTableWidget(0, 5) + self.results_table.setHorizontalHeaderLabels( + ["Image", "Model", "Detections", "Classes", "Last Updated"] + ) + self.results_table.horizontalHeader().setSectionResizeMode( + 0, QHeaderView.Stretch + ) + self.results_table.horizontalHeader().setSectionResizeMode( + 1, QHeaderView.Stretch + ) + self.results_table.horizontalHeader().setSectionResizeMode( + 2, QHeaderView.ResizeToContents + ) + self.results_table.horizontalHeader().setSectionResizeMode( + 3, QHeaderView.Stretch + ) + self.results_table.horizontalHeader().setSectionResizeMode( + 4, QHeaderView.ResizeToContents + ) + self.results_table.setSelectionBehavior(QAbstractItemView.SelectRows) + self.results_table.setSelectionMode(QAbstractItemView.SingleSelection) + self.results_table.setEditTriggers(QAbstractItemView.NoEditTriggers) + self.results_table.itemSelectionChanged.connect(self._on_result_selected) + + left_layout.addWidget(self.results_table) + left_container.setLayout(left_layout) + + # Right pane: preview canvas and controls + right_container = QWidget() + right_layout = QVBoxLayout() + right_layout.setContentsMargins(0, 0, 0, 0) + + preview_group = QGroupBox("Detection Preview") + preview_layout = QVBoxLayout() + + self.preview_canvas = AnnotationCanvasWidget() + self.preview_canvas.set_polyline_enabled(False) + self.preview_canvas.set_show_bboxes(True) + preview_layout.addWidget(self.preview_canvas) + + toggles_layout = QHBoxLayout() + self.show_masks_checkbox = QCheckBox("Show Masks") + self.show_masks_checkbox.setChecked(False) + self.show_masks_checkbox.stateChanged.connect(self._apply_detection_overlays) + self.show_bboxes_checkbox = QCheckBox("Show Bounding Boxes") + self.show_bboxes_checkbox.setChecked(True) + self.show_bboxes_checkbox.stateChanged.connect(self._toggle_bboxes) + toggles_layout.addWidget(self.show_masks_checkbox) + toggles_layout.addWidget(self.show_bboxes_checkbox) + toggles_layout.addStretch() + preview_layout.addLayout(toggles_layout) + + self.summary_label = QLabel("Select a detection result to preview.") + self.summary_label.setWordWrap(True) + preview_layout.addWidget(self.summary_label) + + preview_group.setLayout(preview_layout) + right_layout.addWidget(preview_group) + right_container.setLayout(right_layout) + + splitter.addWidget(left_container) + splitter.addWidget(right_container) + splitter.setStretchFactor(0, 1) + splitter.setStretchFactor(1, 2) + + layout.addWidget(splitter) self.setLayout(layout) def refresh(self): - """Refresh the tab.""" - pass + """Refresh the detection list and preview.""" + self._load_detection_summary() + self._populate_results_table() + self.current_selection = None + self.current_image = None + self.current_detections = [] + self.preview_canvas.clear() + self.summary_label.setText("Select a detection result to preview.") + + def _load_detection_summary(self): + """Load latest detection summaries grouped by image + model.""" + try: + detections = self.db_manager.get_detections(limit=500) + summary_map: Dict[tuple, Dict] = {} + + for det in detections: + key = (det["image_id"], det["model_id"]) + metadata = det.get("metadata") or {} + entry = summary_map.setdefault( + key, + { + "image_id": det["image_id"], + "model_id": det["model_id"], + "image_path": det.get("image_path"), + "image_filename": det.get("image_filename") + or det.get("image_path"), + "model_name": det.get("model_name", ""), + "model_version": det.get("model_version", ""), + "last_detected": det.get("detected_at"), + "count": 0, + "classes": set(), + "source_path": metadata.get("source_path"), + "repository_root": metadata.get("repository_root"), + }, + ) + + entry["count"] += 1 + if det.get("detected_at") and ( + not entry.get("last_detected") + or str(det.get("detected_at")) > str(entry.get("last_detected")) + ): + entry["last_detected"] = det.get("detected_at") + if det.get("class_name"): + entry["classes"].add(det["class_name"]) + if metadata.get("source_path") and not entry.get("source_path"): + entry["source_path"] = metadata.get("source_path") + if metadata.get("repository_root") and not entry.get("repository_root"): + entry["repository_root"] = metadata.get("repository_root") + + self.detection_summary = sorted( + summary_map.values(), + key=lambda x: str(x.get("last_detected") or ""), + reverse=True, + ) + except Exception as e: + logger.error(f"Failed to load detection summary: {e}") + QMessageBox.critical( + self, + "Error", + f"Failed to load detection results:\n{str(e)}", + ) + self.detection_summary = [] + + def _populate_results_table(self): + """Populate the table widget with detection summaries.""" + self.results_table.setRowCount(len(self.detection_summary)) + + for row, entry in enumerate(self.detection_summary): + model_label = f"{entry['model_name']} {entry['model_version']}".strip() + class_list = ( + ", ".join(sorted(entry["classes"])) if entry["classes"] else "-" + ) + + items = [ + QTableWidgetItem(entry.get("image_filename", "")), + QTableWidgetItem(model_label), + QTableWidgetItem(str(entry.get("count", 0))), + QTableWidgetItem(class_list), + QTableWidgetItem(str(entry.get("last_detected") or "")), + ] + + for col, item in enumerate(items): + item.setData(Qt.UserRole, row) + self.results_table.setItem(row, col, item) + + self.results_table.clearSelection() + + def _on_result_selected(self): + """Handle selection changes in the detection table.""" + selected_items = self.results_table.selectedItems() + if not selected_items: + return + + row = selected_items[0].data(Qt.UserRole) + if row is None or row >= len(self.detection_summary): + return + + entry = self.detection_summary[row] + if ( + self.current_selection + and self.current_selection.get("image_id") == entry["image_id"] + and self.current_selection.get("model_id") == entry["model_id"] + ): + return + + self.current_selection = entry + + image_path = self._resolve_image_path(entry) + if not image_path: + QMessageBox.warning( + self, + "Image Not Found", + "Unable to locate the image file for this detection.", + ) + return + + try: + self.current_image = Image(image_path) + self.preview_canvas.load_image(self.current_image) + except ImageLoadError as e: + logger.error(f"Failed to load image '{image_path}': {e}") + QMessageBox.critical( + self, + "Image Error", + f"Failed to load image for preview:\n{str(e)}", + ) + return + + self._load_detections_for_selection(entry) + self._apply_detection_overlays() + self._update_summary_label(entry) + + def _load_detections_for_selection(self, entry: Dict): + """Load detection records for the selected image/model pair.""" + self.current_detections = [] + if not entry: + return + + try: + filters = {"image_id": entry["image_id"], "model_id": entry["model_id"]} + self.current_detections = self.db_manager.get_detections(filters) + except Exception as e: + logger.error(f"Failed to load detections for preview: {e}") + QMessageBox.critical( + self, + "Error", + f"Failed to load detections for this image:\n{str(e)}", + ) + self.current_detections = [] + + def _apply_detection_overlays(self): + """Draw detections onto the preview canvas based on current toggles.""" + self.preview_canvas.clear_annotations() + self.preview_canvas.set_show_bboxes(self.show_bboxes_checkbox.isChecked()) + + if not self.current_detections or not self.current_image: + return + + for det in self.current_detections: + color = self._get_class_color(det.get("class_name")) + + if self.show_masks_checkbox.isChecked() and det.get("segmentation_mask"): + mask_points = self._convert_mask(det["segmentation_mask"]) + if mask_points: + self.preview_canvas.draw_saved_polyline(mask_points, color) + + bbox = [ + det.get("x_min"), + det.get("y_min"), + det.get("x_max"), + det.get("y_max"), + ] + if all(v is not None for v in bbox): + self.preview_canvas.draw_saved_bbox(bbox, color) + + def _convert_mask(self, mask_points: List[List[float]]) -> List[List[float]]: + """Convert stored [x, y] masks to [y, x] format for the canvas.""" + converted = [] + for point in mask_points: + if len(point) >= 2: + x, y = point[0], point[1] + converted.append([y, x]) + return converted + + def _toggle_bboxes(self): + """Update bounding box visibility on the canvas.""" + self.preview_canvas.set_show_bboxes(self.show_bboxes_checkbox.isChecked()) + # Re-render to respect show/hide when toggled + self._apply_detection_overlays() + + def _update_summary_label(self, entry: Dict): + """Display textual summary for the selected detection run.""" + classes = ", ".join(sorted(entry.get("classes", []))) or "-" + summary_text = ( + f"Image: {entry.get('image_filename', 'unknown')}\n" + f"Model: {entry.get('model_name', '')} {entry.get('model_version', '')}\n" + f"Detections: {entry.get('count', 0)}\n" + f"Classes: {classes}\n" + f"Last Updated: {entry.get('last_detected', 'n/a')}" + ) + self.summary_label.setText(summary_text) + + def _resolve_image_path(self, entry: Dict) -> Optional[str]: + """Resolve an image path using metadata, cache, and repository hints.""" + relative_path = entry.get("image_path") if entry else None + cache_key = relative_path or entry.get("source_path") + if cache_key and cache_key in self._image_path_cache: + cached = Path(self._image_path_cache[cache_key]) + if cached.exists(): + return self._image_path_cache[cache_key] + del self._image_path_cache[cache_key] + + candidates = [] + source_path = entry.get("source_path") if entry else None + if source_path: + candidates.append(Path(source_path)) + + repo_roots = [] + if entry.get("repository_root"): + repo_roots.append(entry["repository_root"]) + config_repo = self.config_manager.get_image_repository_path() + if config_repo: + repo_roots.append(config_repo) + + for root in repo_roots: + if relative_path: + candidates.append(Path(root) / relative_path) + + if relative_path: + candidates.append(Path(relative_path)) + + for candidate in candidates: + try: + if candidate and candidate.exists(): + resolved = str(candidate.resolve()) + if cache_key: + self._image_path_cache[cache_key] = resolved + return resolved + except Exception: + continue + + # Fallback: search by filename in known roots + filename = Path(relative_path).name if relative_path else None + if filename: + search_roots = [Path(root) for root in repo_roots if root] + if not search_roots: + search_roots = [Path("data")] + match = self._search_in_roots(filename, search_roots) + if match: + resolved = str(match.resolve()) + if cache_key: + self._image_path_cache[cache_key] = resolved + return resolved + + return None + + def _search_in_roots(self, filename: str, roots: List[Path]) -> Optional[Path]: + """Search for a file name within a list of root directories.""" + for root in roots: + try: + if not root.exists(): + continue + for candidate in root.rglob(filename): + return candidate + except Exception as e: + logger.debug(f"Error searching for {filename} in {root}: {e}") + return None + + def _get_class_color(self, class_name: Optional[str]) -> str: + """Return consistent color hex for a class name.""" + if not class_name: + return "#FF6B6B" + + color_map = self.config_manager.get_bbox_colors() + if class_name in color_map: + return color_map[class_name] + + # Deterministic fallback color based on hash + palette = [ + "#FF6B6B", + "#4ECDC4", + "#FFD166", + "#1D3557", + "#F4A261", + "#E76F51", + ] + return palette[hash(class_name) % len(palette)] diff --git a/src/gui/widgets/annotation_canvas_widget.py b/src/gui/widgets/annotation_canvas_widget.py index 37523e9..880143e 100644 --- a/src/gui/widgets/annotation_canvas_widget.py +++ b/src/gui/widgets/annotation_canvas_widget.py @@ -263,7 +263,7 @@ class AnnotationCanvasWidget(QWidget): height, bytes_per_line, self.current_image.qtimage_format, - ) + ).copy() # Copy so Qt owns the buffer even after numpy array goes out of scope self.original_pixmap = QPixmap.fromImage(qimage) diff --git a/src/gui/widgets/image_display_widget.py b/src/gui/widgets/image_display_widget.py index 52d2ce2..bf5ffd1 100644 --- a/src/gui/widgets/image_display_widget.py +++ b/src/gui/widgets/image_display_widget.py @@ -137,7 +137,7 @@ class ImageDisplayWidget(QWidget): height, bytes_per_line, self.current_image.qtimage_format, - ) + ).copy() # Copy to ensure Qt owns its memory after this scope # Convert to pixmap pixmap = QPixmap.fromImage(qimage) diff --git a/src/model/inference.py b/src/model/inference.py index 2a3780b..99a2de7 100644 --- a/src/model/inference.py +++ b/src/model/inference.py @@ -42,6 +42,7 @@ class InferenceEngine: relative_path: str, conf: float = 0.25, save_to_db: bool = True, + repository_root: Optional[str] = None, ) -> Dict: """ Detect objects in a single image. @@ -51,11 +52,17 @@ class InferenceEngine: relative_path: Relative path from repository root conf: Confidence threshold save_to_db: Whether to save results to database + repository_root: Base directory used to compute relative_path (if known) Returns: Dictionary with detection results """ try: + # Normalize storage path (fall back to absolute path when repo root is unknown) + stored_relative_path = relative_path + if not repository_root: + stored_relative_path = str(Path(image_path).resolve()) + # Get image dimensions img = Image.open(image_path) width, height = img.size @@ -66,34 +73,58 @@ class InferenceEngine: # Add/get image in database image_id = self.db_manager.get_or_create_image( - relative_path=relative_path, + relative_path=stored_relative_path, filename=Path(image_path).name, width=width, height=height, ) - # Save detections to database - if save_to_db and detections: - detection_records = [] - for det in detections: - # Use normalized bbox from detection - bbox_normalized = det[ - "bbox_normalized" - ] # [x_min, y_min, x_max, y_max] + inserted_count = 0 + deleted_count = 0 - record = { - "image_id": image_id, - "model_id": self.model_id, - "class_name": det["class_name"], - "bbox": tuple(bbox_normalized), - "confidence": det["confidence"], - "segmentation_mask": det.get("segmentation_mask"), - "metadata": {"class_id": det["class_id"]}, - } - detection_records.append(record) + # Save detections to database, replacing any previous results for this image/model + if save_to_db: + deleted_count = self.db_manager.delete_detections_for_image( + image_id, self.model_id + ) + if detections: + detection_records = [] + for det in detections: + # Use normalized bbox from detection + bbox_normalized = det[ + "bbox_normalized" + ] # [x_min, y_min, x_max, y_max] - self.db_manager.add_detections_batch(detection_records) - logger.info(f"Saved {len(detection_records)} detections to database") + metadata = { + "class_id": det["class_id"], + "source_path": str(Path(image_path).resolve()), + } + if repository_root: + metadata["repository_root"] = str( + Path(repository_root).resolve() + ) + + record = { + "image_id": image_id, + "model_id": self.model_id, + "class_name": det["class_name"], + "bbox": tuple(bbox_normalized), + "confidence": det["confidence"], + "segmentation_mask": det.get("segmentation_mask"), + "metadata": metadata, + } + detection_records.append(record) + + inserted_count = self.db_manager.add_detections_batch( + detection_records + ) + logger.info( + f"Saved {inserted_count} detections to database (replaced {deleted_count})" + ) + else: + logger.info( + f"Detection run removed {deleted_count} stale entries but produced no new detections" + ) return { "success": True, @@ -142,7 +173,12 @@ class InferenceEngine: rel_path = get_relative_path(image_path, repository_root) # Perform detection - result = self.detect_single(image_path, rel_path, conf) + result = self.detect_single( + image_path, + rel_path, + conf=conf, + repository_root=repository_root, + ) results.append(result) # Update progress diff --git a/src/model/yolo_wrapper.py b/src/model/yolo_wrapper.py index cbcf01e..ce573d7 100644 --- a/src/model/yolo_wrapper.py +++ b/src/model/yolo_wrapper.py @@ -7,6 +7,9 @@ from ultralytics import YOLO from pathlib import Path from typing import Optional, List, Dict, Callable, Any import torch +from PIL import Image +import tempfile +import os from src.utils.logger import get_logger @@ -162,10 +165,12 @@ class YOLOWrapper: if self.model is None: self.load_model() + prepared_source, cleanup_path = self._prepare_source(source) + try: logger.info(f"Running inference on {source}") results = self.model.predict( - source=source, + source=prepared_source, conf=conf, iou=iou, save=save, @@ -182,6 +187,14 @@ class YOLOWrapper: except Exception as e: logger.error(f"Error during inference: {e}") raise + finally: + if cleanup_path: + try: + os.remove(cleanup_path) + except OSError as cleanup_error: + logger.warning( + f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}" + ) def export( self, format: str = "onnx", output_path: Optional[str] = None, **kwargs @@ -210,6 +223,36 @@ class YOLOWrapper: logger.error(f"Error exporting model: {e}") raise + def _prepare_source(self, source): + """Convert single-channel images to RGB temporarily for inference.""" + cleanup_path = None + + if isinstance(source, (str, Path)): + source_path = Path(source) + if source_path.is_file(): + try: + with Image.open(source_path) as img: + if len(img.getbands()) == 1: + rgb_img = img.convert("RGB") + suffix = source_path.suffix or ".png" + tmp = tempfile.NamedTemporaryFile( + suffix=suffix, delete=False + ) + tmp_path = tmp.name + tmp.close() + rgb_img.save(tmp_path) + cleanup_path = tmp_path + logger.info( + f"Converted single-channel image {source_path} to RGB for inference" + ) + return tmp_path, cleanup_path + except Exception as convert_error: + logger.warning( + f"Failed to preprocess {source_path} as RGB, continuing with original file: {convert_error}" + ) + + return source, cleanup_path + def _format_training_results(self, results) -> Dict[str, Any]: """Format training results into dictionary.""" try: