Implementing uint16 reading with tifffile

This commit is contained in:
2025-12-16 23:02:45 +02:00
parent e5036c10cf
commit e364d06217
4 changed files with 35 additions and 62 deletions

View File

@@ -1,9 +1,13 @@
"""
YOLO model wrapper for the microscopy object detection application.
Provides a clean interface to YOLOv8 for training, validation, and inference.
"""YOLO model wrapper for the microscopy object detection application.
Notes on 16-bit TIFF support:
- Ultralytics training defaults assume 8-bit images and normalize by dividing by 255.
- This project can patch Ultralytics at runtime to decode TIFFs via `tifffile` and
normalize `uint16` correctly.
See [`apply_ultralytics_16bit_tiff_patches()`](src/utils/ultralytics_16bit_patch.py:1).
"""
from ultralytics import YOLO
from pathlib import Path
from typing import Optional, List, Dict, Callable, Any
import torch
@@ -11,6 +15,7 @@ import tempfile
import os
from src.utils.image import Image
from src.utils.logger import get_logger
from src.utils.ultralytics_16bit_patch import apply_ultralytics_16bit_tiff_patches
logger = get_logger(__name__)
@@ -31,6 +36,9 @@ class YOLOWrapper:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"YOLOWrapper initialized with device: {self.device}")
# Apply Ultralytics runtime patches early (before first import/instantiation of YOLO datasets/trainers).
apply_ultralytics_16bit_tiff_patches()
def load_model(self) -> bool:
"""
Load YOLO model from path.
@@ -40,6 +48,9 @@ class YOLOWrapper:
"""
try:
logger.info(f"Loading YOLO model from {self.model_path}")
# Import YOLO lazily to ensure runtime patches are applied first.
from ultralytics import YOLO
self.model = YOLO(self.model_path)
self.model.to(self.device)
logger.info("Model loaded successfully")
@@ -89,6 +100,16 @@ class YOLOWrapper:
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
)
# Defaults for 16-bit safety: disable augmentations that force uint8 and HSV ops that assume 0..255.
# Users can override by passing explicit kwargs.
kwargs.setdefault("mosaic", 0.0)
kwargs.setdefault("mixup", 0.0)
kwargs.setdefault("cutmix", 0.0)
kwargs.setdefault("copy_paste", 0.0)
kwargs.setdefault("hsv_h", 0.0)
kwargs.setdefault("hsv_s", 0.0)
kwargs.setdefault("hsv_v", 0.0)
# Train the model
results = self.model.train(
data=data_yaml,