diff --git a/src/utils/ultralytics_16bit_patch.py b/src/utils/ultralytics_16bit_patch.py new file mode 100644 index 0000000..adff32d --- /dev/null +++ b/src/utils/ultralytics_16bit_patch.py @@ -0,0 +1,165 @@ +"""Ultralytics runtime patches for 16-bit TIFF training. + +Goals: +- Use `tifffile` to decode `.tif/.tiff` reliably (OpenCV can silently drop bit-depth depending on codec). +- Preserve 16-bit data through the dataloader as `uint16` tensors. +- Fix Ultralytics trainer normalization (default divides by 255) to scale `uint16` correctly. +- Avoid uint8-forcing augmentations by recommending/setting hyp values (handled by caller). + +This module is intended to be imported/called **before** instantiating/using YOLO. +""" + +from __future__ import annotations + +from typing import Optional + + +def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None: + """Apply runtime monkey-patches to Ultralytics to better support 16-bit TIFFs. + + This function is safe to call multiple times. + + Args: + force: If True, re-apply patches even if already applied. + """ + + # Import inside function to ensure patching occurs before YOLO model/dataset is created. + import os + + import cv2 + import numpy as np + import tifffile + import torch + + from ultralytics.utils import patches as ul_patches + + already_patched = getattr(ul_patches.imread, "__name__", "") == "tifffile_imread" + if already_patched and not force: + return + + _original_imread = ul_patches.imread + + def tifffile_imread( + filename: str, flags: int = cv2.IMREAD_COLOR + ) -> Optional[np.ndarray]: + """Replacement for [`ultralytics.utils.patches.imread()`](venv/lib/python3.12/site-packages/ultralytics/utils/patches.py:20). + + - For `.tif/.tiff`, uses `tifffile.imread()` and preserves dtype (e.g. uint16). + - For other formats, falls back to Ultralytics' original implementation. + - Always returns HWC (3 dims). For grayscale, returns (H, W, 1) or (H, W, 3) depending on requested flags. + """ + + ext = os.path.splitext(filename)[1].lower() + if ext in (".tif", ".tiff"): + arr = tifffile.imread(filename) + + # Normalize common shapes: + # - (H, W) -> (H, W, 1) + # - (C, H, W) -> (H, W, C) (heuristic) + if arr is None: + return None + if ( + arr.ndim == 3 + and arr.shape[0] in (1, 3, 4) + and arr.shape[0] < arr.shape[1] + ): + arr = np.transpose(arr, (1, 2, 0)) + if arr.ndim == 2: + arr = arr[..., None] + + # Ultralytics expects BGR ordering when `channels=3`. + # For grayscale data we replicate channels (no scaling, no quantization). + if flags != cv2.IMREAD_GRAYSCALE: + if arr.shape[2] == 1: + arr = np.repeat(arr, 3, axis=2) + elif arr.shape[2] >= 3: + arr = arr[:, :, :3] + + # Ensure contiguous array for downstream OpenCV ops. + return np.ascontiguousarray(arr) + + return _original_imread(filename, flags) + + # Patch the canonical reference. + ul_patches.imread = tifffile_imread + + # Patch common module-level imports (some Ultralytics modules do `from ... import imread`). + # Importing these modules is safe and helps ensure the patched function is used. + try: + import ultralytics.data.base as _ul_base + + _ul_base.imread = tifffile_imread + except Exception: + pass + try: + import ultralytics.data.loaders as _ul_loaders + + _ul_loaders.imread = tifffile_imread + except Exception: + pass + + # Patch trainer normalization: default divides by 255 regardless of input dtype. + from ultralytics.models.yolo.detect import train as detect_train + + _orig_preprocess_batch = detect_train.DetectionTrainer.preprocess_batch + + def preprocess_batch_16bit(self, batch: dict) -> dict: # type: ignore[override] + # Start from upstream behavior to keep device placement + multiscale identical, + # but replace the 255 division with dtype-aware scaling. + for k, v in batch.items(): + if isinstance(v, torch.Tensor): + batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda") + + img = batch.get("img") + if isinstance(img, torch.Tensor): + # Decide scaling denom based on dtype (avoid expensive reductions if possible). + if img.dtype == torch.uint8: + denom = 255.0 + elif img.dtype == torch.uint16: + denom = 65535.0 + elif img.dtype.is_floating_point: + # Assume already in 0-1 range if float. + denom = 1.0 + else: + # Generic integer fallback. + try: + denom = float(torch.iinfo(img.dtype).max) + except Exception: + denom = 255.0 + + batch["img"] = img.float() / denom + + # Multi-scale branch copied from upstream to avoid re-introducing `/255` scaling. + if getattr(self.args, "multi_scale", False): + import math + import random + + import torch.nn as nn + + imgs = batch["img"] + sz = ( + random.randrange( + int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride) + ) + // self.stride + * self.stride + ) + sf = sz / max(imgs.shape[2:]) + if sf != 1: + ns = [ + math.ceil(x * sf / self.stride) * self.stride + for x in imgs.shape[2:] + ] + imgs = nn.functional.interpolate( + imgs, size=ns, mode="bilinear", align_corners=False + ) + batch["img"] = imgs + + return batch + + detect_train.DetectionTrainer.preprocess_batch = preprocess_batch_16bit + + # Tag function to make it easier to detect patch state. + setattr( + detect_train.DetectionTrainer.preprocess_batch, "_ultralytics_16bit_patch", True + )