Adding python files
This commit is contained in:
323
src/model/inference.py
Normal file
323
src/model/inference.py
Normal file
@@ -0,0 +1,323 @@
|
||||
"""
|
||||
Inference engine for the microscopy object detection application.
|
||||
Handles detection inference and result storage.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Optional, Callable
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from src.model.yolo_wrapper import YOLOWrapper
|
||||
from src.database.db_manager import DatabaseManager
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.file_utils import get_relative_path
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class InferenceEngine:
|
||||
"""Handles detection inference and result storage."""
|
||||
|
||||
def __init__(self, model_path: str, db_manager: DatabaseManager, model_id: int):
|
||||
"""
|
||||
Initialize inference engine.
|
||||
|
||||
Args:
|
||||
model_path: Path to YOLO model weights
|
||||
db_manager: Database manager instance
|
||||
model_id: ID of the model in database
|
||||
"""
|
||||
self.yolo = YOLOWrapper(model_path)
|
||||
self.yolo.load_model()
|
||||
self.db_manager = db_manager
|
||||
self.model_id = model_id
|
||||
logger.info(f"InferenceEngine initialized with model_id {model_id}")
|
||||
|
||||
def detect_single(
|
||||
self,
|
||||
image_path: str,
|
||||
relative_path: str,
|
||||
conf: float = 0.25,
|
||||
save_to_db: bool = True,
|
||||
) -> Dict:
|
||||
"""
|
||||
Detect objects in a single image.
|
||||
|
||||
Args:
|
||||
image_path: Absolute path to image file
|
||||
relative_path: Relative path from repository root
|
||||
conf: Confidence threshold
|
||||
save_to_db: Whether to save results to database
|
||||
|
||||
Returns:
|
||||
Dictionary with detection results
|
||||
"""
|
||||
try:
|
||||
# Get image dimensions
|
||||
img = Image.open(image_path)
|
||||
width, height = img.size
|
||||
img.close()
|
||||
|
||||
# Perform detection
|
||||
detections = self.yolo.predict(image_path, conf=conf)
|
||||
|
||||
# Add/get image in database
|
||||
image_id = self.db_manager.get_or_create_image(
|
||||
relative_path=relative_path,
|
||||
filename=Path(image_path).name,
|
||||
width=width,
|
||||
height=height,
|
||||
)
|
||||
|
||||
# Save detections to database
|
||||
if save_to_db and detections:
|
||||
detection_records = []
|
||||
for det in detections:
|
||||
# Use normalized bbox from detection
|
||||
bbox_normalized = det[
|
||||
"bbox_normalized"
|
||||
] # [x_min, y_min, x_max, y_max]
|
||||
|
||||
record = {
|
||||
"image_id": image_id,
|
||||
"model_id": self.model_id,
|
||||
"class_name": det["class_name"],
|
||||
"bbox": tuple(bbox_normalized),
|
||||
"confidence": det["confidence"],
|
||||
"metadata": {"class_id": det["class_id"]},
|
||||
}
|
||||
detection_records.append(record)
|
||||
|
||||
self.db_manager.add_detections_batch(detection_records)
|
||||
logger.info(f"Saved {len(detection_records)} detections to database")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"image_path": image_path,
|
||||
"image_id": image_id,
|
||||
"detections": detections,
|
||||
"count": len(detections),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error detecting objects in {image_path}: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"image_path": image_path,
|
||||
"error": str(e),
|
||||
"detections": [],
|
||||
"count": 0,
|
||||
}
|
||||
|
||||
def detect_batch(
|
||||
self,
|
||||
image_paths: List[str],
|
||||
repository_root: str,
|
||||
conf: float = 0.25,
|
||||
progress_callback: Optional[Callable[[int, int, str], None]] = None,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Detect objects in multiple images.
|
||||
|
||||
Args:
|
||||
image_paths: List of absolute image paths
|
||||
repository_root: Root directory for relative paths
|
||||
conf: Confidence threshold
|
||||
progress_callback: Optional callback(current, total, message)
|
||||
|
||||
Returns:
|
||||
List of detection result dictionaries
|
||||
"""
|
||||
results = []
|
||||
total = len(image_paths)
|
||||
|
||||
logger.info(f"Starting batch detection on {total} images")
|
||||
|
||||
for i, image_path in enumerate(image_paths, 1):
|
||||
# Calculate relative path
|
||||
rel_path = get_relative_path(image_path, repository_root)
|
||||
|
||||
# Perform detection
|
||||
result = self.detect_single(image_path, rel_path, conf)
|
||||
results.append(result)
|
||||
|
||||
# Update progress
|
||||
if progress_callback:
|
||||
progress_callback(i, total, f"Processed {rel_path}")
|
||||
|
||||
if i % 10 == 0:
|
||||
logger.info(f"Processed {i}/{total} images")
|
||||
|
||||
logger.info(f"Batch detection complete: {total} images processed")
|
||||
return results
|
||||
|
||||
def detect_with_visualization(
|
||||
self,
|
||||
image_path: str,
|
||||
conf: float = 0.25,
|
||||
bbox_thickness: int = 2,
|
||||
bbox_colors: Optional[Dict[str, str]] = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Detect objects and return annotated image.
|
||||
|
||||
Args:
|
||||
image_path: Path to image
|
||||
conf: Confidence threshold
|
||||
bbox_thickness: Thickness of bounding boxes
|
||||
bbox_colors: Dictionary mapping class names to hex colors
|
||||
|
||||
Returns:
|
||||
Tuple of (detections, annotated_image_array)
|
||||
"""
|
||||
try:
|
||||
detections = self.yolo.predict(image_path, conf=conf)
|
||||
|
||||
# Load image
|
||||
img = cv2.imread(image_path)
|
||||
if img is None:
|
||||
raise ValueError(f"Failed to load image: {image_path}")
|
||||
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
height, width = img.shape[:2]
|
||||
|
||||
# Default colors if not provided
|
||||
if bbox_colors is None:
|
||||
bbox_colors = {}
|
||||
default_color = self._hex_to_bgr(bbox_colors.get("default", "#00FF00"))
|
||||
|
||||
# Draw bounding boxes
|
||||
for det in detections:
|
||||
# Get absolute coordinates
|
||||
bbox_abs = det["bbox_absolute"]
|
||||
x1, y1, x2, y2 = [int(v) for v in bbox_abs]
|
||||
|
||||
# Get color for this class
|
||||
class_name = det["class_name"]
|
||||
color_hex = bbox_colors.get(
|
||||
class_name, bbox_colors.get("default", "#00FF00")
|
||||
)
|
||||
color = self._hex_to_bgr(color_hex)
|
||||
|
||||
# Draw box
|
||||
cv2.rectangle(img, (x1, y1), (x2, y2), color, bbox_thickness)
|
||||
|
||||
# Prepare label
|
||||
label = f"{class_name} {det['confidence']:.2f}"
|
||||
|
||||
# Draw label background
|
||||
(label_w, label_h), baseline = cv2.getTextSize(
|
||||
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
|
||||
)
|
||||
cv2.rectangle(
|
||||
img,
|
||||
(x1, y1 - label_h - baseline - 5),
|
||||
(x1 + label_w, y1),
|
||||
color,
|
||||
-1,
|
||||
)
|
||||
|
||||
# Draw label text
|
||||
cv2.putText(
|
||||
img,
|
||||
label,
|
||||
(x1, y1 - baseline - 5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
(255, 255, 255),
|
||||
1,
|
||||
)
|
||||
|
||||
return detections, img
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating visualization: {e}")
|
||||
# Return empty detections and original image if possible
|
||||
try:
|
||||
img = cv2.imread(image_path)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
return [], img
|
||||
except:
|
||||
return [], np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
|
||||
def get_detection_summary(self, detections: List[Dict]) -> Dict[str, any]:
|
||||
"""
|
||||
Generate summary statistics for detections.
|
||||
|
||||
Args:
|
||||
detections: List of detection dictionaries
|
||||
|
||||
Returns:
|
||||
Dictionary with summary statistics
|
||||
"""
|
||||
if not detections:
|
||||
return {
|
||||
"total_count": 0,
|
||||
"class_counts": {},
|
||||
"avg_confidence": 0.0,
|
||||
"confidence_range": (0.0, 0.0),
|
||||
}
|
||||
|
||||
# Count by class
|
||||
class_counts = {}
|
||||
confidences = []
|
||||
|
||||
for det in detections:
|
||||
class_name = det["class_name"]
|
||||
class_counts[class_name] = class_counts.get(class_name, 0) + 1
|
||||
confidences.append(det["confidence"])
|
||||
|
||||
return {
|
||||
"total_count": len(detections),
|
||||
"class_counts": class_counts,
|
||||
"avg_confidence": sum(confidences) / len(confidences),
|
||||
"confidence_range": (min(confidences), max(confidences)),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _hex_to_bgr(hex_color: str) -> tuple:
|
||||
"""
|
||||
Convert hex color to BGR tuple.
|
||||
|
||||
Args:
|
||||
hex_color: Hex color string (e.g., '#FF0000')
|
||||
|
||||
Returns:
|
||||
BGR tuple (B, G, R)
|
||||
"""
|
||||
hex_color = hex_color.lstrip("#")
|
||||
if len(hex_color) != 6:
|
||||
return (0, 255, 0) # Default green
|
||||
|
||||
try:
|
||||
r = int(hex_color[0:2], 16)
|
||||
g = int(hex_color[2:4], 16)
|
||||
b = int(hex_color[4:6], 16)
|
||||
return (b, g, r) # OpenCV uses BGR
|
||||
except ValueError:
|
||||
return (0, 255, 0) # Default green
|
||||
|
||||
def change_model(self, model_path: str, model_id: int) -> bool:
|
||||
"""
|
||||
Change the current model.
|
||||
|
||||
Args:
|
||||
model_path: Path to new model weights
|
||||
model_id: ID of new model in database
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
self.yolo = YOLOWrapper(model_path)
|
||||
if self.yolo.load_model():
|
||||
self.model_id = model_id
|
||||
logger.info(f"Model changed to {model_path}")
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error changing model: {e}")
|
||||
return False
|
||||
Reference in New Issue
Block a user