Files
object-segmentation/src/utils/ultralytics_16bit_patch.py

157 lines
5.8 KiB
Python
Raw Normal View History

2025-12-17 00:45:56 +02:00
"""Ultralytics runtime patches for 16-bit TIFF training.
Goals:
- Use `tifffile` to decode `.tif/.tiff` reliably (OpenCV can silently drop bit-depth depending on codec).
- Preserve 16-bit data through the dataloader as `uint16` tensors.
- Fix Ultralytics trainer normalization (default divides by 255) to scale `uint16` correctly.
- Avoid uint8-forcing augmentations by recommending/setting hyp values (handled by caller).
This module is intended to be imported/called **before** instantiating/using YOLO.
"""
from __future__ import annotations
from typing import Optional
2025-12-18 12:04:41 +02:00
from src.utils.logger import get_logger
logger = get_logger(__name__)
2025-12-17 00:45:56 +02:00
def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
"""Apply runtime monkey-patches to Ultralytics to better support 16-bit TIFFs.
This function is safe to call multiple times.
Args:
force: If True, re-apply patches even if already applied.
"""
# Import inside function to ensure patching occurs before YOLO model/dataset is created.
import os
import cv2
import numpy as np
2025-12-19 09:56:43 +02:00
# import tifffile
2025-12-17 00:45:56 +02:00
import torch
2025-12-19 09:56:43 +02:00
from src.utils.image import Image
2025-12-17 00:45:56 +02:00
from ultralytics.utils import patches as ul_patches
already_patched = getattr(ul_patches.imread, "__name__", "") == "tifffile_imread"
if already_patched and not force:
return
_original_imread = ul_patches.imread
2026-01-02 12:44:06 +02:00
def tifffile_imread(filename: str, flags: int = cv2.IMREAD_COLOR, pseudo_rgb: bool = True) -> Optional[np.ndarray]:
2025-12-17 00:45:56 +02:00
"""Replacement for [`ultralytics.utils.patches.imread()`](venv/lib/python3.12/site-packages/ultralytics/utils/patches.py:20).
- For `.tif/.tiff`, uses `tifffile.imread()` and preserves dtype (e.g. uint16).
- For other formats, falls back to Ultralytics' original implementation.
- Always returns HWC (3 dims). For grayscale, returns (H, W, 1) or (H, W, 3) depending on requested flags.
"""
2025-12-19 13:10:36 +02:00
# print("here")
# return _original_imread(filename, flags)
2025-12-17 00:45:56 +02:00
ext = os.path.splitext(filename)[1].lower()
if ext in (".tif", ".tiff"):
2025-12-19 09:56:43 +02:00
arr = Image(filename).get_qt_rgb()[:, :, :3]
2025-12-17 00:45:56 +02:00
# Normalize common shapes:
# - (H, W) -> (H, W, 1)
# - (C, H, W) -> (H, W, C) (heuristic)
if arr is None:
return None
2026-01-02 12:44:06 +02:00
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[0] < arr.shape[1]:
2025-12-17 00:45:56 +02:00
arr = np.transpose(arr, (1, 2, 0))
if arr.ndim == 2:
arr = arr[..., None]
# Ensure contiguous array for downstream OpenCV ops.
2025-12-19 11:55:38 +02:00
# logger.info(f"Loading with monkey-patched imread: {filename}")
2026-01-02 12:44:06 +02:00
arr = arr.astype(np.float32)
arr /= arr.max()
arr *= 2**16 - 1
arr = arr.astype(np.uint16)
2025-12-17 00:45:56 +02:00
return np.ascontiguousarray(arr)
2025-12-19 11:55:38 +02:00
# logger.info(f"Loading with original imread: {filename}")
2025-12-17 00:45:56 +02:00
return _original_imread(filename, flags)
# Patch the canonical reference.
ul_patches.imread = tifffile_imread
# Patch common module-level imports (some Ultralytics modules do `from ... import imread`).
# Importing these modules is safe and helps ensure the patched function is used.
try:
import ultralytics.data.base as _ul_base
_ul_base.imread = tifffile_imread
except Exception:
pass
try:
import ultralytics.data.loaders as _ul_loaders
_ul_loaders.imread = tifffile_imread
except Exception:
pass
# Patch trainer normalization: default divides by 255 regardless of input dtype.
from ultralytics.models.yolo.detect import train as detect_train
_orig_preprocess_batch = detect_train.DetectionTrainer.preprocess_batch
def preprocess_batch_16bit(self, batch: dict) -> dict: # type: ignore[override]
# Start from upstream behavior to keep device placement + multiscale identical,
# but replace the 255 division with dtype-aware scaling.
2025-12-19 10:15:53 +02:00
logger.info(f"Preprocessing batch with monkey-patched preprocess_batch")
2025-12-17 00:45:56 +02:00
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")
img = batch.get("img")
if isinstance(img, torch.Tensor):
# Decide scaling denom based on dtype (avoid expensive reductions if possible).
if img.dtype == torch.uint8:
denom = 255.0
elif img.dtype == torch.uint16:
denom = 65535.0
elif img.dtype.is_floating_point:
# Assume already in 0-1 range if float.
denom = 1.0
else:
# Generic integer fallback.
try:
denom = float(torch.iinfo(img.dtype).max)
except Exception:
denom = 255.0
batch["img"] = img.float() / denom
# Multi-scale branch copied from upstream to avoid re-introducing `/255` scaling.
if getattr(self.args, "multi_scale", False):
import math
import random
import torch.nn as nn
imgs = batch["img"]
sz = (
2026-01-02 12:44:06 +02:00
random.randrange(int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride))
2025-12-17 00:45:56 +02:00
// self.stride
* self.stride
)
sf = sz / max(imgs.shape[2:])
if sf != 1:
2026-01-02 12:44:06 +02:00
ns = [math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]]
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
2025-12-17 00:45:56 +02:00
batch["img"] = imgs
return batch
detect_train.DetectionTrainer.preprocess_batch = preprocess_batch_16bit
# Tag function to make it easier to detect patch state.
2026-01-02 12:44:06 +02:00
setattr(detect_train.DetectionTrainer.preprocess_batch, "_ultralytics_16bit_patch", True)