Files
object-segmentation/src/model/yolo_wrapper.py

441 lines
15 KiB
Python
Raw Normal View History

2025-12-05 09:50:50 +02:00
"""
YOLO model wrapper for the microscopy object detection application.
Provides a clean interface to YOLOv8 for training, validation, and inference.
"""
from ultralytics import YOLO
from pathlib import Path
from typing import Optional, List, Dict, Callable, Any
import torch
2025-12-10 16:55:28 +02:00
from PIL import Image
import tempfile
import os
2025-12-05 09:50:50 +02:00
from src.utils.logger import get_logger
logger = get_logger(__name__)
class YOLOWrapper:
"""Wrapper for YOLOv8 model operations."""
def __init__(self, model_path: str = "yolov8s-seg.pt"):
2025-12-05 09:50:50 +02:00
"""
Initialize YOLO model.
Args:
model_path: Path to model weights (.pt file)
"""
self.model_path = model_path
self.model = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"YOLOWrapper initialized with device: {self.device}")
def load_model(self) -> bool:
"""
Load YOLO model from path.
Returns:
True if loaded successfully, False otherwise
"""
try:
logger.info(f"Loading YOLO model from {self.model_path}")
self.model = YOLO(self.model_path)
self.model.to(self.device)
logger.info("Model loaded successfully")
return True
except Exception as e:
logger.error(f"Error loading model: {e}")
return False
def train(
self,
data_yaml: str,
epochs: int = 100,
imgsz: int = 640,
batch: int = 16,
patience: int = 50,
save_dir: str = "data/models",
name: str = "custom_model",
resume: bool = False,
2025-12-10 15:46:26 +02:00
callbacks: Optional[Dict[str, Callable]] = None,
2025-12-05 09:50:50 +02:00
**kwargs,
) -> Dict[str, Any]:
"""
Train the YOLO model.
Args:
data_yaml: Path to data.yaml configuration file
epochs: Number of training epochs
imgsz: Input image size
batch: Batch size
patience: Early stopping patience
save_dir: Directory to save trained model
name: Name for the training run
resume: Resume training from last checkpoint
2025-12-10 15:46:26 +02:00
callbacks: Optional Ultralytics callback dictionary
2025-12-05 09:50:50 +02:00
**kwargs: Additional training arguments
Returns:
Dictionary with training results
"""
if self.model is None:
self.load_model()
try:
logger.info(f"Starting training: {name}")
logger.info(
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
)
# Train the model
results = self.model.train(
data=data_yaml,
epochs=epochs,
imgsz=imgsz,
batch=batch,
patience=patience,
project=save_dir,
name=name,
device=self.device,
resume=resume,
**kwargs,
)
logger.info("Training completed successfully")
return self._format_training_results(results)
except Exception as e:
logger.error(f"Error during training: {e}")
raise
def validate(self, data_yaml: str, split: str = "val", **kwargs) -> Dict[str, Any]:
"""
Validate the model.
Args:
data_yaml: Path to data.yaml configuration file
split: Dataset split to validate on ('val' or 'test')
**kwargs: Additional validation arguments
Returns:
Dictionary with validation metrics
"""
if self.model is None:
self.load_model()
try:
logger.info(f"Starting validation on {split} split")
results = self.model.val(
data=data_yaml, split=split, device=self.device, **kwargs
)
logger.info("Validation completed successfully")
return self._format_validation_results(results)
except Exception as e:
logger.error(f"Error during validation: {e}")
raise
def predict(
self,
source: str,
conf: float = 0.25,
iou: float = 0.45,
save: bool = False,
save_txt: bool = False,
save_conf: bool = False,
**kwargs,
) -> List[Dict]:
"""
Perform inference on image(s).
Args:
source: Path to image or directory
conf: Confidence threshold
iou: IoU threshold for NMS
save: Whether to save annotated images
save_txt: Whether to save labels to .txt files
save_conf: Whether to save confidence in labels
**kwargs: Additional prediction arguments
Returns:
List of detection dictionaries
"""
if self.model is None:
self.load_model()
2025-12-10 16:55:28 +02:00
prepared_source, cleanup_path = self._prepare_source(source)
2025-12-05 09:50:50 +02:00
try:
logger.info(f"Running inference on {source}")
results = self.model.predict(
2025-12-10 16:55:28 +02:00
source=prepared_source,
2025-12-05 09:50:50 +02:00
conf=conf,
iou=iou,
save=save,
save_txt=save_txt,
save_conf=save_conf,
device=self.device,
**kwargs,
)
detections = self._format_prediction_results(results)
logger.info(f"Inference complete: {len(detections)} detections")
return detections
except Exception as e:
logger.error(f"Error during inference: {e}")
raise
2025-12-10 16:55:28 +02:00
finally:
if cleanup_path:
try:
os.remove(cleanup_path)
except OSError as cleanup_error:
logger.warning(
f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}"
)
2025-12-05 09:50:50 +02:00
def export(
self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
) -> str:
"""
Export model to different format.
Args:
format: Export format (onnx, torchscript, tflite, etc.)
output_path: Path for exported model
**kwargs: Additional export arguments
Returns:
Path to exported model
"""
if self.model is None:
self.load_model()
try:
logger.info(f"Exporting model to {format} format")
export_path = self.model.export(format=format, **kwargs)
logger.info(f"Model exported to {export_path}")
return str(export_path)
except Exception as e:
logger.error(f"Error exporting model: {e}")
raise
2025-12-10 16:55:28 +02:00
def _prepare_source(self, source):
"""Convert single-channel images to RGB temporarily for inference."""
cleanup_path = None
if isinstance(source, (str, Path)):
source_path = Path(source)
if source_path.is_file():
try:
with Image.open(source_path) as img:
if len(img.getbands()) == 1:
rgb_img = img.convert("RGB")
suffix = source_path.suffix or ".png"
tmp = tempfile.NamedTemporaryFile(
suffix=suffix, delete=False
)
tmp_path = tmp.name
tmp.close()
rgb_img.save(tmp_path)
cleanup_path = tmp_path
logger.info(
f"Converted single-channel image {source_path} to RGB for inference"
)
return tmp_path, cleanup_path
except Exception as convert_error:
logger.warning(
f"Failed to preprocess {source_path} as RGB, continuing with original file: {convert_error}"
)
return source, cleanup_path
2025-12-05 09:50:50 +02:00
def _format_training_results(self, results) -> Dict[str, Any]:
"""Format training results into dictionary."""
try:
# Get the results dict
results_dict = (
results.results_dict if hasattr(results, "results_dict") else {}
)
formatted = {
"success": True,
"final_epoch": getattr(results, "epoch", 0),
"metrics": {
"mAP50": float(results_dict.get("metrics/mAP50(B)", 0)),
"mAP50-95": float(results_dict.get("metrics/mAP50-95(B)", 0)),
"precision": float(results_dict.get("metrics/precision(B)", 0)),
"recall": float(results_dict.get("metrics/recall(B)", 0)),
},
"best_model_path": str(Path(results.save_dir) / "weights" / "best.pt"),
"last_model_path": str(Path(results.save_dir) / "weights" / "last.pt"),
"save_dir": str(results.save_dir),
}
return formatted
except Exception as e:
logger.error(f"Error formatting training results: {e}")
return {"success": False, "error": str(e)}
def _format_validation_results(self, results) -> Dict[str, Any]:
"""Format validation results into dictionary."""
try:
box_metrics = results.box
formatted = {
"success": True,
"mAP50": float(box_metrics.map50),
"mAP50-95": float(box_metrics.map),
"precision": float(box_metrics.mp),
"recall": float(box_metrics.mr),
"fitness": (
float(results.fitness) if hasattr(results, "fitness") else 0.0
),
}
# Add per-class metrics if available
if hasattr(box_metrics, "ap") and hasattr(results, "names"):
class_metrics = {}
for idx, name in results.names.items():
if idx < len(box_metrics.ap):
class_metrics[name] = {
"ap": float(box_metrics.ap[idx]),
"ap50": (
float(box_metrics.ap50[idx])
if hasattr(box_metrics, "ap50")
else 0.0
),
}
formatted["class_metrics"] = class_metrics
return formatted
except Exception as e:
logger.error(f"Error formatting validation results: {e}")
return {"success": False, "error": str(e)}
def _format_prediction_results(self, results) -> List[Dict]:
"""Format prediction results into list of dictionaries."""
detections = []
try:
for result in results:
boxes = result.boxes
image_path = str(result.path)
orig_shape = result.orig_shape # (height, width)
height, width = orig_shape
# Check if this is a segmentation model with masks
has_masks = hasattr(result, "masks") and result.masks is not None
2025-12-05 09:50:50 +02:00
for i in range(len(boxes)):
# Get normalized coordinates
xyxyn = boxes.xyxyn[i].cpu().numpy() # Normalized [x1, y1, x2, y2]
detection = {
"image_path": image_path,
"class_id": int(boxes.cls[i]),
"class_name": result.names[int(boxes.cls[i])],
"confidence": float(boxes.conf[i]),
"bbox_normalized": [
float(v) for v in xyxyn
], # [x_min, y_min, x_max, y_max]
"bbox_absolute": [
float(v) for v in boxes.xyxy[i].cpu().numpy()
], # Absolute pixels
}
# Extract segmentation mask if available
if has_masks:
try:
# Get the mask for this detection
mask_data = result.masks.xy[
i
] # Polygon coordinates in absolute pixels
# Convert to normalized coordinates
if len(mask_data) > 0:
mask_normalized = []
for point in mask_data:
x_norm = float(point[0]) / width
y_norm = float(point[1]) / height
mask_normalized.append([x_norm, y_norm])
detection["segmentation_mask"] = mask_normalized
else:
detection["segmentation_mask"] = None
except Exception as mask_error:
logger.warning(
f"Error extracting mask for detection {i}: {mask_error}"
)
detection["segmentation_mask"] = None
else:
detection["segmentation_mask"] = None
2025-12-05 09:50:50 +02:00
detections.append(detection)
return detections
except Exception as e:
logger.error(f"Error formatting prediction results: {e}")
return []
@staticmethod
def convert_bbox_format(
bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy"
) -> List[float]:
"""
Convert bounding box between formats.
Formats:
- xywh: [x_center, y_center, width, height]
- xyxy: [x_min, y_min, x_max, y_max]
Args:
bbox: Bounding box coordinates
format_from: Source format
format_to: Target format
Returns:
Converted bounding box
"""
if format_from == "xywh" and format_to == "xyxy":
x, y, w, h = bbox
return [x - w / 2, y - h / 2, x + w / 2, y + h / 2]
elif format_from == "xyxy" and format_to == "xywh":
x1, y1, x2, y2 = bbox
return [(x1 + x2) / 2, (y1 + y2) / 2, x2 - x1, y2 - y1]
else:
return bbox
def get_model_info(self) -> Dict[str, Any]:
"""
Get information about the loaded model.
Returns:
Dictionary with model information
"""
if self.model is None:
return {"error": "Model not loaded"}
try:
info = {
"model_path": self.model_path,
"device": self.device,
"task": getattr(self.model, "task", "unknown"),
}
# Try to get class names
if hasattr(self.model, "names"):
info["classes"] = self.model.names
info["num_classes"] = len(self.model.names)
return info
except Exception as e:
logger.error(f"Error getting model info: {e}")
return {"error": str(e)}