"""Ultralytics runtime patches for 16-bit TIFF training. Goals: - Use `tifffile` to decode `.tif/.tiff` reliably (OpenCV can silently drop bit-depth depending on codec). - Preserve 16-bit data through the dataloader as `uint16` tensors. - Fix Ultralytics trainer normalization (default divides by 255) to scale `uint16` correctly. - Avoid uint8-forcing augmentations by recommending/setting hyp values (handled by caller). This module is intended to be imported/called **before** instantiating/using YOLO. """ from __future__ import annotations from typing import Optional def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None: """Apply runtime monkey-patches to Ultralytics to better support 16-bit TIFFs. This function is safe to call multiple times. Args: force: If True, re-apply patches even if already applied. """ # Import inside function to ensure patching occurs before YOLO model/dataset is created. import os import cv2 import numpy as np import tifffile import torch from ultralytics.utils import patches as ul_patches already_patched = getattr(ul_patches.imread, "__name__", "") == "tifffile_imread" if already_patched and not force: return _original_imread = ul_patches.imread def tifffile_imread( filename: str, flags: int = cv2.IMREAD_COLOR ) -> Optional[np.ndarray]: """Replacement for [`ultralytics.utils.patches.imread()`](venv/lib/python3.12/site-packages/ultralytics/utils/patches.py:20). - For `.tif/.tiff`, uses `tifffile.imread()` and preserves dtype (e.g. uint16). - For other formats, falls back to Ultralytics' original implementation. - Always returns HWC (3 dims). For grayscale, returns (H, W, 1) or (H, W, 3) depending on requested flags. """ ext = os.path.splitext(filename)[1].lower() if ext in (".tif", ".tiff"): arr = tifffile.imread(filename) # Normalize common shapes: # - (H, W) -> (H, W, 1) # - (C, H, W) -> (H, W, C) (heuristic) if arr is None: return None if ( arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[0] < arr.shape[1] ): arr = np.transpose(arr, (1, 2, 0)) if arr.ndim == 2: arr = arr[..., None] # Ultralytics expects BGR ordering when `channels=3`. # For grayscale data we replicate channels (no scaling, no quantization). if flags != cv2.IMREAD_GRAYSCALE: if arr.shape[2] == 1: arr = np.repeat(arr, 3, axis=2) elif arr.shape[2] >= 3: arr = arr[:, :, :3] # Ensure contiguous array for downstream OpenCV ops. return np.ascontiguousarray(arr) return _original_imread(filename, flags) # Patch the canonical reference. ul_patches.imread = tifffile_imread # Patch common module-level imports (some Ultralytics modules do `from ... import imread`). # Importing these modules is safe and helps ensure the patched function is used. try: import ultralytics.data.base as _ul_base _ul_base.imread = tifffile_imread except Exception: pass try: import ultralytics.data.loaders as _ul_loaders _ul_loaders.imread = tifffile_imread except Exception: pass # Patch trainer normalization: default divides by 255 regardless of input dtype. from ultralytics.models.yolo.detect import train as detect_train _orig_preprocess_batch = detect_train.DetectionTrainer.preprocess_batch def preprocess_batch_16bit(self, batch: dict) -> dict: # type: ignore[override] # Start from upstream behavior to keep device placement + multiscale identical, # but replace the 255 division with dtype-aware scaling. for k, v in batch.items(): if isinstance(v, torch.Tensor): batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda") img = batch.get("img") if isinstance(img, torch.Tensor): # Decide scaling denom based on dtype (avoid expensive reductions if possible). if img.dtype == torch.uint8: denom = 255.0 elif img.dtype == torch.uint16: denom = 65535.0 elif img.dtype.is_floating_point: # Assume already in 0-1 range if float. denom = 1.0 else: # Generic integer fallback. try: denom = float(torch.iinfo(img.dtype).max) except Exception: denom = 255.0 batch["img"] = img.float() / denom # Multi-scale branch copied from upstream to avoid re-introducing `/255` scaling. if getattr(self.args, "multi_scale", False): import math import random import torch.nn as nn imgs = batch["img"] sz = ( random.randrange( int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride) ) // self.stride * self.stride ) sf = sz / max(imgs.shape[2:]) if sf != 1: ns = [ math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:] ] imgs = nn.functional.interpolate( imgs, size=ns, mode="bilinear", align_corners=False ) batch["img"] = imgs return batch detect_train.DetectionTrainer.preprocess_batch = preprocess_batch_16bit # Tag function to make it easier to detect patch state. setattr( detect_train.DetectionTrainer.preprocess_batch, "_ultralytics_16bit_patch", True )