385 lines
13 KiB
Python
385 lines
13 KiB
Python
"""
|
|
Inference engine for the microscopy object detection application.
|
|
Handles detection inference and result storage.
|
|
"""
|
|
|
|
from typing import List, Dict, Optional, Callable
|
|
from pathlib import Path
|
|
import cv2
|
|
import numpy as np
|
|
|
|
from src.model.yolo_wrapper import YOLOWrapper
|
|
from src.database.db_manager import DatabaseManager
|
|
from src.utils.image import Image
|
|
from src.utils.logger import get_logger
|
|
from src.utils.file_utils import get_relative_path
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class InferenceEngine:
|
|
"""Handles detection inference and result storage."""
|
|
|
|
def __init__(self, model_path: str, db_manager: DatabaseManager, model_id: int):
|
|
"""
|
|
Initialize inference engine.
|
|
|
|
Args:
|
|
model_path: Path to YOLO model weights
|
|
db_manager: Database manager instance
|
|
model_id: ID of the model in database
|
|
"""
|
|
self.yolo = YOLOWrapper(model_path)
|
|
self.yolo.load_model()
|
|
self.db_manager = db_manager
|
|
self.model_id = model_id
|
|
logger.info(f"InferenceEngine initialized with model_id {model_id}")
|
|
|
|
def detect_single(
|
|
self,
|
|
image_path: str,
|
|
relative_path: str,
|
|
conf: float = 0.25,
|
|
save_to_db: bool = True,
|
|
repository_root: Optional[str] = None,
|
|
) -> Dict:
|
|
"""
|
|
Detect objects in a single image.
|
|
|
|
Args:
|
|
image_path: Absolute path to image file
|
|
relative_path: Relative path from repository root
|
|
conf: Confidence threshold
|
|
save_to_db: Whether to save results to database
|
|
repository_root: Base directory used to compute relative_path (if known)
|
|
|
|
Returns:
|
|
Dictionary with detection results
|
|
"""
|
|
try:
|
|
# Normalize storage path (fall back to absolute path when repo root is unknown)
|
|
stored_relative_path = relative_path
|
|
if not repository_root:
|
|
stored_relative_path = str(Path(image_path).resolve())
|
|
|
|
# Get image dimensions
|
|
img = Image(image_path)
|
|
width = img.width
|
|
height = img.height
|
|
|
|
# Perform detection
|
|
detections = self.yolo.predict(image_path, conf=conf)
|
|
|
|
# Add/get image in database
|
|
image_id = self.db_manager.get_or_create_image(
|
|
relative_path=stored_relative_path,
|
|
filename=Path(image_path).name,
|
|
width=width,
|
|
height=height,
|
|
)
|
|
|
|
inserted_count = 0
|
|
deleted_count = 0
|
|
|
|
# Save detections to database, replacing any previous results for this image/model
|
|
if save_to_db:
|
|
deleted_count = self.db_manager.delete_detections_for_image(
|
|
image_id, self.model_id
|
|
)
|
|
if detections:
|
|
detection_records = []
|
|
for det in detections:
|
|
# Use normalized bbox from detection
|
|
bbox_normalized = det[
|
|
"bbox_normalized"
|
|
] # [x_min, y_min, x_max, y_max]
|
|
|
|
metadata = {
|
|
"class_id": det["class_id"],
|
|
"source_path": str(Path(image_path).resolve()),
|
|
}
|
|
if repository_root:
|
|
metadata["repository_root"] = str(
|
|
Path(repository_root).resolve()
|
|
)
|
|
|
|
record = {
|
|
"image_id": image_id,
|
|
"model_id": self.model_id,
|
|
"class_name": det["class_name"],
|
|
"bbox": tuple(bbox_normalized),
|
|
"confidence": det["confidence"],
|
|
"segmentation_mask": det.get("segmentation_mask"),
|
|
"metadata": metadata,
|
|
}
|
|
detection_records.append(record)
|
|
|
|
inserted_count = self.db_manager.add_detections_batch(
|
|
detection_records
|
|
)
|
|
logger.info(
|
|
f"Saved {inserted_count} detections to database (replaced {deleted_count})"
|
|
)
|
|
else:
|
|
logger.info(
|
|
f"Detection run removed {deleted_count} stale entries but produced no new detections"
|
|
)
|
|
|
|
return {
|
|
"success": True,
|
|
"image_path": image_path,
|
|
"image_id": image_id,
|
|
"detections": detections,
|
|
"count": len(detections),
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error detecting objects in {image_path}: {e}")
|
|
return {
|
|
"success": False,
|
|
"image_path": image_path,
|
|
"error": str(e),
|
|
"detections": [],
|
|
"count": 0,
|
|
}
|
|
|
|
def detect_batch(
|
|
self,
|
|
image_paths: List[str],
|
|
repository_root: str,
|
|
conf: float = 0.25,
|
|
progress_callback: Optional[Callable[[int, int, str], None]] = None,
|
|
) -> List[Dict]:
|
|
"""
|
|
Detect objects in multiple images.
|
|
|
|
Args:
|
|
image_paths: List of absolute image paths
|
|
repository_root: Root directory for relative paths
|
|
conf: Confidence threshold
|
|
progress_callback: Optional callback(current, total, message)
|
|
|
|
Returns:
|
|
List of detection result dictionaries
|
|
"""
|
|
results = []
|
|
total = len(image_paths)
|
|
|
|
logger.info(f"Starting batch detection on {total} images")
|
|
|
|
for i, image_path in enumerate(image_paths, 1):
|
|
# Calculate relative path
|
|
rel_path = get_relative_path(image_path, repository_root)
|
|
|
|
# Perform detection
|
|
result = self.detect_single(
|
|
image_path,
|
|
rel_path,
|
|
conf=conf,
|
|
repository_root=repository_root,
|
|
)
|
|
results.append(result)
|
|
|
|
# Update progress
|
|
if progress_callback:
|
|
progress_callback(i, total, f"Processed {rel_path}")
|
|
|
|
if i % 10 == 0:
|
|
logger.info(f"Processed {i}/{total} images")
|
|
|
|
logger.info(f"Batch detection complete: {total} images processed")
|
|
return results
|
|
|
|
def detect_with_visualization(
|
|
self,
|
|
image_path: str,
|
|
conf: float = 0.25,
|
|
bbox_thickness: int = 2,
|
|
bbox_colors: Optional[Dict[str, str]] = None,
|
|
draw_masks: bool = True,
|
|
) -> tuple:
|
|
"""
|
|
Detect objects and return annotated image.
|
|
|
|
Args:
|
|
image_path: Path to image
|
|
conf: Confidence threshold
|
|
bbox_thickness: Thickness of bounding boxes
|
|
bbox_colors: Dictionary mapping class names to hex colors
|
|
draw_masks: Whether to draw segmentation masks (if available)
|
|
|
|
Returns:
|
|
Tuple of (detections, annotated_image_array)
|
|
"""
|
|
try:
|
|
detections = self.yolo.predict(image_path, conf=conf)
|
|
|
|
# Load image
|
|
img = cv2.imread(image_path)
|
|
if img is None:
|
|
raise ValueError(f"Failed to load image: {image_path}")
|
|
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
height, width = img.shape[:2]
|
|
|
|
# Default colors if not provided
|
|
if bbox_colors is None:
|
|
bbox_colors = {}
|
|
default_color = self._hex_to_bgr(bbox_colors.get("default", "#00FF00"))
|
|
|
|
# Draw detections
|
|
for det in detections:
|
|
# Get color for this class
|
|
class_name = det["class_name"]
|
|
color_hex = bbox_colors.get(
|
|
class_name, bbox_colors.get("default", "#00FF00")
|
|
)
|
|
color = self._hex_to_bgr(color_hex)
|
|
|
|
# Draw segmentation mask if available and requested
|
|
if draw_masks and det.get("segmentation_mask"):
|
|
mask_normalized = det["segmentation_mask"]
|
|
if mask_normalized and len(mask_normalized) > 0:
|
|
# Convert normalized coordinates to absolute pixels
|
|
mask_points = np.array(
|
|
[
|
|
[int(pt[0] * width), int(pt[1] * height)]
|
|
for pt in mask_normalized
|
|
],
|
|
dtype=np.int32,
|
|
)
|
|
|
|
# Create a semi-transparent overlay
|
|
overlay = img.copy()
|
|
cv2.fillPoly(overlay, [mask_points], color)
|
|
# Blend with original image (30% opacity)
|
|
cv2.addWeighted(overlay, 0.3, img, 0.7, 0, img)
|
|
|
|
# Draw mask contour
|
|
cv2.polylines(img, [mask_points], True, color, bbox_thickness)
|
|
|
|
# Get absolute coordinates for bounding box
|
|
bbox_abs = det["bbox_absolute"]
|
|
x1, y1, x2, y2 = [int(v) for v in bbox_abs]
|
|
|
|
# Draw bounding box
|
|
cv2.rectangle(img, (x1, y1), (x2, y2), color, bbox_thickness)
|
|
|
|
# Prepare label
|
|
label = f"{class_name} {det['confidence']:.2f}"
|
|
|
|
# Draw label background
|
|
(label_w, label_h), baseline = cv2.getTextSize(
|
|
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
|
|
)
|
|
cv2.rectangle(
|
|
img,
|
|
(x1, y1 - label_h - baseline - 5),
|
|
(x1 + label_w, y1),
|
|
color,
|
|
-1,
|
|
)
|
|
|
|
# Draw label text
|
|
cv2.putText(
|
|
img,
|
|
label,
|
|
(x1, y1 - baseline - 5),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
0.5,
|
|
(255, 255, 255),
|
|
1,
|
|
)
|
|
|
|
return detections, img
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating visualization: {e}")
|
|
# Return empty detections and original image if possible
|
|
try:
|
|
img = cv2.imread(image_path)
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
return [], img
|
|
except:
|
|
return [], np.zeros((480, 640, 3), dtype=np.uint8)
|
|
|
|
def get_detection_summary(self, detections: List[Dict]) -> Dict[str, any]:
|
|
"""
|
|
Generate summary statistics for detections.
|
|
|
|
Args:
|
|
detections: List of detection dictionaries
|
|
|
|
Returns:
|
|
Dictionary with summary statistics
|
|
"""
|
|
if not detections:
|
|
return {
|
|
"total_count": 0,
|
|
"class_counts": {},
|
|
"avg_confidence": 0.0,
|
|
"confidence_range": (0.0, 0.0),
|
|
}
|
|
|
|
# Count by class
|
|
class_counts = {}
|
|
confidences = []
|
|
|
|
for det in detections:
|
|
class_name = det["class_name"]
|
|
class_counts[class_name] = class_counts.get(class_name, 0) + 1
|
|
confidences.append(det["confidence"])
|
|
|
|
return {
|
|
"total_count": len(detections),
|
|
"class_counts": class_counts,
|
|
"avg_confidence": sum(confidences) / len(confidences),
|
|
"confidence_range": (min(confidences), max(confidences)),
|
|
}
|
|
|
|
@staticmethod
|
|
def _hex_to_bgr(hex_color: str) -> tuple:
|
|
"""
|
|
Convert hex color to BGR tuple.
|
|
|
|
Args:
|
|
hex_color: Hex color string (e.g., '#FF0000')
|
|
|
|
Returns:
|
|
BGR tuple (B, G, R)
|
|
"""
|
|
hex_color = hex_color.lstrip("#")
|
|
if len(hex_color) != 6:
|
|
return (0, 255, 0) # Default green
|
|
|
|
try:
|
|
r = int(hex_color[0:2], 16)
|
|
g = int(hex_color[2:4], 16)
|
|
b = int(hex_color[4:6], 16)
|
|
return (b, g, r) # OpenCV uses BGR
|
|
except ValueError:
|
|
return (0, 255, 0) # Default green
|
|
|
|
def change_model(self, model_path: str, model_id: int) -> bool:
|
|
"""
|
|
Change the current model.
|
|
|
|
Args:
|
|
model_path: Path to new model weights
|
|
model_id: ID of new model in database
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
try:
|
|
self.yolo = YOLOWrapper(model_path)
|
|
if self.yolo.load_model():
|
|
self.model_id = model_id
|
|
logger.info(f"Model changed to {model_path}")
|
|
return True
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Error changing model: {e}")
|
|
return False
|