Updating tiff image patch

This commit is contained in:
2026-01-02 12:44:06 +02:00
parent d25101de2d
commit e98d287b8a
3 changed files with 26 additions and 26 deletions

View File

@@ -45,9 +45,7 @@ def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
_original_imread = ul_patches.imread
def tifffile_imread(
filename: str, flags: int = cv2.IMREAD_COLOR, pseudo_rgb: bool = True
) -> Optional[np.ndarray]:
def tifffile_imread(filename: str, flags: int = cv2.IMREAD_COLOR, pseudo_rgb: bool = True) -> Optional[np.ndarray]:
"""Replacement for [`ultralytics.utils.patches.imread()`](venv/lib/python3.12/site-packages/ultralytics/utils/patches.py:20).
- For `.tif/.tiff`, uses `tifffile.imread()` and preserves dtype (e.g. uint16).
@@ -65,19 +63,17 @@ def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
# - (C, H, W) -> (H, W, C) (heuristic)
if arr is None:
return None
if (
arr.ndim == 3
and arr.shape[0] in (1, 3, 4)
and arr.shape[0] < arr.shape[1]
):
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[0] < arr.shape[1]:
arr = np.transpose(arr, (1, 2, 0))
if arr.ndim == 2:
arr = arr[..., None]
# Ensure contiguous array for downstream OpenCV ops.
# logger.info(f"Loading with monkey-patched imread: {filename}")
arr *= 2**8 - 1
arr = arr.astype(np.uint8)
arr = arr.astype(np.float32)
arr /= arr.max()
arr *= 2**16 - 1
arr = arr.astype(np.uint16)
return np.ascontiguousarray(arr)
# logger.info(f"Loading with original imread: {filename}")
@@ -142,21 +138,14 @@ def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
imgs = batch["img"]
sz = (
random.randrange(
int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride)
)
random.randrange(int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride))
// self.stride
* self.stride
)
sf = sz / max(imgs.shape[2:])
if sf != 1:
ns = [
math.ceil(x * sf / self.stride) * self.stride
for x in imgs.shape[2:]
]
imgs = nn.functional.interpolate(
imgs, size=ns, mode="bilinear", align_corners=False
)
ns = [math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]]
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
batch["img"] = imgs
return batch
@@ -164,6 +153,4 @@ def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
detect_train.DetectionTrainer.preprocess_batch = preprocess_batch_16bit
# Tag function to make it easier to detect patch state.
setattr(
detect_train.DetectionTrainer.preprocess_batch, "_ultralytics_16bit_patch", True
)
setattr(detect_train.DetectionTrainer.preprocess_batch, "_ultralytics_16bit_patch", True)