Adding important file
This commit is contained in:
165
src/utils/ultralytics_16bit_patch.py
Normal file
165
src/utils/ultralytics_16bit_patch.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
"""Ultralytics runtime patches for 16-bit TIFF training.
|
||||||
|
|
||||||
|
Goals:
|
||||||
|
- Use `tifffile` to decode `.tif/.tiff` reliably (OpenCV can silently drop bit-depth depending on codec).
|
||||||
|
- Preserve 16-bit data through the dataloader as `uint16` tensors.
|
||||||
|
- Fix Ultralytics trainer normalization (default divides by 255) to scale `uint16` correctly.
|
||||||
|
- Avoid uint8-forcing augmentations by recommending/setting hyp values (handled by caller).
|
||||||
|
|
||||||
|
This module is intended to be imported/called **before** instantiating/using YOLO.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
|
||||||
|
"""Apply runtime monkey-patches to Ultralytics to better support 16-bit TIFFs.
|
||||||
|
|
||||||
|
This function is safe to call multiple times.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
force: If True, re-apply patches even if already applied.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Import inside function to ensure patching occurs before YOLO model/dataset is created.
|
||||||
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import tifffile
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ultralytics.utils import patches as ul_patches
|
||||||
|
|
||||||
|
already_patched = getattr(ul_patches.imread, "__name__", "") == "tifffile_imread"
|
||||||
|
if already_patched and not force:
|
||||||
|
return
|
||||||
|
|
||||||
|
_original_imread = ul_patches.imread
|
||||||
|
|
||||||
|
def tifffile_imread(
|
||||||
|
filename: str, flags: int = cv2.IMREAD_COLOR
|
||||||
|
) -> Optional[np.ndarray]:
|
||||||
|
"""Replacement for [`ultralytics.utils.patches.imread()`](venv/lib/python3.12/site-packages/ultralytics/utils/patches.py:20).
|
||||||
|
|
||||||
|
- For `.tif/.tiff`, uses `tifffile.imread()` and preserves dtype (e.g. uint16).
|
||||||
|
- For other formats, falls back to Ultralytics' original implementation.
|
||||||
|
- Always returns HWC (3 dims). For grayscale, returns (H, W, 1) or (H, W, 3) depending on requested flags.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ext = os.path.splitext(filename)[1].lower()
|
||||||
|
if ext in (".tif", ".tiff"):
|
||||||
|
arr = tifffile.imread(filename)
|
||||||
|
|
||||||
|
# Normalize common shapes:
|
||||||
|
# - (H, W) -> (H, W, 1)
|
||||||
|
# - (C, H, W) -> (H, W, C) (heuristic)
|
||||||
|
if arr is None:
|
||||||
|
return None
|
||||||
|
if (
|
||||||
|
arr.ndim == 3
|
||||||
|
and arr.shape[0] in (1, 3, 4)
|
||||||
|
and arr.shape[0] < arr.shape[1]
|
||||||
|
):
|
||||||
|
arr = np.transpose(arr, (1, 2, 0))
|
||||||
|
if arr.ndim == 2:
|
||||||
|
arr = arr[..., None]
|
||||||
|
|
||||||
|
# Ultralytics expects BGR ordering when `channels=3`.
|
||||||
|
# For grayscale data we replicate channels (no scaling, no quantization).
|
||||||
|
if flags != cv2.IMREAD_GRAYSCALE:
|
||||||
|
if arr.shape[2] == 1:
|
||||||
|
arr = np.repeat(arr, 3, axis=2)
|
||||||
|
elif arr.shape[2] >= 3:
|
||||||
|
arr = arr[:, :, :3]
|
||||||
|
|
||||||
|
# Ensure contiguous array for downstream OpenCV ops.
|
||||||
|
return np.ascontiguousarray(arr)
|
||||||
|
|
||||||
|
return _original_imread(filename, flags)
|
||||||
|
|
||||||
|
# Patch the canonical reference.
|
||||||
|
ul_patches.imread = tifffile_imread
|
||||||
|
|
||||||
|
# Patch common module-level imports (some Ultralytics modules do `from ... import imread`).
|
||||||
|
# Importing these modules is safe and helps ensure the patched function is used.
|
||||||
|
try:
|
||||||
|
import ultralytics.data.base as _ul_base
|
||||||
|
|
||||||
|
_ul_base.imread = tifffile_imread
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
import ultralytics.data.loaders as _ul_loaders
|
||||||
|
|
||||||
|
_ul_loaders.imread = tifffile_imread
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Patch trainer normalization: default divides by 255 regardless of input dtype.
|
||||||
|
from ultralytics.models.yolo.detect import train as detect_train
|
||||||
|
|
||||||
|
_orig_preprocess_batch = detect_train.DetectionTrainer.preprocess_batch
|
||||||
|
|
||||||
|
def preprocess_batch_16bit(self, batch: dict) -> dict: # type: ignore[override]
|
||||||
|
# Start from upstream behavior to keep device placement + multiscale identical,
|
||||||
|
# but replace the 255 division with dtype-aware scaling.
|
||||||
|
for k, v in batch.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")
|
||||||
|
|
||||||
|
img = batch.get("img")
|
||||||
|
if isinstance(img, torch.Tensor):
|
||||||
|
# Decide scaling denom based on dtype (avoid expensive reductions if possible).
|
||||||
|
if img.dtype == torch.uint8:
|
||||||
|
denom = 255.0
|
||||||
|
elif img.dtype == torch.uint16:
|
||||||
|
denom = 65535.0
|
||||||
|
elif img.dtype.is_floating_point:
|
||||||
|
# Assume already in 0-1 range if float.
|
||||||
|
denom = 1.0
|
||||||
|
else:
|
||||||
|
# Generic integer fallback.
|
||||||
|
try:
|
||||||
|
denom = float(torch.iinfo(img.dtype).max)
|
||||||
|
except Exception:
|
||||||
|
denom = 255.0
|
||||||
|
|
||||||
|
batch["img"] = img.float() / denom
|
||||||
|
|
||||||
|
# Multi-scale branch copied from upstream to avoid re-introducing `/255` scaling.
|
||||||
|
if getattr(self.args, "multi_scale", False):
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
imgs = batch["img"]
|
||||||
|
sz = (
|
||||||
|
random.randrange(
|
||||||
|
int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride)
|
||||||
|
)
|
||||||
|
// self.stride
|
||||||
|
* self.stride
|
||||||
|
)
|
||||||
|
sf = sz / max(imgs.shape[2:])
|
||||||
|
if sf != 1:
|
||||||
|
ns = [
|
||||||
|
math.ceil(x * sf / self.stride) * self.stride
|
||||||
|
for x in imgs.shape[2:]
|
||||||
|
]
|
||||||
|
imgs = nn.functional.interpolate(
|
||||||
|
imgs, size=ns, mode="bilinear", align_corners=False
|
||||||
|
)
|
||||||
|
batch["img"] = imgs
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
detect_train.DetectionTrainer.preprocess_batch = preprocess_batch_16bit
|
||||||
|
|
||||||
|
# Tag function to make it easier to detect patch state.
|
||||||
|
setattr(
|
||||||
|
detect_train.DetectionTrainer.preprocess_batch, "_ultralytics_16bit_patch", True
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user