"""YOLO model wrapper for the microscopy object detection application. Notes on 16-bit TIFF support: - Ultralytics training defaults assume 8-bit images and normalize by dividing by 255. - This project can patch Ultralytics at runtime to decode TIFFs via `tifffile` and normalize `uint16` correctly. See [`apply_ultralytics_16bit_tiff_patches()`](src/utils/ultralytics_16bit_patch.py:1). """ from pathlib import Path from typing import Optional, List, Dict, Callable, Any import torch import tempfile import os from src.utils.image import Image from src.utils.logger import get_logger from src.utils.ultralytics_16bit_patch import apply_ultralytics_16bit_tiff_patches logger = get_logger(__name__) class YOLOWrapper: """Wrapper for YOLOv8 model operations.""" def __init__(self, model_path: str = "yolov8s-seg.pt"): """ Initialize YOLO model. Args: model_path: Path to model weights (.pt file) """ self.model_path = model_path self.model = None self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"YOLOWrapper initialized with device: {self.device}") # Apply Ultralytics runtime patches early (before first import/instantiation of YOLO datasets/trainers). apply_ultralytics_16bit_tiff_patches() def load_model(self) -> bool: """ Load YOLO model from path. Returns: True if loaded successfully, False otherwise """ try: logger.info(f"Loading YOLO model from {self.model_path}") # Import YOLO lazily to ensure runtime patches are applied first. from ultralytics import YOLO self.model = YOLO(self.model_path) self.model.to(self.device) logger.info("Model loaded successfully") return True except Exception as e: logger.error(f"Error loading model: {e}") return False def train( self, data_yaml: str, epochs: int = 100, imgsz: int = 640, batch: int = 16, patience: int = 50, save_dir: str = "data/models", name: str = "custom_model", resume: bool = False, callbacks: Optional[Dict[str, Callable]] = None, **kwargs, ) -> Dict[str, Any]: """ Train the YOLO model. Args: data_yaml: Path to data.yaml configuration file epochs: Number of training epochs imgsz: Input image size batch: Batch size patience: Early stopping patience save_dir: Directory to save trained model name: Name for the training run resume: Resume training from last checkpoint callbacks: Optional Ultralytics callback dictionary **kwargs: Additional training arguments Returns: Dictionary with training results """ if self.model is None: if not self.load_model(): raise RuntimeError(f"Failed to load model from {self.model_path}") try: logger.info(f"Starting training: {name}") logger.info( f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}" ) # Defaults for 16-bit safety: disable augmentations that force uint8 and HSV ops that assume 0..255. # Users can override by passing explicit kwargs. kwargs.setdefault("mosaic", 0.0) kwargs.setdefault("mixup", 0.0) kwargs.setdefault("cutmix", 0.0) kwargs.setdefault("copy_paste", 0.0) kwargs.setdefault("hsv_h", 0.0) kwargs.setdefault("hsv_s", 0.0) kwargs.setdefault("hsv_v", 0.0) # Train the model results = self.model.train( data=data_yaml, epochs=epochs, imgsz=imgsz, batch=batch, patience=patience, project=save_dir, name=name, device=self.device, resume=resume, **kwargs, ) logger.info("Training completed successfully") return self._format_training_results(results) except Exception as e: logger.error(f"Error during training: {e}") raise def validate(self, data_yaml: str, split: str = "val", **kwargs) -> Dict[str, Any]: """ Validate the model. Args: data_yaml: Path to data.yaml configuration file split: Dataset split to validate on ('val' or 'test') **kwargs: Additional validation arguments Returns: Dictionary with validation metrics """ if self.model is None: if not self.load_model(): raise RuntimeError(f"Failed to load model from {self.model_path}") try: logger.info(f"Starting validation on {split} split") results = self.model.val( data=data_yaml, split=split, device=self.device, **kwargs ) logger.info("Validation completed successfully") return self._format_validation_results(results) except Exception as e: logger.error(f"Error during validation: {e}") raise def predict( self, source: str, conf: float = 0.25, iou: float = 0.45, save: bool = False, save_txt: bool = False, save_conf: bool = False, **kwargs, ) -> List[Dict]: """ Perform inference on image(s). Args: source: Path to image or directory conf: Confidence threshold iou: IoU threshold for NMS save: Whether to save annotated images save_txt: Whether to save labels to .txt files save_conf: Whether to save confidence in labels **kwargs: Additional prediction arguments Returns: List of detection dictionaries """ if self.model is None: if not self.load_model(): raise RuntimeError(f"Failed to load model from {self.model_path}") prepared_source, cleanup_path = self._prepare_source(source) try: logger.info( f"Running inference on {source} -> prepared_source {prepared_source}" ) results = self.model.predict( source=source, conf=conf, iou=iou, save=save, save_txt=save_txt, save_conf=save_conf, device=self.device, **kwargs, ) detections = self._format_prediction_results(results) logger.info(f"Inference complete: {len(detections)} detections") return detections except Exception as e: logger.error(f"Error during inference: {e}") raise finally: if 0: # cleanup_path: try: os.remove(cleanup_path) except OSError as cleanup_error: logger.warning( f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}" ) def export( self, format: str = "onnx", output_path: Optional[str] = None, **kwargs ) -> str: """ Export model to different format. Args: format: Export format (onnx, torchscript, tflite, etc.) output_path: Path for exported model **kwargs: Additional export arguments Returns: Path to exported model """ if self.model is None: if not self.load_model(): raise RuntimeError(f"Failed to load model from {self.model_path}") try: logger.info(f"Exporting model to {format} format") export_path = self.model.export(format=format, **kwargs) logger.info(f"Model exported to {export_path}") return str(export_path) except Exception as e: logger.error(f"Error exporting model: {e}") raise def _prepare_source(self, source): """Convert single-channel images to RGB temporarily for inference.""" cleanup_path = None if isinstance(source, (str, Path)): source_path = Path(source) if source_path.is_file(): try: img_obj = Image(source_path) suffix = source_path.suffix or ".png" tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) tmp_path = tmp.name tmp.close() img_obj.save(tmp_path) cleanup_path = tmp_path logger.info( f"Converted image {source_path} to RGB for inference at {tmp_path}" ) return tmp_path, cleanup_path except Exception as convert_error: logger.warning( f"Failed to preprocess {source_path} as RGB, continuing with original file: {convert_error}" ) return source, cleanup_path def _format_training_results(self, results) -> Dict[str, Any]: """Format training results into dictionary.""" try: # Get the results dict results_dict = ( results.results_dict if hasattr(results, "results_dict") else {} ) formatted = { "success": True, "final_epoch": getattr(results, "epoch", 0), "metrics": { "mAP50": float(results_dict.get("metrics/mAP50(B)", 0)), "mAP50-95": float(results_dict.get("metrics/mAP50-95(B)", 0)), "precision": float(results_dict.get("metrics/precision(B)", 0)), "recall": float(results_dict.get("metrics/recall(B)", 0)), }, "best_model_path": str(Path(results.save_dir) / "weights" / "best.pt"), "last_model_path": str(Path(results.save_dir) / "weights" / "last.pt"), "save_dir": str(results.save_dir), } return formatted except Exception as e: logger.error(f"Error formatting training results: {e}") return {"success": False, "error": str(e)} def _format_validation_results(self, results) -> Dict[str, Any]: """Format validation results into dictionary.""" try: box_metrics = results.box formatted = { "success": True, "mAP50": float(box_metrics.map50), "mAP50-95": float(box_metrics.map), "precision": float(box_metrics.mp), "recall": float(box_metrics.mr), "fitness": ( float(results.fitness) if hasattr(results, "fitness") else 0.0 ), } # Add per-class metrics if available if hasattr(box_metrics, "ap") and hasattr(results, "names"): class_metrics = {} for idx, name in results.names.items(): if idx < len(box_metrics.ap): class_metrics[name] = { "ap": float(box_metrics.ap[idx]), "ap50": ( float(box_metrics.ap50[idx]) if hasattr(box_metrics, "ap50") else 0.0 ), } formatted["class_metrics"] = class_metrics return formatted except Exception as e: logger.error(f"Error formatting validation results: {e}") return {"success": False, "error": str(e)} def _format_prediction_results(self, results) -> List[Dict]: """Format prediction results into list of dictionaries.""" detections = [] try: for result in results: boxes = result.boxes image_path = str(result.path) orig_shape = result.orig_shape # (height, width) height, width = orig_shape # Check if this is a segmentation model with masks has_masks = hasattr(result, "masks") and result.masks is not None for i in range(len(boxes)): # Get normalized coordinates xyxyn = boxes.xyxyn[i].cpu().numpy() # Normalized [x1, y1, x2, y2] detection = { "image_path": image_path, "class_id": int(boxes.cls[i]), "class_name": result.names[int(boxes.cls[i])], "confidence": float(boxes.conf[i]), "bbox_normalized": [ float(v) for v in xyxyn ], # [x_min, y_min, x_max, y_max] "bbox_absolute": [ float(v) for v in boxes.xyxy[i].cpu().numpy() ], # Absolute pixels } # Extract segmentation mask if available if has_masks: try: # Get the mask for this detection mask_data = result.masks.xy[ i ] # Polygon coordinates in absolute pixels # Convert to normalized coordinates if len(mask_data) > 0: mask_normalized = [] for point in mask_data: x_norm = float(point[0]) / width y_norm = float(point[1]) / height mask_normalized.append([x_norm, y_norm]) detection["segmentation_mask"] = mask_normalized else: detection["segmentation_mask"] = None except Exception as mask_error: logger.warning( f"Error extracting mask for detection {i}: {mask_error}" ) detection["segmentation_mask"] = None else: detection["segmentation_mask"] = None detections.append(detection) return detections except Exception as e: logger.error(f"Error formatting prediction results: {e}") return [] @staticmethod def convert_bbox_format( bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy" ) -> List[float]: """ Convert bounding box between formats. Formats: - xywh: [x_center, y_center, width, height] - xyxy: [x_min, y_min, x_max, y_max] Args: bbox: Bounding box coordinates format_from: Source format format_to: Target format Returns: Converted bounding box """ if format_from == "xywh" and format_to == "xyxy": x, y, w, h = bbox return [x - w / 2, y - h / 2, x + w / 2, y + h / 2] elif format_from == "xyxy" and format_to == "xywh": x1, y1, x2, y2 = bbox return [(x1 + x2) / 2, (y1 + y2) / 2, x2 - x1, y2 - y1] else: return bbox def get_model_info(self) -> Dict[str, Any]: """ Get information about the loaded model. Returns: Dictionary with model information """ if self.model is None: return {"error": "Model not loaded"} try: info = { "model_path": self.model_path, "device": self.device, "task": getattr(self.model, "task", "unknown"), } # Try to get class names if hasattr(self.model, "names"): info["classes"] = self.model.names info["num_classes"] = len(self.model.names) return info except Exception as e: logger.error(f"Error getting model info: {e}") return {"error": str(e)}