""" Inference engine for the microscopy object detection application. Handles detection inference and result storage. """ from typing import List, Dict, Optional, Callable from pathlib import Path import cv2 import numpy as np from src.model.yolo_wrapper import YOLOWrapper from src.database.db_manager import DatabaseManager from src.utils.image import Image from src.utils.logger import get_logger from src.utils.file_utils import get_relative_path logger = get_logger(__name__) class InferenceEngine: """Handles detection inference and result storage.""" def __init__(self, model_path: str, db_manager: DatabaseManager, model_id: int): """ Initialize inference engine. Args: model_path: Path to YOLO model weights db_manager: Database manager instance model_id: ID of the model in database """ self.yolo = YOLOWrapper(model_path) self.yolo.load_model() self.db_manager = db_manager self.model_id = model_id logger.info(f"InferenceEngine initialized with model_id {model_id}") def detect_single( self, image_path: str, relative_path: str, conf: float = 0.25, save_to_db: bool = True, repository_root: Optional[str] = None, ) -> Dict: """ Detect objects in a single image. Args: image_path: Absolute path to image file relative_path: Relative path from repository root conf: Confidence threshold save_to_db: Whether to save results to database repository_root: Base directory used to compute relative_path (if known) Returns: Dictionary with detection results """ try: # Normalize storage path (fall back to absolute path when repo root is unknown) stored_relative_path = relative_path if not repository_root: stored_relative_path = str(Path(image_path).resolve()) # Get image dimensions img = Image(image_path) width = img.width height = img.height # Perform detection detections = self.yolo.predict(image_path, conf=conf) # Add/get image in database image_id = self.db_manager.get_or_create_image( relative_path=stored_relative_path, filename=Path(image_path).name, width=width, height=height, ) inserted_count = 0 deleted_count = 0 # Save detections to database, replacing any previous results for this image/model if save_to_db: deleted_count = self.db_manager.delete_detections_for_image( image_id, self.model_id ) if detections: detection_records = [] for det in detections: # Use normalized bbox from detection bbox_normalized = det[ "bbox_normalized" ] # [x_min, y_min, x_max, y_max] metadata = { "class_id": det["class_id"], "source_path": str(Path(image_path).resolve()), } if repository_root: metadata["repository_root"] = str( Path(repository_root).resolve() ) record = { "image_id": image_id, "model_id": self.model_id, "class_name": det["class_name"], "bbox": tuple(bbox_normalized), "confidence": det["confidence"], "segmentation_mask": det.get("segmentation_mask"), "metadata": metadata, } detection_records.append(record) inserted_count = self.db_manager.add_detections_batch( detection_records ) logger.info( f"Saved {inserted_count} detections to database (replaced {deleted_count})" ) else: logger.info( f"Detection run removed {deleted_count} stale entries but produced no new detections" ) return { "success": True, "image_path": image_path, "image_id": image_id, "detections": detections, "count": len(detections), } except Exception as e: logger.error(f"Error detecting objects in {image_path}: {e}") return { "success": False, "image_path": image_path, "error": str(e), "detections": [], "count": 0, } def detect_batch( self, image_paths: List[str], repository_root: str, conf: float = 0.25, progress_callback: Optional[Callable[[int, int, str], None]] = None, ) -> List[Dict]: """ Detect objects in multiple images. Args: image_paths: List of absolute image paths repository_root: Root directory for relative paths conf: Confidence threshold progress_callback: Optional callback(current, total, message) Returns: List of detection result dictionaries """ results = [] total = len(image_paths) logger.info(f"Starting batch detection on {total} images") for i, image_path in enumerate(image_paths, 1): # Calculate relative path rel_path = get_relative_path(image_path, repository_root) # Perform detection result = self.detect_single( image_path, rel_path, conf=conf, repository_root=repository_root, ) results.append(result) # Update progress if progress_callback: progress_callback(i, total, f"Processed {rel_path}") if i % 10 == 0: logger.info(f"Processed {i}/{total} images") logger.info(f"Batch detection complete: {total} images processed") return results def detect_with_visualization( self, image_path: str, conf: float = 0.25, bbox_thickness: int = 2, bbox_colors: Optional[Dict[str, str]] = None, draw_masks: bool = True, ) -> tuple: """ Detect objects and return annotated image. Args: image_path: Path to image conf: Confidence threshold bbox_thickness: Thickness of bounding boxes bbox_colors: Dictionary mapping class names to hex colors draw_masks: Whether to draw segmentation masks (if available) Returns: Tuple of (detections, annotated_image_array) """ try: detections = self.yolo.predict(image_path, conf=conf) # Load image img = cv2.imread(image_path) if img is None: raise ValueError(f"Failed to load image: {image_path}") img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) height, width = img.shape[:2] # Default colors if not provided if bbox_colors is None: bbox_colors = {} default_color = self._hex_to_bgr(bbox_colors.get("default", "#00FF00")) # Draw detections for det in detections: # Get color for this class class_name = det["class_name"] color_hex = bbox_colors.get( class_name, bbox_colors.get("default", "#00FF00") ) color = self._hex_to_bgr(color_hex) # Draw segmentation mask if available and requested if draw_masks and det.get("segmentation_mask"): mask_normalized = det["segmentation_mask"] if mask_normalized and len(mask_normalized) > 0: # Convert normalized coordinates to absolute pixels mask_points = np.array( [ [int(pt[0] * width), int(pt[1] * height)] for pt in mask_normalized ], dtype=np.int32, ) # Create a semi-transparent overlay overlay = img.copy() cv2.fillPoly(overlay, [mask_points], color) # Blend with original image (30% opacity) cv2.addWeighted(overlay, 0.3, img, 0.7, 0, img) # Draw mask contour cv2.polylines(img, [mask_points], True, color, bbox_thickness) # Get absolute coordinates for bounding box bbox_abs = det["bbox_absolute"] x1, y1, x2, y2 = [int(v) for v in bbox_abs] # Draw bounding box cv2.rectangle(img, (x1, y1), (x2, y2), color, bbox_thickness) # Prepare label label = f"{class_name} {det['confidence']:.2f}" # Draw label background (label_w, label_h), baseline = cv2.getTextSize( label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1 ) cv2.rectangle( img, (x1, y1 - label_h - baseline - 5), (x1 + label_w, y1), color, -1, ) # Draw label text cv2.putText( img, label, (x1, y1 - baseline - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, ) return detections, img except Exception as e: logger.error(f"Error creating visualization: {e}") # Return empty detections and original image if possible try: img = cv2.imread(image_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return [], img except: return [], np.zeros((480, 640, 3), dtype=np.uint8) def get_detection_summary(self, detections: List[Dict]) -> Dict[str, any]: """ Generate summary statistics for detections. Args: detections: List of detection dictionaries Returns: Dictionary with summary statistics """ if not detections: return { "total_count": 0, "class_counts": {}, "avg_confidence": 0.0, "confidence_range": (0.0, 0.0), } # Count by class class_counts = {} confidences = [] for det in detections: class_name = det["class_name"] class_counts[class_name] = class_counts.get(class_name, 0) + 1 confidences.append(det["confidence"]) return { "total_count": len(detections), "class_counts": class_counts, "avg_confidence": sum(confidences) / len(confidences), "confidence_range": (min(confidences), max(confidences)), } @staticmethod def _hex_to_bgr(hex_color: str) -> tuple: """ Convert hex color to BGR tuple. Args: hex_color: Hex color string (e.g., '#FF0000') Returns: BGR tuple (B, G, R) """ hex_color = hex_color.lstrip("#") if len(hex_color) != 6: return (0, 255, 0) # Default green try: r = int(hex_color[0:2], 16) g = int(hex_color[2:4], 16) b = int(hex_color[4:6], 16) return (b, g, r) # OpenCV uses BGR except ValueError: return (0, 255, 0) # Default green def change_model(self, model_path: str, model_id: int) -> bool: """ Change the current model. Args: model_path: Path to new model weights model_id: ID of new model in database Returns: True if successful, False otherwise """ try: self.yolo = YOLOWrapper(model_path) if self.yolo.load_model(): self.model_id = model_id logger.info(f"Model changed to {model_path}") return True return False except Exception as e: logger.error(f"Error changing model: {e}") return False