Adding python files
This commit is contained in:
0
src/model/__init__.py
Normal file
0
src/model/__init__.py
Normal file
323
src/model/inference.py
Normal file
323
src/model/inference.py
Normal file
@@ -0,0 +1,323 @@
|
||||
"""
|
||||
Inference engine for the microscopy object detection application.
|
||||
Handles detection inference and result storage.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Optional, Callable
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from src.model.yolo_wrapper import YOLOWrapper
|
||||
from src.database.db_manager import DatabaseManager
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.file_utils import get_relative_path
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class InferenceEngine:
|
||||
"""Handles detection inference and result storage."""
|
||||
|
||||
def __init__(self, model_path: str, db_manager: DatabaseManager, model_id: int):
|
||||
"""
|
||||
Initialize inference engine.
|
||||
|
||||
Args:
|
||||
model_path: Path to YOLO model weights
|
||||
db_manager: Database manager instance
|
||||
model_id: ID of the model in database
|
||||
"""
|
||||
self.yolo = YOLOWrapper(model_path)
|
||||
self.yolo.load_model()
|
||||
self.db_manager = db_manager
|
||||
self.model_id = model_id
|
||||
logger.info(f"InferenceEngine initialized with model_id {model_id}")
|
||||
|
||||
def detect_single(
|
||||
self,
|
||||
image_path: str,
|
||||
relative_path: str,
|
||||
conf: float = 0.25,
|
||||
save_to_db: bool = True,
|
||||
) -> Dict:
|
||||
"""
|
||||
Detect objects in a single image.
|
||||
|
||||
Args:
|
||||
image_path: Absolute path to image file
|
||||
relative_path: Relative path from repository root
|
||||
conf: Confidence threshold
|
||||
save_to_db: Whether to save results to database
|
||||
|
||||
Returns:
|
||||
Dictionary with detection results
|
||||
"""
|
||||
try:
|
||||
# Get image dimensions
|
||||
img = Image.open(image_path)
|
||||
width, height = img.size
|
||||
img.close()
|
||||
|
||||
# Perform detection
|
||||
detections = self.yolo.predict(image_path, conf=conf)
|
||||
|
||||
# Add/get image in database
|
||||
image_id = self.db_manager.get_or_create_image(
|
||||
relative_path=relative_path,
|
||||
filename=Path(image_path).name,
|
||||
width=width,
|
||||
height=height,
|
||||
)
|
||||
|
||||
# Save detections to database
|
||||
if save_to_db and detections:
|
||||
detection_records = []
|
||||
for det in detections:
|
||||
# Use normalized bbox from detection
|
||||
bbox_normalized = det[
|
||||
"bbox_normalized"
|
||||
] # [x_min, y_min, x_max, y_max]
|
||||
|
||||
record = {
|
||||
"image_id": image_id,
|
||||
"model_id": self.model_id,
|
||||
"class_name": det["class_name"],
|
||||
"bbox": tuple(bbox_normalized),
|
||||
"confidence": det["confidence"],
|
||||
"metadata": {"class_id": det["class_id"]},
|
||||
}
|
||||
detection_records.append(record)
|
||||
|
||||
self.db_manager.add_detections_batch(detection_records)
|
||||
logger.info(f"Saved {len(detection_records)} detections to database")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"image_path": image_path,
|
||||
"image_id": image_id,
|
||||
"detections": detections,
|
||||
"count": len(detections),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error detecting objects in {image_path}: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"image_path": image_path,
|
||||
"error": str(e),
|
||||
"detections": [],
|
||||
"count": 0,
|
||||
}
|
||||
|
||||
def detect_batch(
|
||||
self,
|
||||
image_paths: List[str],
|
||||
repository_root: str,
|
||||
conf: float = 0.25,
|
||||
progress_callback: Optional[Callable[[int, int, str], None]] = None,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Detect objects in multiple images.
|
||||
|
||||
Args:
|
||||
image_paths: List of absolute image paths
|
||||
repository_root: Root directory for relative paths
|
||||
conf: Confidence threshold
|
||||
progress_callback: Optional callback(current, total, message)
|
||||
|
||||
Returns:
|
||||
List of detection result dictionaries
|
||||
"""
|
||||
results = []
|
||||
total = len(image_paths)
|
||||
|
||||
logger.info(f"Starting batch detection on {total} images")
|
||||
|
||||
for i, image_path in enumerate(image_paths, 1):
|
||||
# Calculate relative path
|
||||
rel_path = get_relative_path(image_path, repository_root)
|
||||
|
||||
# Perform detection
|
||||
result = self.detect_single(image_path, rel_path, conf)
|
||||
results.append(result)
|
||||
|
||||
# Update progress
|
||||
if progress_callback:
|
||||
progress_callback(i, total, f"Processed {rel_path}")
|
||||
|
||||
if i % 10 == 0:
|
||||
logger.info(f"Processed {i}/{total} images")
|
||||
|
||||
logger.info(f"Batch detection complete: {total} images processed")
|
||||
return results
|
||||
|
||||
def detect_with_visualization(
|
||||
self,
|
||||
image_path: str,
|
||||
conf: float = 0.25,
|
||||
bbox_thickness: int = 2,
|
||||
bbox_colors: Optional[Dict[str, str]] = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Detect objects and return annotated image.
|
||||
|
||||
Args:
|
||||
image_path: Path to image
|
||||
conf: Confidence threshold
|
||||
bbox_thickness: Thickness of bounding boxes
|
||||
bbox_colors: Dictionary mapping class names to hex colors
|
||||
|
||||
Returns:
|
||||
Tuple of (detections, annotated_image_array)
|
||||
"""
|
||||
try:
|
||||
detections = self.yolo.predict(image_path, conf=conf)
|
||||
|
||||
# Load image
|
||||
img = cv2.imread(image_path)
|
||||
if img is None:
|
||||
raise ValueError(f"Failed to load image: {image_path}")
|
||||
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
height, width = img.shape[:2]
|
||||
|
||||
# Default colors if not provided
|
||||
if bbox_colors is None:
|
||||
bbox_colors = {}
|
||||
default_color = self._hex_to_bgr(bbox_colors.get("default", "#00FF00"))
|
||||
|
||||
# Draw bounding boxes
|
||||
for det in detections:
|
||||
# Get absolute coordinates
|
||||
bbox_abs = det["bbox_absolute"]
|
||||
x1, y1, x2, y2 = [int(v) for v in bbox_abs]
|
||||
|
||||
# Get color for this class
|
||||
class_name = det["class_name"]
|
||||
color_hex = bbox_colors.get(
|
||||
class_name, bbox_colors.get("default", "#00FF00")
|
||||
)
|
||||
color = self._hex_to_bgr(color_hex)
|
||||
|
||||
# Draw box
|
||||
cv2.rectangle(img, (x1, y1), (x2, y2), color, bbox_thickness)
|
||||
|
||||
# Prepare label
|
||||
label = f"{class_name} {det['confidence']:.2f}"
|
||||
|
||||
# Draw label background
|
||||
(label_w, label_h), baseline = cv2.getTextSize(
|
||||
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
|
||||
)
|
||||
cv2.rectangle(
|
||||
img,
|
||||
(x1, y1 - label_h - baseline - 5),
|
||||
(x1 + label_w, y1),
|
||||
color,
|
||||
-1,
|
||||
)
|
||||
|
||||
# Draw label text
|
||||
cv2.putText(
|
||||
img,
|
||||
label,
|
||||
(x1, y1 - baseline - 5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
(255, 255, 255),
|
||||
1,
|
||||
)
|
||||
|
||||
return detections, img
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating visualization: {e}")
|
||||
# Return empty detections and original image if possible
|
||||
try:
|
||||
img = cv2.imread(image_path)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
return [], img
|
||||
except:
|
||||
return [], np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
|
||||
def get_detection_summary(self, detections: List[Dict]) -> Dict[str, any]:
|
||||
"""
|
||||
Generate summary statistics for detections.
|
||||
|
||||
Args:
|
||||
detections: List of detection dictionaries
|
||||
|
||||
Returns:
|
||||
Dictionary with summary statistics
|
||||
"""
|
||||
if not detections:
|
||||
return {
|
||||
"total_count": 0,
|
||||
"class_counts": {},
|
||||
"avg_confidence": 0.0,
|
||||
"confidence_range": (0.0, 0.0),
|
||||
}
|
||||
|
||||
# Count by class
|
||||
class_counts = {}
|
||||
confidences = []
|
||||
|
||||
for det in detections:
|
||||
class_name = det["class_name"]
|
||||
class_counts[class_name] = class_counts.get(class_name, 0) + 1
|
||||
confidences.append(det["confidence"])
|
||||
|
||||
return {
|
||||
"total_count": len(detections),
|
||||
"class_counts": class_counts,
|
||||
"avg_confidence": sum(confidences) / len(confidences),
|
||||
"confidence_range": (min(confidences), max(confidences)),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _hex_to_bgr(hex_color: str) -> tuple:
|
||||
"""
|
||||
Convert hex color to BGR tuple.
|
||||
|
||||
Args:
|
||||
hex_color: Hex color string (e.g., '#FF0000')
|
||||
|
||||
Returns:
|
||||
BGR tuple (B, G, R)
|
||||
"""
|
||||
hex_color = hex_color.lstrip("#")
|
||||
if len(hex_color) != 6:
|
||||
return (0, 255, 0) # Default green
|
||||
|
||||
try:
|
||||
r = int(hex_color[0:2], 16)
|
||||
g = int(hex_color[2:4], 16)
|
||||
b = int(hex_color[4:6], 16)
|
||||
return (b, g, r) # OpenCV uses BGR
|
||||
except ValueError:
|
||||
return (0, 255, 0) # Default green
|
||||
|
||||
def change_model(self, model_path: str, model_id: int) -> bool:
|
||||
"""
|
||||
Change the current model.
|
||||
|
||||
Args:
|
||||
model_path: Path to new model weights
|
||||
model_id: ID of new model in database
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
self.yolo = YOLOWrapper(model_path)
|
||||
if self.yolo.load_model():
|
||||
self.model_id = model_id
|
||||
logger.info(f"Model changed to {model_path}")
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error changing model: {e}")
|
||||
return False
|
||||
364
src/model/yolo_wrapper.py
Normal file
364
src/model/yolo_wrapper.py
Normal file
@@ -0,0 +1,364 @@
|
||||
"""
|
||||
YOLO model wrapper for the microscopy object detection application.
|
||||
Provides a clean interface to YOLOv8 for training, validation, and inference.
|
||||
"""
|
||||
|
||||
from ultralytics import YOLO
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Callable, Any
|
||||
import torch
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class YOLOWrapper:
|
||||
"""Wrapper for YOLOv8 model operations."""
|
||||
|
||||
def __init__(self, model_path: str = "yolov8s.pt"):
|
||||
"""
|
||||
Initialize YOLO model.
|
||||
|
||||
Args:
|
||||
model_path: Path to model weights (.pt file)
|
||||
"""
|
||||
self.model_path = model_path
|
||||
self.model = None
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"YOLOWrapper initialized with device: {self.device}")
|
||||
|
||||
def load_model(self) -> bool:
|
||||
"""
|
||||
Load YOLO model from path.
|
||||
|
||||
Returns:
|
||||
True if loaded successfully, False otherwise
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Loading YOLO model from {self.model_path}")
|
||||
self.model = YOLO(self.model_path)
|
||||
self.model.to(self.device)
|
||||
logger.info("Model loaded successfully")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading model: {e}")
|
||||
return False
|
||||
|
||||
def train(
|
||||
self,
|
||||
data_yaml: str,
|
||||
epochs: int = 100,
|
||||
imgsz: int = 640,
|
||||
batch: int = 16,
|
||||
patience: int = 50,
|
||||
save_dir: str = "data/models",
|
||||
name: str = "custom_model",
|
||||
resume: bool = False,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Train the YOLO model.
|
||||
|
||||
Args:
|
||||
data_yaml: Path to data.yaml configuration file
|
||||
epochs: Number of training epochs
|
||||
imgsz: Input image size
|
||||
batch: Batch size
|
||||
patience: Early stopping patience
|
||||
save_dir: Directory to save trained model
|
||||
name: Name for the training run
|
||||
resume: Resume training from last checkpoint
|
||||
**kwargs: Additional training arguments
|
||||
|
||||
Returns:
|
||||
Dictionary with training results
|
||||
"""
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
|
||||
try:
|
||||
logger.info(f"Starting training: {name}")
|
||||
logger.info(
|
||||
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
|
||||
)
|
||||
|
||||
# Train the model
|
||||
results = self.model.train(
|
||||
data=data_yaml,
|
||||
epochs=epochs,
|
||||
imgsz=imgsz,
|
||||
batch=batch,
|
||||
patience=patience,
|
||||
project=save_dir,
|
||||
name=name,
|
||||
device=self.device,
|
||||
resume=resume,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
logger.info("Training completed successfully")
|
||||
return self._format_training_results(results)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during training: {e}")
|
||||
raise
|
||||
|
||||
def validate(self, data_yaml: str, split: str = "val", **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate the model.
|
||||
|
||||
Args:
|
||||
data_yaml: Path to data.yaml configuration file
|
||||
split: Dataset split to validate on ('val' or 'test')
|
||||
**kwargs: Additional validation arguments
|
||||
|
||||
Returns:
|
||||
Dictionary with validation metrics
|
||||
"""
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
|
||||
try:
|
||||
logger.info(f"Starting validation on {split} split")
|
||||
results = self.model.val(
|
||||
data=data_yaml, split=split, device=self.device, **kwargs
|
||||
)
|
||||
|
||||
logger.info("Validation completed successfully")
|
||||
return self._format_validation_results(results)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during validation: {e}")
|
||||
raise
|
||||
|
||||
def predict(
|
||||
self,
|
||||
source: str,
|
||||
conf: float = 0.25,
|
||||
iou: float = 0.45,
|
||||
save: bool = False,
|
||||
save_txt: bool = False,
|
||||
save_conf: bool = False,
|
||||
**kwargs,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Perform inference on image(s).
|
||||
|
||||
Args:
|
||||
source: Path to image or directory
|
||||
conf: Confidence threshold
|
||||
iou: IoU threshold for NMS
|
||||
save: Whether to save annotated images
|
||||
save_txt: Whether to save labels to .txt files
|
||||
save_conf: Whether to save confidence in labels
|
||||
**kwargs: Additional prediction arguments
|
||||
|
||||
Returns:
|
||||
List of detection dictionaries
|
||||
"""
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
|
||||
try:
|
||||
logger.info(f"Running inference on {source}")
|
||||
results = self.model.predict(
|
||||
source=source,
|
||||
conf=conf,
|
||||
iou=iou,
|
||||
save=save,
|
||||
save_txt=save_txt,
|
||||
save_conf=save_conf,
|
||||
device=self.device,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
detections = self._format_prediction_results(results)
|
||||
logger.info(f"Inference complete: {len(detections)} detections")
|
||||
return detections
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during inference: {e}")
|
||||
raise
|
||||
|
||||
def export(
|
||||
self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Export model to different format.
|
||||
|
||||
Args:
|
||||
format: Export format (onnx, torchscript, tflite, etc.)
|
||||
output_path: Path for exported model
|
||||
**kwargs: Additional export arguments
|
||||
|
||||
Returns:
|
||||
Path to exported model
|
||||
"""
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
|
||||
try:
|
||||
logger.info(f"Exporting model to {format} format")
|
||||
export_path = self.model.export(format=format, **kwargs)
|
||||
logger.info(f"Model exported to {export_path}")
|
||||
return str(export_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error exporting model: {e}")
|
||||
raise
|
||||
|
||||
def _format_training_results(self, results) -> Dict[str, Any]:
|
||||
"""Format training results into dictionary."""
|
||||
try:
|
||||
# Get the results dict
|
||||
results_dict = (
|
||||
results.results_dict if hasattr(results, "results_dict") else {}
|
||||
)
|
||||
|
||||
formatted = {
|
||||
"success": True,
|
||||
"final_epoch": getattr(results, "epoch", 0),
|
||||
"metrics": {
|
||||
"mAP50": float(results_dict.get("metrics/mAP50(B)", 0)),
|
||||
"mAP50-95": float(results_dict.get("metrics/mAP50-95(B)", 0)),
|
||||
"precision": float(results_dict.get("metrics/precision(B)", 0)),
|
||||
"recall": float(results_dict.get("metrics/recall(B)", 0)),
|
||||
},
|
||||
"best_model_path": str(Path(results.save_dir) / "weights" / "best.pt"),
|
||||
"last_model_path": str(Path(results.save_dir) / "weights" / "last.pt"),
|
||||
"save_dir": str(results.save_dir),
|
||||
}
|
||||
|
||||
return formatted
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting training results: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
def _format_validation_results(self, results) -> Dict[str, Any]:
|
||||
"""Format validation results into dictionary."""
|
||||
try:
|
||||
box_metrics = results.box
|
||||
|
||||
formatted = {
|
||||
"success": True,
|
||||
"mAP50": float(box_metrics.map50),
|
||||
"mAP50-95": float(box_metrics.map),
|
||||
"precision": float(box_metrics.mp),
|
||||
"recall": float(box_metrics.mr),
|
||||
"fitness": (
|
||||
float(results.fitness) if hasattr(results, "fitness") else 0.0
|
||||
),
|
||||
}
|
||||
|
||||
# Add per-class metrics if available
|
||||
if hasattr(box_metrics, "ap") and hasattr(results, "names"):
|
||||
class_metrics = {}
|
||||
for idx, name in results.names.items():
|
||||
if idx < len(box_metrics.ap):
|
||||
class_metrics[name] = {
|
||||
"ap": float(box_metrics.ap[idx]),
|
||||
"ap50": (
|
||||
float(box_metrics.ap50[idx])
|
||||
if hasattr(box_metrics, "ap50")
|
||||
else 0.0
|
||||
),
|
||||
}
|
||||
formatted["class_metrics"] = class_metrics
|
||||
|
||||
return formatted
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting validation results: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
def _format_prediction_results(self, results) -> List[Dict]:
|
||||
"""Format prediction results into list of dictionaries."""
|
||||
detections = []
|
||||
|
||||
try:
|
||||
for result in results:
|
||||
boxes = result.boxes
|
||||
image_path = str(result.path)
|
||||
orig_shape = result.orig_shape # (height, width)
|
||||
|
||||
for i in range(len(boxes)):
|
||||
# Get normalized coordinates
|
||||
xyxyn = boxes.xyxyn[i].cpu().numpy() # Normalized [x1, y1, x2, y2]
|
||||
|
||||
detection = {
|
||||
"image_path": image_path,
|
||||
"class_id": int(boxes.cls[i]),
|
||||
"class_name": result.names[int(boxes.cls[i])],
|
||||
"confidence": float(boxes.conf[i]),
|
||||
"bbox_normalized": [
|
||||
float(v) for v in xyxyn
|
||||
], # [x_min, y_min, x_max, y_max]
|
||||
"bbox_absolute": [
|
||||
float(v) for v in boxes.xyxy[i].cpu().numpy()
|
||||
], # Absolute pixels
|
||||
}
|
||||
detections.append(detection)
|
||||
|
||||
return detections
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting prediction results: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def convert_bbox_format(
|
||||
bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy"
|
||||
) -> List[float]:
|
||||
"""
|
||||
Convert bounding box between formats.
|
||||
|
||||
Formats:
|
||||
- xywh: [x_center, y_center, width, height]
|
||||
- xyxy: [x_min, y_min, x_max, y_max]
|
||||
|
||||
Args:
|
||||
bbox: Bounding box coordinates
|
||||
format_from: Source format
|
||||
format_to: Target format
|
||||
|
||||
Returns:
|
||||
Converted bounding box
|
||||
"""
|
||||
if format_from == "xywh" and format_to == "xyxy":
|
||||
x, y, w, h = bbox
|
||||
return [x - w / 2, y - h / 2, x + w / 2, y + h / 2]
|
||||
elif format_from == "xyxy" and format_to == "xywh":
|
||||
x1, y1, x2, y2 = bbox
|
||||
return [(x1 + x2) / 2, (y1 + y2) / 2, x2 - x1, y2 - y1]
|
||||
else:
|
||||
return bbox
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about the loaded model.
|
||||
|
||||
Returns:
|
||||
Dictionary with model information
|
||||
"""
|
||||
if self.model is None:
|
||||
return {"error": "Model not loaded"}
|
||||
|
||||
try:
|
||||
info = {
|
||||
"model_path": self.model_path,
|
||||
"device": self.device,
|
||||
"task": getattr(self.model, "task", "unknown"),
|
||||
}
|
||||
|
||||
# Try to get class names
|
||||
if hasattr(self.model, "names"):
|
||||
info["classes"] = self.model.names
|
||||
info["num_classes"] = len(self.model.names)
|
||||
|
||||
return info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model info: {e}")
|
||||
return {"error": str(e)}
|
||||
Reference in New Issue
Block a user