Files
object-segmentation/src/model/inference.py

385 lines
13 KiB
Python

"""
Inference engine for the microscopy object detection application.
Handles detection inference and result storage.
"""
from typing import List, Dict, Optional, Callable
from pathlib import Path
import cv2
import numpy as np
from src.model.yolo_wrapper import YOLOWrapper
from src.database.db_manager import DatabaseManager
from src.utils.image import Image
from src.utils.logger import get_logger
from src.utils.file_utils import get_relative_path
logger = get_logger(__name__)
class InferenceEngine:
"""Handles detection inference and result storage."""
def __init__(self, model_path: str, db_manager: DatabaseManager, model_id: int):
"""
Initialize inference engine.
Args:
model_path: Path to YOLO model weights
db_manager: Database manager instance
model_id: ID of the model in database
"""
self.yolo = YOLOWrapper(model_path)
self.yolo.load_model()
self.db_manager = db_manager
self.model_id = model_id
logger.info(f"InferenceEngine initialized with model_id {model_id}")
def detect_single(
self,
image_path: str,
relative_path: str,
conf: float = 0.25,
save_to_db: bool = True,
repository_root: Optional[str] = None,
) -> Dict:
"""
Detect objects in a single image.
Args:
image_path: Absolute path to image file
relative_path: Relative path from repository root
conf: Confidence threshold
save_to_db: Whether to save results to database
repository_root: Base directory used to compute relative_path (if known)
Returns:
Dictionary with detection results
"""
try:
# Normalize storage path (fall back to absolute path when repo root is unknown)
stored_relative_path = relative_path
if not repository_root:
stored_relative_path = str(Path(image_path).resolve())
# Get image dimensions
img = Image(image_path)
width = img.width
height = img.height
# Perform detection
detections = self.yolo.predict(image_path, conf=conf)
# Add/get image in database
image_id = self.db_manager.get_or_create_image(
relative_path=stored_relative_path,
filename=Path(image_path).name,
width=width,
height=height,
)
inserted_count = 0
deleted_count = 0
# Save detections to database, replacing any previous results for this image/model
if save_to_db:
deleted_count = self.db_manager.delete_detections_for_image(
image_id, self.model_id
)
if detections:
detection_records = []
for det in detections:
# Use normalized bbox from detection
bbox_normalized = det[
"bbox_normalized"
] # [x_min, y_min, x_max, y_max]
metadata = {
"class_id": det["class_id"],
"source_path": str(Path(image_path).resolve()),
}
if repository_root:
metadata["repository_root"] = str(
Path(repository_root).resolve()
)
record = {
"image_id": image_id,
"model_id": self.model_id,
"class_name": det["class_name"],
"bbox": tuple(bbox_normalized),
"confidence": det["confidence"],
"segmentation_mask": det.get("segmentation_mask"),
"metadata": metadata,
}
detection_records.append(record)
inserted_count = self.db_manager.add_detections_batch(
detection_records
)
logger.info(
f"Saved {inserted_count} detections to database (replaced {deleted_count})"
)
else:
logger.info(
f"Detection run removed {deleted_count} stale entries but produced no new detections"
)
return {
"success": True,
"image_path": image_path,
"image_id": image_id,
"detections": detections,
"count": len(detections),
}
except Exception as e:
logger.error(f"Error detecting objects in {image_path}: {e}")
return {
"success": False,
"image_path": image_path,
"error": str(e),
"detections": [],
"count": 0,
}
def detect_batch(
self,
image_paths: List[str],
repository_root: str,
conf: float = 0.25,
progress_callback: Optional[Callable[[int, int, str], None]] = None,
) -> List[Dict]:
"""
Detect objects in multiple images.
Args:
image_paths: List of absolute image paths
repository_root: Root directory for relative paths
conf: Confidence threshold
progress_callback: Optional callback(current, total, message)
Returns:
List of detection result dictionaries
"""
results = []
total = len(image_paths)
logger.info(f"Starting batch detection on {total} images")
for i, image_path in enumerate(image_paths, 1):
# Calculate relative path
rel_path = get_relative_path(image_path, repository_root)
# Perform detection
result = self.detect_single(
image_path,
rel_path,
conf=conf,
repository_root=repository_root,
)
results.append(result)
# Update progress
if progress_callback:
progress_callback(i, total, f"Processed {rel_path}")
if i % 10 == 0:
logger.info(f"Processed {i}/{total} images")
logger.info(f"Batch detection complete: {total} images processed")
return results
def detect_with_visualization(
self,
image_path: str,
conf: float = 0.25,
bbox_thickness: int = 2,
bbox_colors: Optional[Dict[str, str]] = None,
draw_masks: bool = True,
) -> tuple:
"""
Detect objects and return annotated image.
Args:
image_path: Path to image
conf: Confidence threshold
bbox_thickness: Thickness of bounding boxes
bbox_colors: Dictionary mapping class names to hex colors
draw_masks: Whether to draw segmentation masks (if available)
Returns:
Tuple of (detections, annotated_image_array)
"""
try:
detections = self.yolo.predict(image_path, conf=conf)
# Load image
img = cv2.imread(image_path)
if img is None:
raise ValueError(f"Failed to load image: {image_path}")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
height, width = img.shape[:2]
# Default colors if not provided
if bbox_colors is None:
bbox_colors = {}
default_color = self._hex_to_bgr(bbox_colors.get("default", "#00FF00"))
# Draw detections
for det in detections:
# Get color for this class
class_name = det["class_name"]
color_hex = bbox_colors.get(
class_name, bbox_colors.get("default", "#00FF00")
)
color = self._hex_to_bgr(color_hex)
# Draw segmentation mask if available and requested
if draw_masks and det.get("segmentation_mask"):
mask_normalized = det["segmentation_mask"]
if mask_normalized and len(mask_normalized) > 0:
# Convert normalized coordinates to absolute pixels
mask_points = np.array(
[
[int(pt[0] * width), int(pt[1] * height)]
for pt in mask_normalized
],
dtype=np.int32,
)
# Create a semi-transparent overlay
overlay = img.copy()
cv2.fillPoly(overlay, [mask_points], color)
# Blend with original image (30% opacity)
cv2.addWeighted(overlay, 0.3, img, 0.7, 0, img)
# Draw mask contour
cv2.polylines(img, [mask_points], True, color, bbox_thickness)
# Get absolute coordinates for bounding box
bbox_abs = det["bbox_absolute"]
x1, y1, x2, y2 = [int(v) for v in bbox_abs]
# Draw bounding box
cv2.rectangle(img, (x1, y1), (x2, y2), color, bbox_thickness)
# Prepare label
label = f"{class_name} {det['confidence']:.2f}"
# Draw label background
(label_w, label_h), baseline = cv2.getTextSize(
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
)
cv2.rectangle(
img,
(x1, y1 - label_h - baseline - 5),
(x1 + label_w, y1),
color,
-1,
)
# Draw label text
cv2.putText(
img,
label,
(x1, y1 - baseline - 5),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(255, 255, 255),
1,
)
return detections, img
except Exception as e:
logger.error(f"Error creating visualization: {e}")
# Return empty detections and original image if possible
try:
img = cv2.imread(image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return [], img
except:
return [], np.zeros((480, 640, 3), dtype=np.uint8)
def get_detection_summary(self, detections: List[Dict]) -> Dict[str, any]:
"""
Generate summary statistics for detections.
Args:
detections: List of detection dictionaries
Returns:
Dictionary with summary statistics
"""
if not detections:
return {
"total_count": 0,
"class_counts": {},
"avg_confidence": 0.0,
"confidence_range": (0.0, 0.0),
}
# Count by class
class_counts = {}
confidences = []
for det in detections:
class_name = det["class_name"]
class_counts[class_name] = class_counts.get(class_name, 0) + 1
confidences.append(det["confidence"])
return {
"total_count": len(detections),
"class_counts": class_counts,
"avg_confidence": sum(confidences) / len(confidences),
"confidence_range": (min(confidences), max(confidences)),
}
@staticmethod
def _hex_to_bgr(hex_color: str) -> tuple:
"""
Convert hex color to BGR tuple.
Args:
hex_color: Hex color string (e.g., '#FF0000')
Returns:
BGR tuple (B, G, R)
"""
hex_color = hex_color.lstrip("#")
if len(hex_color) != 6:
return (0, 255, 0) # Default green
try:
r = int(hex_color[0:2], 16)
g = int(hex_color[2:4], 16)
b = int(hex_color[4:6], 16)
return (b, g, r) # OpenCV uses BGR
except ValueError:
return (0, 255, 0) # Default green
def change_model(self, model_path: str, model_id: int) -> bool:
"""
Change the current model.
Args:
model_path: Path to new model weights
model_id: ID of new model in database
Returns:
True if successful, False otherwise
"""
try:
self.yolo = YOLOWrapper(model_path)
if self.yolo.load_model():
self.model_id = model_id
logger.info(f"Model changed to {model_path}")
return True
return False
except Exception as e:
logger.error(f"Error changing model: {e}")
return False