Compare commits
2 Commits
main
...
monkey-pat
| Author | SHA1 | Date | |
|---|---|---|---|
| 7d83e9b9b1 | |||
| e364d06217 |
@@ -1,57 +0,0 @@
|
||||
database:
|
||||
path: data/detections.db
|
||||
image_repository:
|
||||
base_path: ''
|
||||
allowed_extensions:
|
||||
- .jpg
|
||||
- .jpeg
|
||||
- .png
|
||||
- .tif
|
||||
- .tiff
|
||||
- .bmp
|
||||
models:
|
||||
default_base_model: yolov8s-seg.pt
|
||||
models_directory: data/models
|
||||
base_model_choices:
|
||||
- yolov8s-seg.pt
|
||||
- yolo11s-seg.pt
|
||||
training:
|
||||
default_epochs: 100
|
||||
default_batch_size: 16
|
||||
default_imgsz: 1024
|
||||
default_patience: 50
|
||||
default_lr0: 0.01
|
||||
two_stage:
|
||||
enabled: false
|
||||
stage1:
|
||||
epochs: 20
|
||||
lr0: 0.0005
|
||||
patience: 10
|
||||
freeze: 10
|
||||
stage2:
|
||||
epochs: 150
|
||||
lr0: 0.0003
|
||||
patience: 30
|
||||
last_dataset_yaml: /home/martin/code/object_detection/data/datasets/data.yaml
|
||||
last_dataset_dir: /home/martin/code/object_detection/data/datasets
|
||||
detection:
|
||||
default_confidence: 0.25
|
||||
default_iou: 0.45
|
||||
max_batch_size: 100
|
||||
visualization:
|
||||
bbox_colors:
|
||||
organelle: '#FF6B6B'
|
||||
membrane_branch: '#4ECDC4'
|
||||
default: '#00FF00'
|
||||
bbox_thickness: 2
|
||||
font_size: 12
|
||||
export:
|
||||
formats:
|
||||
- csv
|
||||
- json
|
||||
- excel
|
||||
default_format: csv
|
||||
logging:
|
||||
level: INFO
|
||||
file: logs/app.log
|
||||
format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
@@ -1303,6 +1303,14 @@ class TrainingTab(QWidget):
|
||||
sample_image = self._find_first_image(images_dir)
|
||||
if not sample_image:
|
||||
return False
|
||||
|
||||
# Do not force an RGB cache for TIFF datasets.
|
||||
# We handle grayscale/16-bit TIFFs via runtime Ultralytics patches that:
|
||||
# - load TIFFs with `tifffile`
|
||||
# - replicate grayscale to 3 channels without quantization
|
||||
# - normalize uint16 correctly during training
|
||||
if sample_image.suffix.lower() in {".tif", ".tiff"}:
|
||||
return False
|
||||
try:
|
||||
img = Image(sample_image)
|
||||
return img.pil_image.mode.upper() != "RGB"
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
"""
|
||||
YOLO model wrapper for the microscopy object detection application.
|
||||
Provides a clean interface to YOLOv8 for training, validation, and inference.
|
||||
"""YOLO model wrapper for the microscopy object detection application.
|
||||
|
||||
Notes on 16-bit TIFF support:
|
||||
- Ultralytics training defaults assume 8-bit images and normalize by dividing by 255.
|
||||
- This project can patch Ultralytics at runtime to decode TIFFs via `tifffile` and
|
||||
normalize `uint16` correctly.
|
||||
|
||||
See [`apply_ultralytics_16bit_tiff_patches()`](src/utils/ultralytics_16bit_patch.py:1).
|
||||
"""
|
||||
|
||||
from ultralytics import YOLO
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Callable, Any
|
||||
import torch
|
||||
@@ -11,6 +15,7 @@ import tempfile
|
||||
import os
|
||||
from src.utils.image import Image
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.ultralytics_16bit_patch import apply_ultralytics_16bit_tiff_patches
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -31,6 +36,9 @@ class YOLOWrapper:
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"YOLOWrapper initialized with device: {self.device}")
|
||||
|
||||
# Apply Ultralytics runtime patches early (before first import/instantiation of YOLO datasets/trainers).
|
||||
apply_ultralytics_16bit_tiff_patches()
|
||||
|
||||
def load_model(self) -> bool:
|
||||
"""
|
||||
Load YOLO model from path.
|
||||
@@ -40,6 +48,9 @@ class YOLOWrapper:
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Loading YOLO model from {self.model_path}")
|
||||
# Import YOLO lazily to ensure runtime patches are applied first.
|
||||
from ultralytics import YOLO
|
||||
|
||||
self.model = YOLO(self.model_path)
|
||||
self.model.to(self.device)
|
||||
logger.info("Model loaded successfully")
|
||||
@@ -89,6 +100,16 @@ class YOLOWrapper:
|
||||
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
|
||||
)
|
||||
|
||||
# Defaults for 16-bit safety: disable augmentations that force uint8 and HSV ops that assume 0..255.
|
||||
# Users can override by passing explicit kwargs.
|
||||
kwargs.setdefault("mosaic", 0.0)
|
||||
kwargs.setdefault("mixup", 0.0)
|
||||
kwargs.setdefault("cutmix", 0.0)
|
||||
kwargs.setdefault("copy_paste", 0.0)
|
||||
kwargs.setdefault("hsv_h", 0.0)
|
||||
kwargs.setdefault("hsv_s", 0.0)
|
||||
kwargs.setdefault("hsv_v", 0.0)
|
||||
|
||||
# Train the model
|
||||
results = self.model.train(
|
||||
data=data_yaml,
|
||||
|
||||
@@ -313,7 +313,8 @@ class Image:
|
||||
"""String representation of the Image object."""
|
||||
return (
|
||||
f"Image(path='{self.path.name}', "
|
||||
f"shape=({self._width}x{self._height}x{self._channels}), "
|
||||
# Display as HxWxC to match the conventional NumPy shape semantics.
|
||||
f"shape=({self._height}x{self._width}x{self._channels}), "
|
||||
f"format={self._format}, "
|
||||
f"size={self.size_mb:.2f}MB)"
|
||||
)
|
||||
|
||||
165
src/utils/ultralytics_16bit_patch.py
Normal file
165
src/utils/ultralytics_16bit_patch.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Ultralytics runtime patches for 16-bit TIFF training.
|
||||
|
||||
Goals:
|
||||
- Use `tifffile` to decode `.tif/.tiff` reliably (OpenCV can silently drop bit-depth depending on codec).
|
||||
- Preserve 16-bit data through the dataloader as `uint16` tensors.
|
||||
- Fix Ultralytics trainer normalization (default divides by 255) to scale `uint16` correctly.
|
||||
- Avoid uint8-forcing augmentations by recommending/setting hyp values (handled by caller).
|
||||
|
||||
This module is intended to be imported/called **before** instantiating/using YOLO.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
|
||||
"""Apply runtime monkey-patches to Ultralytics to better support 16-bit TIFFs.
|
||||
|
||||
This function is safe to call multiple times.
|
||||
|
||||
Args:
|
||||
force: If True, re-apply patches even if already applied.
|
||||
"""
|
||||
|
||||
# Import inside function to ensure patching occurs before YOLO model/dataset is created.
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import tifffile
|
||||
import torch
|
||||
|
||||
from ultralytics.utils import patches as ul_patches
|
||||
|
||||
already_patched = getattr(ul_patches.imread, "__name__", "") == "tifffile_imread"
|
||||
if already_patched and not force:
|
||||
return
|
||||
|
||||
_original_imread = ul_patches.imread
|
||||
|
||||
def tifffile_imread(
|
||||
filename: str, flags: int = cv2.IMREAD_COLOR
|
||||
) -> Optional[np.ndarray]:
|
||||
"""Replacement for [`ultralytics.utils.patches.imread()`](venv/lib/python3.12/site-packages/ultralytics/utils/patches.py:20).
|
||||
|
||||
- For `.tif/.tiff`, uses `tifffile.imread()` and preserves dtype (e.g. uint16).
|
||||
- For other formats, falls back to Ultralytics' original implementation.
|
||||
- Always returns HWC (3 dims). For grayscale, returns (H, W, 1) or (H, W, 3) depending on requested flags.
|
||||
"""
|
||||
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
if ext in (".tif", ".tiff"):
|
||||
arr = tifffile.imread(filename)
|
||||
|
||||
# Normalize common shapes:
|
||||
# - (H, W) -> (H, W, 1)
|
||||
# - (C, H, W) -> (H, W, C) (heuristic)
|
||||
if arr is None:
|
||||
return None
|
||||
if (
|
||||
arr.ndim == 3
|
||||
and arr.shape[0] in (1, 3, 4)
|
||||
and arr.shape[0] < arr.shape[1]
|
||||
):
|
||||
arr = np.transpose(arr, (1, 2, 0))
|
||||
if arr.ndim == 2:
|
||||
arr = arr[..., None]
|
||||
|
||||
# Ultralytics expects BGR ordering when `channels=3`.
|
||||
# For grayscale data we replicate channels (no scaling, no quantization).
|
||||
if flags != cv2.IMREAD_GRAYSCALE:
|
||||
if arr.shape[2] == 1:
|
||||
arr = np.repeat(arr, 3, axis=2)
|
||||
elif arr.shape[2] >= 3:
|
||||
arr = arr[:, :, :3]
|
||||
|
||||
# Ensure contiguous array for downstream OpenCV ops.
|
||||
return np.ascontiguousarray(arr)
|
||||
|
||||
return _original_imread(filename, flags)
|
||||
|
||||
# Patch the canonical reference.
|
||||
ul_patches.imread = tifffile_imread
|
||||
|
||||
# Patch common module-level imports (some Ultralytics modules do `from ... import imread`).
|
||||
# Importing these modules is safe and helps ensure the patched function is used.
|
||||
try:
|
||||
import ultralytics.data.base as _ul_base
|
||||
|
||||
_ul_base.imread = tifffile_imread
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
import ultralytics.data.loaders as _ul_loaders
|
||||
|
||||
_ul_loaders.imread = tifffile_imread
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Patch trainer normalization: default divides by 255 regardless of input dtype.
|
||||
from ultralytics.models.yolo.detect import train as detect_train
|
||||
|
||||
_orig_preprocess_batch = detect_train.DetectionTrainer.preprocess_batch
|
||||
|
||||
def preprocess_batch_16bit(self, batch: dict) -> dict: # type: ignore[override]
|
||||
# Start from upstream behavior to keep device placement + multiscale identical,
|
||||
# but replace the 255 division with dtype-aware scaling.
|
||||
for k, v in batch.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")
|
||||
|
||||
img = batch.get("img")
|
||||
if isinstance(img, torch.Tensor):
|
||||
# Decide scaling denom based on dtype (avoid expensive reductions if possible).
|
||||
if img.dtype == torch.uint8:
|
||||
denom = 255.0
|
||||
elif img.dtype == torch.uint16:
|
||||
denom = 65535.0
|
||||
elif img.dtype.is_floating_point:
|
||||
# Assume already in 0-1 range if float.
|
||||
denom = 1.0
|
||||
else:
|
||||
# Generic integer fallback.
|
||||
try:
|
||||
denom = float(torch.iinfo(img.dtype).max)
|
||||
except Exception:
|
||||
denom = 255.0
|
||||
|
||||
batch["img"] = img.float() / denom
|
||||
|
||||
# Multi-scale branch copied from upstream to avoid re-introducing `/255` scaling.
|
||||
if getattr(self.args, "multi_scale", False):
|
||||
import math
|
||||
import random
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
imgs = batch["img"]
|
||||
sz = (
|
||||
random.randrange(
|
||||
int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride)
|
||||
)
|
||||
// self.stride
|
||||
* self.stride
|
||||
)
|
||||
sf = sz / max(imgs.shape[2:])
|
||||
if sf != 1:
|
||||
ns = [
|
||||
math.ceil(x * sf / self.stride) * self.stride
|
||||
for x in imgs.shape[2:]
|
||||
]
|
||||
imgs = nn.functional.interpolate(
|
||||
imgs, size=ns, mode="bilinear", align_corners=False
|
||||
)
|
||||
batch["img"] = imgs
|
||||
|
||||
return batch
|
||||
|
||||
detect_train.DetectionTrainer.preprocess_batch = preprocess_batch_16bit
|
||||
|
||||
# Tag function to make it easier to detect patch state.
|
||||
setattr(
|
||||
detect_train.DetectionTrainer.preprocess_batch, "_ultralytics_16bit_patch", True
|
||||
)
|
||||
Reference in New Issue
Block a user