Files
object-segmentation/src/model/yolo_wrapper.py
2025-12-10 15:46:26 +02:00

398 lines
13 KiB
Python

"""
YOLO model wrapper for the microscopy object detection application.
Provides a clean interface to YOLOv8 for training, validation, and inference.
"""
from ultralytics import YOLO
from pathlib import Path
from typing import Optional, List, Dict, Callable, Any
import torch
from src.utils.logger import get_logger
logger = get_logger(__name__)
class YOLOWrapper:
"""Wrapper for YOLOv8 model operations."""
def __init__(self, model_path: str = "yolov8s-seg.pt"):
"""
Initialize YOLO model.
Args:
model_path: Path to model weights (.pt file)
"""
self.model_path = model_path
self.model = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"YOLOWrapper initialized with device: {self.device}")
def load_model(self) -> bool:
"""
Load YOLO model from path.
Returns:
True if loaded successfully, False otherwise
"""
try:
logger.info(f"Loading YOLO model from {self.model_path}")
self.model = YOLO(self.model_path)
self.model.to(self.device)
logger.info("Model loaded successfully")
return True
except Exception as e:
logger.error(f"Error loading model: {e}")
return False
def train(
self,
data_yaml: str,
epochs: int = 100,
imgsz: int = 640,
batch: int = 16,
patience: int = 50,
save_dir: str = "data/models",
name: str = "custom_model",
resume: bool = False,
callbacks: Optional[Dict[str, Callable]] = None,
**kwargs,
) -> Dict[str, Any]:
"""
Train the YOLO model.
Args:
data_yaml: Path to data.yaml configuration file
epochs: Number of training epochs
imgsz: Input image size
batch: Batch size
patience: Early stopping patience
save_dir: Directory to save trained model
name: Name for the training run
resume: Resume training from last checkpoint
callbacks: Optional Ultralytics callback dictionary
**kwargs: Additional training arguments
Returns:
Dictionary with training results
"""
if self.model is None:
self.load_model()
try:
logger.info(f"Starting training: {name}")
logger.info(
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
)
# Train the model
results = self.model.train(
data=data_yaml,
epochs=epochs,
imgsz=imgsz,
batch=batch,
patience=patience,
project=save_dir,
name=name,
device=self.device,
resume=resume,
**kwargs,
)
logger.info("Training completed successfully")
return self._format_training_results(results)
except Exception as e:
logger.error(f"Error during training: {e}")
raise
def validate(self, data_yaml: str, split: str = "val", **kwargs) -> Dict[str, Any]:
"""
Validate the model.
Args:
data_yaml: Path to data.yaml configuration file
split: Dataset split to validate on ('val' or 'test')
**kwargs: Additional validation arguments
Returns:
Dictionary with validation metrics
"""
if self.model is None:
self.load_model()
try:
logger.info(f"Starting validation on {split} split")
results = self.model.val(
data=data_yaml, split=split, device=self.device, **kwargs
)
logger.info("Validation completed successfully")
return self._format_validation_results(results)
except Exception as e:
logger.error(f"Error during validation: {e}")
raise
def predict(
self,
source: str,
conf: float = 0.25,
iou: float = 0.45,
save: bool = False,
save_txt: bool = False,
save_conf: bool = False,
**kwargs,
) -> List[Dict]:
"""
Perform inference on image(s).
Args:
source: Path to image or directory
conf: Confidence threshold
iou: IoU threshold for NMS
save: Whether to save annotated images
save_txt: Whether to save labels to .txt files
save_conf: Whether to save confidence in labels
**kwargs: Additional prediction arguments
Returns:
List of detection dictionaries
"""
if self.model is None:
self.load_model()
try:
logger.info(f"Running inference on {source}")
results = self.model.predict(
source=source,
conf=conf,
iou=iou,
save=save,
save_txt=save_txt,
save_conf=save_conf,
device=self.device,
**kwargs,
)
detections = self._format_prediction_results(results)
logger.info(f"Inference complete: {len(detections)} detections")
return detections
except Exception as e:
logger.error(f"Error during inference: {e}")
raise
def export(
self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
) -> str:
"""
Export model to different format.
Args:
format: Export format (onnx, torchscript, tflite, etc.)
output_path: Path for exported model
**kwargs: Additional export arguments
Returns:
Path to exported model
"""
if self.model is None:
self.load_model()
try:
logger.info(f"Exporting model to {format} format")
export_path = self.model.export(format=format, **kwargs)
logger.info(f"Model exported to {export_path}")
return str(export_path)
except Exception as e:
logger.error(f"Error exporting model: {e}")
raise
def _format_training_results(self, results) -> Dict[str, Any]:
"""Format training results into dictionary."""
try:
# Get the results dict
results_dict = (
results.results_dict if hasattr(results, "results_dict") else {}
)
formatted = {
"success": True,
"final_epoch": getattr(results, "epoch", 0),
"metrics": {
"mAP50": float(results_dict.get("metrics/mAP50(B)", 0)),
"mAP50-95": float(results_dict.get("metrics/mAP50-95(B)", 0)),
"precision": float(results_dict.get("metrics/precision(B)", 0)),
"recall": float(results_dict.get("metrics/recall(B)", 0)),
},
"best_model_path": str(Path(results.save_dir) / "weights" / "best.pt"),
"last_model_path": str(Path(results.save_dir) / "weights" / "last.pt"),
"save_dir": str(results.save_dir),
}
return formatted
except Exception as e:
logger.error(f"Error formatting training results: {e}")
return {"success": False, "error": str(e)}
def _format_validation_results(self, results) -> Dict[str, Any]:
"""Format validation results into dictionary."""
try:
box_metrics = results.box
formatted = {
"success": True,
"mAP50": float(box_metrics.map50),
"mAP50-95": float(box_metrics.map),
"precision": float(box_metrics.mp),
"recall": float(box_metrics.mr),
"fitness": (
float(results.fitness) if hasattr(results, "fitness") else 0.0
),
}
# Add per-class metrics if available
if hasattr(box_metrics, "ap") and hasattr(results, "names"):
class_metrics = {}
for idx, name in results.names.items():
if idx < len(box_metrics.ap):
class_metrics[name] = {
"ap": float(box_metrics.ap[idx]),
"ap50": (
float(box_metrics.ap50[idx])
if hasattr(box_metrics, "ap50")
else 0.0
),
}
formatted["class_metrics"] = class_metrics
return formatted
except Exception as e:
logger.error(f"Error formatting validation results: {e}")
return {"success": False, "error": str(e)}
def _format_prediction_results(self, results) -> List[Dict]:
"""Format prediction results into list of dictionaries."""
detections = []
try:
for result in results:
boxes = result.boxes
image_path = str(result.path)
orig_shape = result.orig_shape # (height, width)
height, width = orig_shape
# Check if this is a segmentation model with masks
has_masks = hasattr(result, "masks") and result.masks is not None
for i in range(len(boxes)):
# Get normalized coordinates
xyxyn = boxes.xyxyn[i].cpu().numpy() # Normalized [x1, y1, x2, y2]
detection = {
"image_path": image_path,
"class_id": int(boxes.cls[i]),
"class_name": result.names[int(boxes.cls[i])],
"confidence": float(boxes.conf[i]),
"bbox_normalized": [
float(v) for v in xyxyn
], # [x_min, y_min, x_max, y_max]
"bbox_absolute": [
float(v) for v in boxes.xyxy[i].cpu().numpy()
], # Absolute pixels
}
# Extract segmentation mask if available
if has_masks:
try:
# Get the mask for this detection
mask_data = result.masks.xy[
i
] # Polygon coordinates in absolute pixels
# Convert to normalized coordinates
if len(mask_data) > 0:
mask_normalized = []
for point in mask_data:
x_norm = float(point[0]) / width
y_norm = float(point[1]) / height
mask_normalized.append([x_norm, y_norm])
detection["segmentation_mask"] = mask_normalized
else:
detection["segmentation_mask"] = None
except Exception as mask_error:
logger.warning(
f"Error extracting mask for detection {i}: {mask_error}"
)
detection["segmentation_mask"] = None
else:
detection["segmentation_mask"] = None
detections.append(detection)
return detections
except Exception as e:
logger.error(f"Error formatting prediction results: {e}")
return []
@staticmethod
def convert_bbox_format(
bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy"
) -> List[float]:
"""
Convert bounding box between formats.
Formats:
- xywh: [x_center, y_center, width, height]
- xyxy: [x_min, y_min, x_max, y_max]
Args:
bbox: Bounding box coordinates
format_from: Source format
format_to: Target format
Returns:
Converted bounding box
"""
if format_from == "xywh" and format_to == "xyxy":
x, y, w, h = bbox
return [x - w / 2, y - h / 2, x + w / 2, y + h / 2]
elif format_from == "xyxy" and format_to == "xywh":
x1, y1, x2, y2 = bbox
return [(x1 + x2) / 2, (y1 + y2) / 2, x2 - x1, y2 - y1]
else:
return bbox
def get_model_info(self) -> Dict[str, Any]:
"""
Get information about the loaded model.
Returns:
Dictionary with model information
"""
if self.model is None:
return {"error": "Model not loaded"}
try:
info = {
"model_path": self.model_path,
"device": self.device,
"task": getattr(self.model, "task", "unknown"),
}
# Try to get class names
if hasattr(self.model, "names"):
info["classes"] = self.model.names
info["num_classes"] = len(self.model.names)
return info
except Exception as e:
logger.error(f"Error getting model info: {e}")
return {"error": str(e)}