""" YOLO model wrapper for the microscopy object detection application. Provides a clean interface to YOLOv8 for training, validation, and inference. """ from ultralytics import YOLO from pathlib import Path from typing import Optional, List, Dict, Callable, Any import torch from src.utils.logger import get_logger logger = get_logger(__name__) class YOLOWrapper: """Wrapper for YOLOv8 model operations.""" def __init__(self, model_path: str = "yolov8s-seg.pt"): """ Initialize YOLO model. Args: model_path: Path to model weights (.pt file) """ self.model_path = model_path self.model = None self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"YOLOWrapper initialized with device: {self.device}") def load_model(self) -> bool: """ Load YOLO model from path. Returns: True if loaded successfully, False otherwise """ try: logger.info(f"Loading YOLO model from {self.model_path}") self.model = YOLO(self.model_path) self.model.to(self.device) logger.info("Model loaded successfully") return True except Exception as e: logger.error(f"Error loading model: {e}") return False def train( self, data_yaml: str, epochs: int = 100, imgsz: int = 640, batch: int = 16, patience: int = 50, save_dir: str = "data/models", name: str = "custom_model", resume: bool = False, callbacks: Optional[Dict[str, Callable]] = None, **kwargs, ) -> Dict[str, Any]: """ Train the YOLO model. Args: data_yaml: Path to data.yaml configuration file epochs: Number of training epochs imgsz: Input image size batch: Batch size patience: Early stopping patience save_dir: Directory to save trained model name: Name for the training run resume: Resume training from last checkpoint callbacks: Optional Ultralytics callback dictionary **kwargs: Additional training arguments Returns: Dictionary with training results """ if self.model is None: self.load_model() try: logger.info(f"Starting training: {name}") logger.info( f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}" ) # Train the model results = self.model.train( data=data_yaml, epochs=epochs, imgsz=imgsz, batch=batch, patience=patience, project=save_dir, name=name, device=self.device, resume=resume, **kwargs, ) logger.info("Training completed successfully") return self._format_training_results(results) except Exception as e: logger.error(f"Error during training: {e}") raise def validate(self, data_yaml: str, split: str = "val", **kwargs) -> Dict[str, Any]: """ Validate the model. Args: data_yaml: Path to data.yaml configuration file split: Dataset split to validate on ('val' or 'test') **kwargs: Additional validation arguments Returns: Dictionary with validation metrics """ if self.model is None: self.load_model() try: logger.info(f"Starting validation on {split} split") results = self.model.val( data=data_yaml, split=split, device=self.device, **kwargs ) logger.info("Validation completed successfully") return self._format_validation_results(results) except Exception as e: logger.error(f"Error during validation: {e}") raise def predict( self, source: str, conf: float = 0.25, iou: float = 0.45, save: bool = False, save_txt: bool = False, save_conf: bool = False, **kwargs, ) -> List[Dict]: """ Perform inference on image(s). Args: source: Path to image or directory conf: Confidence threshold iou: IoU threshold for NMS save: Whether to save annotated images save_txt: Whether to save labels to .txt files save_conf: Whether to save confidence in labels **kwargs: Additional prediction arguments Returns: List of detection dictionaries """ if self.model is None: self.load_model() try: logger.info(f"Running inference on {source}") results = self.model.predict( source=source, conf=conf, iou=iou, save=save, save_txt=save_txt, save_conf=save_conf, device=self.device, **kwargs, ) detections = self._format_prediction_results(results) logger.info(f"Inference complete: {len(detections)} detections") return detections except Exception as e: logger.error(f"Error during inference: {e}") raise def export( self, format: str = "onnx", output_path: Optional[str] = None, **kwargs ) -> str: """ Export model to different format. Args: format: Export format (onnx, torchscript, tflite, etc.) output_path: Path for exported model **kwargs: Additional export arguments Returns: Path to exported model """ if self.model is None: self.load_model() try: logger.info(f"Exporting model to {format} format") export_path = self.model.export(format=format, **kwargs) logger.info(f"Model exported to {export_path}") return str(export_path) except Exception as e: logger.error(f"Error exporting model: {e}") raise def _format_training_results(self, results) -> Dict[str, Any]: """Format training results into dictionary.""" try: # Get the results dict results_dict = ( results.results_dict if hasattr(results, "results_dict") else {} ) formatted = { "success": True, "final_epoch": getattr(results, "epoch", 0), "metrics": { "mAP50": float(results_dict.get("metrics/mAP50(B)", 0)), "mAP50-95": float(results_dict.get("metrics/mAP50-95(B)", 0)), "precision": float(results_dict.get("metrics/precision(B)", 0)), "recall": float(results_dict.get("metrics/recall(B)", 0)), }, "best_model_path": str(Path(results.save_dir) / "weights" / "best.pt"), "last_model_path": str(Path(results.save_dir) / "weights" / "last.pt"), "save_dir": str(results.save_dir), } return formatted except Exception as e: logger.error(f"Error formatting training results: {e}") return {"success": False, "error": str(e)} def _format_validation_results(self, results) -> Dict[str, Any]: """Format validation results into dictionary.""" try: box_metrics = results.box formatted = { "success": True, "mAP50": float(box_metrics.map50), "mAP50-95": float(box_metrics.map), "precision": float(box_metrics.mp), "recall": float(box_metrics.mr), "fitness": ( float(results.fitness) if hasattr(results, "fitness") else 0.0 ), } # Add per-class metrics if available if hasattr(box_metrics, "ap") and hasattr(results, "names"): class_metrics = {} for idx, name in results.names.items(): if idx < len(box_metrics.ap): class_metrics[name] = { "ap": float(box_metrics.ap[idx]), "ap50": ( float(box_metrics.ap50[idx]) if hasattr(box_metrics, "ap50") else 0.0 ), } formatted["class_metrics"] = class_metrics return formatted except Exception as e: logger.error(f"Error formatting validation results: {e}") return {"success": False, "error": str(e)} def _format_prediction_results(self, results) -> List[Dict]: """Format prediction results into list of dictionaries.""" detections = [] try: for result in results: boxes = result.boxes image_path = str(result.path) orig_shape = result.orig_shape # (height, width) height, width = orig_shape # Check if this is a segmentation model with masks has_masks = hasattr(result, "masks") and result.masks is not None for i in range(len(boxes)): # Get normalized coordinates xyxyn = boxes.xyxyn[i].cpu().numpy() # Normalized [x1, y1, x2, y2] detection = { "image_path": image_path, "class_id": int(boxes.cls[i]), "class_name": result.names[int(boxes.cls[i])], "confidence": float(boxes.conf[i]), "bbox_normalized": [ float(v) for v in xyxyn ], # [x_min, y_min, x_max, y_max] "bbox_absolute": [ float(v) for v in boxes.xyxy[i].cpu().numpy() ], # Absolute pixels } # Extract segmentation mask if available if has_masks: try: # Get the mask for this detection mask_data = result.masks.xy[ i ] # Polygon coordinates in absolute pixels # Convert to normalized coordinates if len(mask_data) > 0: mask_normalized = [] for point in mask_data: x_norm = float(point[0]) / width y_norm = float(point[1]) / height mask_normalized.append([x_norm, y_norm]) detection["segmentation_mask"] = mask_normalized else: detection["segmentation_mask"] = None except Exception as mask_error: logger.warning( f"Error extracting mask for detection {i}: {mask_error}" ) detection["segmentation_mask"] = None else: detection["segmentation_mask"] = None detections.append(detection) return detections except Exception as e: logger.error(f"Error formatting prediction results: {e}") return [] @staticmethod def convert_bbox_format( bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy" ) -> List[float]: """ Convert bounding box between formats. Formats: - xywh: [x_center, y_center, width, height] - xyxy: [x_min, y_min, x_max, y_max] Args: bbox: Bounding box coordinates format_from: Source format format_to: Target format Returns: Converted bounding box """ if format_from == "xywh" and format_to == "xyxy": x, y, w, h = bbox return [x - w / 2, y - h / 2, x + w / 2, y + h / 2] elif format_from == "xyxy" and format_to == "xywh": x1, y1, x2, y2 = bbox return [(x1 + x2) / 2, (y1 + y2) / 2, x2 - x1, y2 - y1] else: return bbox def get_model_info(self) -> Dict[str, Any]: """ Get information about the loaded model. Returns: Dictionary with model information """ if self.model is None: return {"error": "Model not loaded"} try: info = { "model_path": self.model_path, "device": self.device, "task": getattr(self.model, "task", "unknown"), } # Try to get class names if hasattr(self.model, "names"): info["classes"] = self.model.names info["num_classes"] = len(self.model.names) return info except Exception as e: logger.error(f"Error getting model info: {e}") return {"error": str(e)}