398 lines
13 KiB
Python
398 lines
13 KiB
Python
"""
|
|
YOLO model wrapper for the microscopy object detection application.
|
|
Provides a clean interface to YOLOv8 for training, validation, and inference.
|
|
"""
|
|
|
|
from ultralytics import YOLO
|
|
from pathlib import Path
|
|
from typing import Optional, List, Dict, Callable, Any
|
|
import torch
|
|
from src.utils.logger import get_logger
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class YOLOWrapper:
|
|
"""Wrapper for YOLOv8 model operations."""
|
|
|
|
def __init__(self, model_path: str = "yolov8s-seg.pt"):
|
|
"""
|
|
Initialize YOLO model.
|
|
|
|
Args:
|
|
model_path: Path to model weights (.pt file)
|
|
"""
|
|
self.model_path = model_path
|
|
self.model = None
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
logger.info(f"YOLOWrapper initialized with device: {self.device}")
|
|
|
|
def load_model(self) -> bool:
|
|
"""
|
|
Load YOLO model from path.
|
|
|
|
Returns:
|
|
True if loaded successfully, False otherwise
|
|
"""
|
|
try:
|
|
logger.info(f"Loading YOLO model from {self.model_path}")
|
|
self.model = YOLO(self.model_path)
|
|
self.model.to(self.device)
|
|
logger.info("Model loaded successfully")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error loading model: {e}")
|
|
return False
|
|
|
|
def train(
|
|
self,
|
|
data_yaml: str,
|
|
epochs: int = 100,
|
|
imgsz: int = 640,
|
|
batch: int = 16,
|
|
patience: int = 50,
|
|
save_dir: str = "data/models",
|
|
name: str = "custom_model",
|
|
resume: bool = False,
|
|
callbacks: Optional[Dict[str, Callable]] = None,
|
|
**kwargs,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Train the YOLO model.
|
|
|
|
Args:
|
|
data_yaml: Path to data.yaml configuration file
|
|
epochs: Number of training epochs
|
|
imgsz: Input image size
|
|
batch: Batch size
|
|
patience: Early stopping patience
|
|
save_dir: Directory to save trained model
|
|
name: Name for the training run
|
|
resume: Resume training from last checkpoint
|
|
callbacks: Optional Ultralytics callback dictionary
|
|
**kwargs: Additional training arguments
|
|
|
|
Returns:
|
|
Dictionary with training results
|
|
"""
|
|
if self.model is None:
|
|
self.load_model()
|
|
|
|
try:
|
|
logger.info(f"Starting training: {name}")
|
|
logger.info(
|
|
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
|
|
)
|
|
|
|
# Train the model
|
|
results = self.model.train(
|
|
data=data_yaml,
|
|
epochs=epochs,
|
|
imgsz=imgsz,
|
|
batch=batch,
|
|
patience=patience,
|
|
project=save_dir,
|
|
name=name,
|
|
device=self.device,
|
|
resume=resume,
|
|
**kwargs,
|
|
)
|
|
|
|
logger.info("Training completed successfully")
|
|
return self._format_training_results(results)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during training: {e}")
|
|
raise
|
|
|
|
def validate(self, data_yaml: str, split: str = "val", **kwargs) -> Dict[str, Any]:
|
|
"""
|
|
Validate the model.
|
|
|
|
Args:
|
|
data_yaml: Path to data.yaml configuration file
|
|
split: Dataset split to validate on ('val' or 'test')
|
|
**kwargs: Additional validation arguments
|
|
|
|
Returns:
|
|
Dictionary with validation metrics
|
|
"""
|
|
if self.model is None:
|
|
self.load_model()
|
|
|
|
try:
|
|
logger.info(f"Starting validation on {split} split")
|
|
results = self.model.val(
|
|
data=data_yaml, split=split, device=self.device, **kwargs
|
|
)
|
|
|
|
logger.info("Validation completed successfully")
|
|
return self._format_validation_results(results)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during validation: {e}")
|
|
raise
|
|
|
|
def predict(
|
|
self,
|
|
source: str,
|
|
conf: float = 0.25,
|
|
iou: float = 0.45,
|
|
save: bool = False,
|
|
save_txt: bool = False,
|
|
save_conf: bool = False,
|
|
**kwargs,
|
|
) -> List[Dict]:
|
|
"""
|
|
Perform inference on image(s).
|
|
|
|
Args:
|
|
source: Path to image or directory
|
|
conf: Confidence threshold
|
|
iou: IoU threshold for NMS
|
|
save: Whether to save annotated images
|
|
save_txt: Whether to save labels to .txt files
|
|
save_conf: Whether to save confidence in labels
|
|
**kwargs: Additional prediction arguments
|
|
|
|
Returns:
|
|
List of detection dictionaries
|
|
"""
|
|
if self.model is None:
|
|
self.load_model()
|
|
|
|
try:
|
|
logger.info(f"Running inference on {source}")
|
|
results = self.model.predict(
|
|
source=source,
|
|
conf=conf,
|
|
iou=iou,
|
|
save=save,
|
|
save_txt=save_txt,
|
|
save_conf=save_conf,
|
|
device=self.device,
|
|
**kwargs,
|
|
)
|
|
|
|
detections = self._format_prediction_results(results)
|
|
logger.info(f"Inference complete: {len(detections)} detections")
|
|
return detections
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during inference: {e}")
|
|
raise
|
|
|
|
def export(
|
|
self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
|
|
) -> str:
|
|
"""
|
|
Export model to different format.
|
|
|
|
Args:
|
|
format: Export format (onnx, torchscript, tflite, etc.)
|
|
output_path: Path for exported model
|
|
**kwargs: Additional export arguments
|
|
|
|
Returns:
|
|
Path to exported model
|
|
"""
|
|
if self.model is None:
|
|
self.load_model()
|
|
|
|
try:
|
|
logger.info(f"Exporting model to {format} format")
|
|
export_path = self.model.export(format=format, **kwargs)
|
|
logger.info(f"Model exported to {export_path}")
|
|
return str(export_path)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error exporting model: {e}")
|
|
raise
|
|
|
|
def _format_training_results(self, results) -> Dict[str, Any]:
|
|
"""Format training results into dictionary."""
|
|
try:
|
|
# Get the results dict
|
|
results_dict = (
|
|
results.results_dict if hasattr(results, "results_dict") else {}
|
|
)
|
|
|
|
formatted = {
|
|
"success": True,
|
|
"final_epoch": getattr(results, "epoch", 0),
|
|
"metrics": {
|
|
"mAP50": float(results_dict.get("metrics/mAP50(B)", 0)),
|
|
"mAP50-95": float(results_dict.get("metrics/mAP50-95(B)", 0)),
|
|
"precision": float(results_dict.get("metrics/precision(B)", 0)),
|
|
"recall": float(results_dict.get("metrics/recall(B)", 0)),
|
|
},
|
|
"best_model_path": str(Path(results.save_dir) / "weights" / "best.pt"),
|
|
"last_model_path": str(Path(results.save_dir) / "weights" / "last.pt"),
|
|
"save_dir": str(results.save_dir),
|
|
}
|
|
|
|
return formatted
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error formatting training results: {e}")
|
|
return {"success": False, "error": str(e)}
|
|
|
|
def _format_validation_results(self, results) -> Dict[str, Any]:
|
|
"""Format validation results into dictionary."""
|
|
try:
|
|
box_metrics = results.box
|
|
|
|
formatted = {
|
|
"success": True,
|
|
"mAP50": float(box_metrics.map50),
|
|
"mAP50-95": float(box_metrics.map),
|
|
"precision": float(box_metrics.mp),
|
|
"recall": float(box_metrics.mr),
|
|
"fitness": (
|
|
float(results.fitness) if hasattr(results, "fitness") else 0.0
|
|
),
|
|
}
|
|
|
|
# Add per-class metrics if available
|
|
if hasattr(box_metrics, "ap") and hasattr(results, "names"):
|
|
class_metrics = {}
|
|
for idx, name in results.names.items():
|
|
if idx < len(box_metrics.ap):
|
|
class_metrics[name] = {
|
|
"ap": float(box_metrics.ap[idx]),
|
|
"ap50": (
|
|
float(box_metrics.ap50[idx])
|
|
if hasattr(box_metrics, "ap50")
|
|
else 0.0
|
|
),
|
|
}
|
|
formatted["class_metrics"] = class_metrics
|
|
|
|
return formatted
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error formatting validation results: {e}")
|
|
return {"success": False, "error": str(e)}
|
|
|
|
def _format_prediction_results(self, results) -> List[Dict]:
|
|
"""Format prediction results into list of dictionaries."""
|
|
detections = []
|
|
|
|
try:
|
|
for result in results:
|
|
boxes = result.boxes
|
|
image_path = str(result.path)
|
|
orig_shape = result.orig_shape # (height, width)
|
|
height, width = orig_shape
|
|
|
|
# Check if this is a segmentation model with masks
|
|
has_masks = hasattr(result, "masks") and result.masks is not None
|
|
|
|
for i in range(len(boxes)):
|
|
# Get normalized coordinates
|
|
xyxyn = boxes.xyxyn[i].cpu().numpy() # Normalized [x1, y1, x2, y2]
|
|
|
|
detection = {
|
|
"image_path": image_path,
|
|
"class_id": int(boxes.cls[i]),
|
|
"class_name": result.names[int(boxes.cls[i])],
|
|
"confidence": float(boxes.conf[i]),
|
|
"bbox_normalized": [
|
|
float(v) for v in xyxyn
|
|
], # [x_min, y_min, x_max, y_max]
|
|
"bbox_absolute": [
|
|
float(v) for v in boxes.xyxy[i].cpu().numpy()
|
|
], # Absolute pixels
|
|
}
|
|
|
|
# Extract segmentation mask if available
|
|
if has_masks:
|
|
try:
|
|
# Get the mask for this detection
|
|
mask_data = result.masks.xy[
|
|
i
|
|
] # Polygon coordinates in absolute pixels
|
|
|
|
# Convert to normalized coordinates
|
|
if len(mask_data) > 0:
|
|
mask_normalized = []
|
|
for point in mask_data:
|
|
x_norm = float(point[0]) / width
|
|
y_norm = float(point[1]) / height
|
|
mask_normalized.append([x_norm, y_norm])
|
|
detection["segmentation_mask"] = mask_normalized
|
|
else:
|
|
detection["segmentation_mask"] = None
|
|
except Exception as mask_error:
|
|
logger.warning(
|
|
f"Error extracting mask for detection {i}: {mask_error}"
|
|
)
|
|
detection["segmentation_mask"] = None
|
|
else:
|
|
detection["segmentation_mask"] = None
|
|
|
|
detections.append(detection)
|
|
|
|
return detections
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error formatting prediction results: {e}")
|
|
return []
|
|
|
|
@staticmethod
|
|
def convert_bbox_format(
|
|
bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy"
|
|
) -> List[float]:
|
|
"""
|
|
Convert bounding box between formats.
|
|
|
|
Formats:
|
|
- xywh: [x_center, y_center, width, height]
|
|
- xyxy: [x_min, y_min, x_max, y_max]
|
|
|
|
Args:
|
|
bbox: Bounding box coordinates
|
|
format_from: Source format
|
|
format_to: Target format
|
|
|
|
Returns:
|
|
Converted bounding box
|
|
"""
|
|
if format_from == "xywh" and format_to == "xyxy":
|
|
x, y, w, h = bbox
|
|
return [x - w / 2, y - h / 2, x + w / 2, y + h / 2]
|
|
elif format_from == "xyxy" and format_to == "xywh":
|
|
x1, y1, x2, y2 = bbox
|
|
return [(x1 + x2) / 2, (y1 + y2) / 2, x2 - x1, y2 - y1]
|
|
else:
|
|
return bbox
|
|
|
|
def get_model_info(self) -> Dict[str, Any]:
|
|
"""
|
|
Get information about the loaded model.
|
|
|
|
Returns:
|
|
Dictionary with model information
|
|
"""
|
|
if self.model is None:
|
|
return {"error": "Model not loaded"}
|
|
|
|
try:
|
|
info = {
|
|
"model_path": self.model_path,
|
|
"device": self.device,
|
|
"task": getattr(self.model, "task", "unknown"),
|
|
}
|
|
|
|
# Try to get class names
|
|
if hasattr(self.model, "names"):
|
|
info["classes"] = self.model.names
|
|
info["num_classes"] = len(self.model.names)
|
|
|
|
return info
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting model info: {e}")
|
|
return {"error": str(e)}
|