20 Commits

Author SHA1 Message Date
d998c65665 Updating image splitter 2026-01-12 13:28:00 +02:00
510eabfa94 Adding splitter method 2026-01-05 13:56:57 +02:00
395d263900 Update 2026-01-05 08:59:36 +02:00
e98d287b8a Updating tiff image patch 2026-01-02 12:44:06 +02:00
d25101de2d adding files 2026-01-02 12:40:44 +02:00
f88beef188 Another test 2025-12-19 13:50:49 +02:00
2fd9a2acf4 RGB 2025-12-19 13:31:24 +02:00
2bcd18cc75 Bug fix 2025-12-19 13:13:12 +02:00
5d25378c46 Testing with uint conversion 2025-12-19 13:10:36 +02:00
2b0b48921e Testing more grayscale 2025-12-19 12:02:11 +02:00
b0c05f0225 testing grayscale 2025-12-19 11:55:38 +02:00
97badaa390 Samll update 2025-12-19 11:31:12 +02:00
8f8132ce61 Testing detect 2025-12-19 10:44:11 +02:00
6ae7481e25 Adding debug messages 2025-12-19 10:15:53 +02:00
061f8b3ca2 Fixing pseudo rgb 2025-12-19 09:56:43 +02:00
a8e5db3135 Small change 2025-12-18 13:03:12 +02:00
268ed5175e Appling pseudo channels for RGB 2025-12-18 12:52:13 +02:00
5e9d3b1dc4 Adding logger 2025-12-18 12:04:41 +02:00
7d83e9b9b1 Adding important file 2025-12-17 00:45:56 +02:00
e364d06217 Implementing uint16 reading with tifffile 2025-12-16 23:02:45 +02:00
12 changed files with 743 additions and 184 deletions

View File

@@ -1,57 +0,0 @@
database:
path: data/detections.db
image_repository:
base_path: ''
allowed_extensions:
- .jpg
- .jpeg
- .png
- .tif
- .tiff
- .bmp
models:
default_base_model: yolov8s-seg.pt
models_directory: data/models
base_model_choices:
- yolov8s-seg.pt
- yolo11s-seg.pt
training:
default_epochs: 100
default_batch_size: 16
default_imgsz: 1024
default_patience: 50
default_lr0: 0.01
two_stage:
enabled: false
stage1:
epochs: 20
lr0: 0.0005
patience: 10
freeze: 10
stage2:
epochs: 150
lr0: 0.0003
patience: 30
last_dataset_yaml: /home/martin/code/object_detection/data/datasets/data.yaml
last_dataset_dir: /home/martin/code/object_detection/data/datasets
detection:
default_confidence: 0.25
default_iou: 0.45
max_batch_size: 100
visualization:
bbox_colors:
organelle: '#FF6B6B'
membrane_branch: '#4ECDC4'
default: '#00FF00'
bbox_thickness: 2
font_size: 12
export:
formats:
- csv
- json
- excel
default_format: csv
logging:
level: INFO
file: logs/app.log
format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'

View File

@@ -82,12 +82,12 @@ include-package-data = true
"src.database" = ["*.sql"] "src.database" = ["*.sql"]
[tool.black] [tool.black]
line-length = 88 line-length = 120
target-version = ['py38', 'py39', 'py310', 'py311'] target-version = ['py38', 'py39', 'py310', 'py311']
include = '\.pyi?$' include = '\.pyi?$'
[tool.pylint.messages_control] [tool.pylint.messages_control]
max-line-length = 88 max-line-length = 120
[tool.mypy] [tool.mypy]
python_version = "3.8" python_version = "3.8"

View File

@@ -1303,6 +1303,14 @@ class TrainingTab(QWidget):
sample_image = self._find_first_image(images_dir) sample_image = self._find_first_image(images_dir)
if not sample_image: if not sample_image:
return False return False
# Do not force an RGB cache for TIFF datasets.
# We handle grayscale/16-bit TIFFs via runtime Ultralytics patches that:
# - load TIFFs with `tifffile`
# - replicate grayscale to 3 channels without quantization
# - normalize uint16 correctly during training
if sample_image.suffix.lower() in {".tif", ".tiff"}:
return False
try: try:
img = Image(sample_image) img = Image(sample_image)
return img.pil_image.mode.upper() != "RGB" return img.pil_image.mode.upper() != "RGB"

View File

@@ -250,12 +250,10 @@ class AnnotationCanvasWidget(QWidget):
# Get image data in a format compatible with Qt # Get image data in a format compatible with Qt
if self.current_image.channels in (3, 4): if self.current_image.channels in (3, 4):
image_data = self.current_image.get_rgb() image_data = self.current_image.get_rgb()
height, width = image_data.shape[:2]
else: else:
image_data = self.current_image.get_grayscale() image_data = self.current_image.get_qt_rgb()
height, width = image_data.shape
image_data = np.ascontiguousarray(image_data) height, width = image_data.shape[:2]
bytes_per_line = image_data.strides[0] bytes_per_line = image_data.strides[0]
qimage = QImage( qimage = QImage(
@@ -263,7 +261,7 @@ class AnnotationCanvasWidget(QWidget):
width, width,
height, height,
bytes_per_line, bytes_per_line,
self.current_image.qtimage_format, QImage.Format_RGBX32FPx4, # self.current_image.qtimage_format,
).copy() # Copy so Qt owns the buffer even after numpy array goes out of scope ).copy() # Copy so Qt owns the buffer even after numpy array goes out of scope
self.original_pixmap = QPixmap.fromImage(qimage) self.original_pixmap = QPixmap.fromImage(qimage)

View File

@@ -1,9 +1,13 @@
""" """YOLO model wrapper for the microscopy object detection application.
YOLO model wrapper for the microscopy object detection application.
Provides a clean interface to YOLOv8 for training, validation, and inference. Notes on 16-bit TIFF support:
- Ultralytics training defaults assume 8-bit images and normalize by dividing by 255.
- This project can patch Ultralytics at runtime to decode TIFFs via `tifffile` and
normalize `uint16` correctly.
See [`apply_ultralytics_16bit_tiff_patches()`](src/utils/ultralytics_16bit_patch.py:1).
""" """
from ultralytics import YOLO
from pathlib import Path from pathlib import Path
from typing import Optional, List, Dict, Callable, Any from typing import Optional, List, Dict, Callable, Any
import torch import torch
@@ -11,6 +15,7 @@ import tempfile
import os import os
from src.utils.image import Image from src.utils.image import Image
from src.utils.logger import get_logger from src.utils.logger import get_logger
from src.utils.ultralytics_16bit_patch import apply_ultralytics_16bit_tiff_patches
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -31,6 +36,9 @@ class YOLOWrapper:
self.device = "cuda" if torch.cuda.is_available() else "cpu" self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"YOLOWrapper initialized with device: {self.device}") logger.info(f"YOLOWrapper initialized with device: {self.device}")
# Apply Ultralytics runtime patches early (before first import/instantiation of YOLO datasets/trainers).
apply_ultralytics_16bit_tiff_patches()
def load_model(self) -> bool: def load_model(self) -> bool:
""" """
Load YOLO model from path. Load YOLO model from path.
@@ -40,6 +48,9 @@ class YOLOWrapper:
""" """
try: try:
logger.info(f"Loading YOLO model from {self.model_path}") logger.info(f"Loading YOLO model from {self.model_path}")
# Import YOLO lazily to ensure runtime patches are applied first.
from ultralytics import YOLO
self.model = YOLO(self.model_path) self.model = YOLO(self.model_path)
self.model.to(self.device) self.model.to(self.device)
logger.info("Model loaded successfully") logger.info("Model loaded successfully")
@@ -89,6 +100,16 @@ class YOLOWrapper:
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}" f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
) )
# Defaults for 16-bit safety: disable augmentations that force uint8 and HSV ops that assume 0..255.
# Users can override by passing explicit kwargs.
kwargs.setdefault("mosaic", 0.0)
kwargs.setdefault("mixup", 0.0)
kwargs.setdefault("cutmix", 0.0)
kwargs.setdefault("copy_paste", 0.0)
kwargs.setdefault("hsv_h", 0.0)
kwargs.setdefault("hsv_s", 0.0)
kwargs.setdefault("hsv_v", 0.0)
# Train the model # Train the model
results = self.model.train( results = self.model.train(
data=data_yaml, data=data_yaml,
@@ -171,9 +192,11 @@ class YOLOWrapper:
prepared_source, cleanup_path = self._prepare_source(source) prepared_source, cleanup_path = self._prepare_source(source)
try: try:
logger.info(f"Running inference on {source}") logger.info(
f"Running inference on {source} -> prepared_source {prepared_source}"
)
results = self.model.predict( results = self.model.predict(
source=prepared_source, source=source,
conf=conf, conf=conf,
iou=iou, iou=iou,
save=save, save=save,
@@ -236,17 +259,11 @@ class YOLOWrapper:
if source_path.is_file(): if source_path.is_file():
try: try:
img_obj = Image(source_path) img_obj = Image(source_path)
pil_img = img_obj.pil_image
if len(pil_img.getbands()) == 1:
rgb_img = img_obj.convert_grayscale_to_rgb_preserve_range()
else:
rgb_img = pil_img.convert("RGB")
suffix = source_path.suffix or ".png" suffix = source_path.suffix or ".png"
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
tmp_path = tmp.name tmp_path = tmp.name
tmp.close() tmp.close()
rgb_img.save(tmp_path) img_obj.save(tmp_path)
cleanup_path = tmp_path cleanup_path = tmp_path
logger.info( logger.info(
f"Converted image {source_path} to RGB for inference at {tmp_path}" f"Converted image {source_path} to RGB for inference at {tmp_path}"

View File

@@ -54,7 +54,7 @@ class ConfigManager:
"models_directory": "data/models", "models_directory": "data/models",
"base_model_choices": [ "base_model_choices": [
"yolov8s-seg.pt", "yolov8s-seg.pt",
"yolov11s-seg.pt", "yolo11s-seg.pt",
], ],
}, },
"training": { "training": {
@@ -225,6 +225,4 @@ class ConfigManager:
def get_allowed_extensions(self) -> list: def get_allowed_extensions(self) -> list:
"""Get list of allowed image file extensions.""" """Get list of allowed image file extensions."""
return self.get( return self.get("image_repository.allowed_extensions", Image.SUPPORTED_EXTENSIONS)
"image_repository.allowed_extensions", Image.SUPPORTED_EXTENSIONS
)

View File

@@ -6,16 +6,52 @@ import cv2
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
from PIL import Image as PILImage
from src.utils.logger import get_logger from src.utils.logger import get_logger
from src.utils.file_utils import validate_file_path, is_image_file from src.utils.file_utils import validate_file_path, is_image_file
from PySide6.QtGui import QImage from PySide6.QtGui import QImage
from tifffile import imread, imwrite
logger = get_logger(__name__) logger = get_logger(__name__)
def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.5) -> np.ndarray:
"""
Convert a grayscale image to a pseudo-RGB image using a gamma correction.
Args:
arr: Input grayscale image as numpy array
Returns:
Pseudo-RGB image as numpy array
"""
if arr.ndim != 2:
raise ValueError("Input array must be a grayscale image with shape (H, W)")
a1 = arr.copy().astype(np.float32)
a1 -= np.percentile(a1, 2)
a1[a1 < 0] = 0
p999 = np.percentile(a1, 99.9)
a1[a1 > p999] = p999
a1 /= a1.max()
if 0:
a2 = a1.copy()
a2 = a2**gamma
a2 /= a2.max()
a3 = a1.copy()
p9999 = np.percentile(a3, 99.99)
a3[a3 > p9999] = p9999
a3 /= a3.max()
return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
# return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
# return np.stack([a1, a2, a3], axis=0)
class ImageLoadError(Exception): class ImageLoadError(Exception):
"""Exception raised when an image cannot be loaded.""" """Exception raised when an image cannot be loaded."""
@@ -54,7 +90,6 @@ class Image:
""" """
self.path = Path(image_path) self.path = Path(image_path)
self._data: Optional[np.ndarray] = None self._data: Optional[np.ndarray] = None
self._pil_image: Optional[PILImage.Image] = None
self._width: int = 0 self._width: int = 0
self._height: int = 0 self._height: int = 0
self._channels: int = 0 self._channels: int = 0
@@ -80,40 +115,39 @@ class Image:
if not is_image_file(str(self.path), self.SUPPORTED_EXTENSIONS): if not is_image_file(str(self.path), self.SUPPORTED_EXTENSIONS):
ext = self.path.suffix.lower() ext = self.path.suffix.lower()
raise ImageLoadError( raise ImageLoadError(
f"Unsupported image format: {ext}. " f"Unsupported image format: {ext}. " f"Supported formats: {', '.join(self.SUPPORTED_EXTENSIONS)}"
f"Supported formats: {', '.join(self.SUPPORTED_EXTENSIONS)}"
) )
try: try:
# Load with OpenCV (returns BGR format) if self.path.suffix.lower() in [".tif", ".tiff"]:
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED) self._data = imread(str(self.path))
else:
raise NotImplementedError("RGB is not implemented")
# Load with OpenCV (returns BGR format)
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
if self._data is None: if self._data is None:
raise ImageLoadError(f"Failed to load image with OpenCV: {self.path}") raise ImageLoadError(f"Failed to load image with OpenCV: {self.path}")
# Extract metadata # Extract metadata
self._height, self._width = self._data.shape[:2] # print(self._data.shape)
self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1 if len(self._data.shape) == 2:
self._height, self._width = self._data.shape[:2]
self._channels = 1
else:
self._height, self._width = self._data.shape[1:]
self._channels = self._data.shape[0]
# self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1
self._format = self.path.suffix.lower().lstrip(".") self._format = self.path.suffix.lower().lstrip(".")
self._size_bytes = self.path.stat().st_size self._size_bytes = self.path.stat().st_size
self._dtype = self._data.dtype self._dtype = self._data.dtype
# Load PIL version for compatibility (convert BGR to RGB) if 0:
if self._channels == 3: logger.info(
rgb_data = cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB) f"Successfully loaded image: {self.path.name} "
self._pil_image = PILImage.fromarray(rgb_data) f"({self._width}x{self._height}, {self._channels} channels, "
elif self._channels == 4: f"{self._format.upper()})"
rgba_data = cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA) )
self._pil_image = PILImage.fromarray(rgba_data)
else:
# Grayscale
self._pil_image = PILImage.fromarray(self._data)
logger.info(
f"Successfully loaded image: {self.path.name} "
f"({self._width}x{self._height}, {self._channels} channels, "
f"{self._format.upper()})"
)
except Exception as e: except Exception as e:
logger.error(f"Error loading image {self.path}: {e}") logger.error(f"Error loading image {self.path}: {e}")
@@ -131,18 +165,6 @@ class Image:
raise ImageLoadError("Image data not available") raise ImageLoadError("Image data not available")
return self._data return self._data
@property
def pil_image(self) -> PILImage.Image:
"""
Get image data as PIL Image (RGB or grayscale).
Returns:
PIL Image object
"""
if self._pil_image is None:
raise ImageLoadError("PIL image not available")
return self._pil_image
@property @property
def width(self) -> int: def width(self) -> int:
"""Get image width in pixels.""" """Get image width in pixels."""
@@ -187,6 +209,7 @@ class Image:
@property @property
def dtype(self) -> np.dtype: def dtype(self) -> np.dtype:
"""Get the data type of the image array.""" """Get the data type of the image array."""
if self._dtype is None: if self._dtype is None:
raise ImageLoadError("Image dtype not available") raise ImageLoadError("Image dtype not available")
return self._dtype return self._dtype
@@ -206,8 +229,10 @@ class Image:
elif self._channels == 1: elif self._channels == 1:
if self._dtype == np.uint16: if self._dtype == np.uint16:
return QImage.Format_Grayscale16 return QImage.Format_Grayscale16
else: elif self._dtype == np.uint8:
return QImage.Format_Grayscale8 return QImage.Format_Grayscale8
elif self._dtype == np.float32:
return QImage.Format_BGR30
else: else:
raise ImageLoadError(f"Unsupported number of channels: {self._channels}") raise ImageLoadError(f"Unsupported number of channels: {self._channels}")
@@ -218,6 +243,12 @@ class Image:
Returns: Returns:
Image data in RGB format as numpy array Image data in RGB format as numpy array
""" """
if self.channels == 1:
img = get_pseudo_rgb(self.data)
self._dtype = img.dtype
return img
raise NotImplementedError
if self._channels == 3: if self._channels == 3:
return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB) return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
elif self._channels == 4: elif self._channels == 4:
@@ -225,6 +256,18 @@ class Image:
else: else:
return self._data return self._data
def get_qt_rgb(self) -> np.ascontiguousarray:
# we keep data as (C, H, W)
_img = self.get_rgb()
img = np.zeros((self.height, self.width, 4), dtype=np.float32)
img[..., 0] = _img[0] # R gradient
img[..., 1] = _img[1] # G gradient
img[..., 2] = _img[2] # B constant
img[..., 3] = 1.0 # A = 1.0 (opaque)
return np.ascontiguousarray(img)
def get_grayscale(self) -> np.ndarray: def get_grayscale(self) -> np.ndarray:
""" """
Get image as grayscale numpy array. Get image as grayscale numpy array.
@@ -277,43 +320,26 @@ class Image:
""" """
return self._channels >= 3 return self._channels >= 3
def convert_grayscale_to_rgb_preserve_range( def save(self, path: Union[str, Path], pseudo_rgb: bool = True) -> None:
self,
) -> PILImage.Image:
"""Convert a single-channel PIL image to RGB while preserving dynamic range.
Returns: if self.channels == 1:
PIL Image in RGB mode with intensities normalized to 0-255. if pseudo_rgb:
""" img = get_pseudo_rgb(self.data)
if self._channels == 3: print("Image.save", img.shape)
return self.pil_image else:
img = np.repeat(self.data, 3, axis=2)
grayscale = self.data
if grayscale.ndim == 3:
grayscale = grayscale[:, :, 0]
original_dtype = grayscale.dtype
grayscale = grayscale.astype(np.float32)
if grayscale.size == 0:
return PILImage.new("RGB", self.shape, color=(0, 0, 0))
if np.issubdtype(original_dtype, np.integer):
denom = float(max(np.iinfo(original_dtype).max, 1))
else: else:
max_val = float(grayscale.max()) raise NotImplementedError("Only grayscale images are supported for now.")
denom = max(max_val, 1.0)
grayscale = np.clip(grayscale / denom, 0.0, 1.0) imwrite(path, data=img)
grayscale_u8 = (grayscale * 255.0).round().astype(np.uint8)
rgb_arr = np.repeat(grayscale_u8[:, :, None], 3, axis=2)
return PILImage.fromarray(rgb_arr, mode="RGB")
def __repr__(self) -> str: def __repr__(self) -> str:
"""String representation of the Image object.""" """String representation of the Image object."""
return ( return (
f"Image(path='{self.path.name}', " f"Image(path='{self.path.name}', "
f"shape=({self._width}x{self._height}x{self._channels}), " # Display as HxWxC to match the conventional NumPy shape semantics.
f"shape=({self._height}x{self._width}x{self._channels}), "
f"format={self._format}, " f"format={self._format}, "
f"size={self.size_mb:.2f}MB)" f"size={self.size_mb:.2f}MB)"
) )
@@ -321,3 +347,15 @@ class Image:
def __str__(self) -> str: def __str__(self) -> str:
"""String representation of the Image object.""" """String representation of the Image object."""
return self.__repr__() return self.__repr__()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--path", type=str, required=True)
args = parser.parse_args()
img = Image(args.path)
img.save(args.path + "test.tif")
print(img)

View File

@@ -18,7 +18,13 @@ class UT:
self.rois = None self.rois = None
if no_labels: if no_labels:
self.rois = ImagejRoi.fromfile(self.roifile_fn) self.rois = ImagejRoi.fromfile(self.roifile_fn)
self.stem = self.roifile_fn.stem.split("Roi-")[1] print(self.roifile_fn.stem)
print(self.roifile_fn.parent.parts[-1])
if "Roi-" in self.roifile_fn.stem:
self.stem = self.roifile_fn.stem.split("Roi-")[1]
else:
self.stem = self.roifile_fn.parent.parts[-1]
else: else:
self.roifile_fn = roifile_fn / roifile_fn.parts[-1] self.roifile_fn = roifile_fn / roifile_fn.parts[-1]
self.stem = self.roifile_fn.stem self.stem = self.roifile_fn.stem
@@ -95,9 +101,7 @@ class UT:
for i, roi in enumerate(self.rois): for i, roi in enumerate(self.rois):
rc = roi.subpixel_coordinates rc = roi.subpixel_coordinates
if rc is None: if rc is None:
print( print(f"No coordinates: {self.roifile_fn}, element {i}, out of {len(self.rois)}")
f"No coordinates: {self.roifile_fn}, element {i}, out of {len(self.rois)}"
)
continue continue
xmn, ymn = rc.min(axis=0) xmn, ymn = rc.min(axis=0)
xmx, ymx = rc.max(axis=0) xmx, ymx = rc.max(axis=0)
@@ -143,6 +147,9 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
# print(args)
# aa
for path in args.input: for path in args.input:
print("Path:", path) print("Path:", path)
if not args.no_labels: if not args.no_labels:
@@ -152,6 +159,7 @@ if __name__ == "__main__":
else: else:
for rfn in Path(path).glob("*.zip"): for rfn in Path(path).glob("*.zip"):
# if Path(path).suffix == ".zip":
print("Roi FN:", rfn) print("Roi FN:", rfn)
ut = UT(rfn, args.no_labels) ut = UT(rfn, args.no_labels)
ut.export_rois(args.output, class_index=0) ut.export_rois(args.output, class_index=0)

353
src/utils/image_splitter.py Normal file
View File

@@ -0,0 +1,353 @@
import numpy as np
from pathlib import Path
from tifffile import imread, imwrite
from shapely.geometry import LineString
from copy import deepcopy
from scipy.ndimage import zoom
# debug
from src.utils.image import Image
from show_yolo_seg import draw_annotations
import pylab as plt
import cv2
class Label:
def __init__(self, yolo_annotation: str):
class_id, bbox, polygon = self.parse_yolo_annotation(yolo_annotation)
self.class_id = class_id
self.bbox = bbox
self.polygon = polygon
def parse_yolo_annotation(self, yolo_annotation: str):
class_id, *coords = yolo_annotation.split()
class_id = int(class_id)
bbox = np.array(coords[:4], dtype=np.float32)
polygon = np.array(coords[4:], dtype=np.float32).reshape(-1, 2) if len(coords) > 4 else None
if not any(np.isclose(polygon[0], polygon[-1])):
polygon = np.vstack([polygon, polygon[0]])
return class_id, bbox, polygon
def offset_label(
self,
img_w,
img_h,
distance: float = 1.0,
cap_style: int = 2,
join_style: int = 2,
):
if self.polygon is None:
self.bbox = np.array(
[
self.bbox[0] - distance if self.bbox[0] - distance > 0 else 0,
self.bbox[1] - distance if self.bbox[1] - distance > 0 else 0,
self.bbox[2] + distance if self.bbox[2] + distance < 1 else 1,
self.bbox[3] + distance if self.bbox[3] + distance < 1 else 1,
],
dtype=np.float32,
)
return self.bbox
def coords_are_normalized(coords):
# If every coordinate is between 0 and 1 (inclusive-ish), assume normalized
print(coords)
# if not coords:
# return False
return all(max(coords.flatten)) <= 1.001
def poly_to_pts(coords, img_w, img_h):
# coords: [x1 y1 x2 y2 ...] either normalized or absolute
# if coords_are_normalized(coords):
coords = [coords[i] * (img_w if i % 2 == 0 else img_h) for i in range(len(coords))]
pts = np.array(coords, dtype=np.int32).reshape(-1, 2)
return pts
pts = poly_to_pts(self.polygon, img_w, img_h)
line = LineString(pts)
# Buffer distance in pixels
buffered = line.buffer(distance=distance, cap_style=cap_style, join_style=join_style)
self.polygon = np.array(buffered.exterior.coords, dtype=np.float32) / (img_w, img_h)
xmn, ymn = self.polygon.min(axis=0)
xmx, ymx = self.polygon.max(axis=0)
xc = (xmn + xmx) / 2
yc = (ymn + ymx) / 2
bw = xmx - xmn
bh = ymx - ymn
self.bbox = np.array([xc, yc, bw, bh], dtype=np.float32)
return self.bbox, self.polygon
def translate(self, x, y, scale_x, scale_y):
self.bbox[0] -= x
self.bbox[0] *= scale_x
self.bbox[1] -= y
self.bbox[1] *= scale_y
self.bbox[2] *= scale_x
self.bbox[3] *= scale_y
if self.polygon is not None:
self.polygon[:, 0] -= x
self.polygon[:, 0] *= scale_x
self.polygon[:, 1] -= y
self.polygon[:, 1] *= scale_y
def in_range(self, hrange, wrange):
xc, yc, h, w = self.bbox
x1 = xc - w / 2
y1 = yc - h / 2
x2 = xc + w / 2
y2 = yc + h / 2
truth_val = (
xc >= wrange[0]
and x1 <= wrange[1]
and x2 >= wrange[0]
and x2 <= wrange[1]
and y1 >= hrange[0]
and y1 <= hrange[1]
and y2 >= hrange[0]
and y2 <= hrange[1]
)
print(x1, x2, wrange, y1, y2, hrange, truth_val)
return truth_val
def to_string(self, bbox: list = None, polygon: list = None):
if bbox is None:
bbox = self.bbox
if polygon is None:
polygon = self.polygon
coords = " ".join([f"{x:.6f}" for x in self.bbox])
if self.polygon is not None:
coords += " " + " ".join([f"{x:.6f} {y:.6f}" for x, y in self.polygon])
return f"{self.class_id} {coords}"
def __str__(self):
return f"Class: {self.class_id}, BBox: {self.bbox}, Polygon: {self.polygon}"
class YoloLabelReader:
def __init__(self, label_path: Path):
self.label_path = label_path
self.labels = self._read_labels()
def _read_labels(self):
with open(self.label_path, "r") as f:
labels = [Label(line) for line in f.readlines()]
return labels
def get_labels(self, hrange, wrange):
"""hrange and wrange are tuples of (start, end) normalized to [0, 1]"""
labels = []
# print(hrange, wrange)
for lbl in self.labels:
# print(lbl)
if lbl.in_range(hrange, wrange):
labels.append(lbl)
return labels if len(labels) > 0 else None
def __get_item__(self, index):
return self.labels[index]
def __len__(self):
return len(self.labels)
def __iter__(self):
return iter(self.labels)
class ImageSplitter:
def __init__(self, image_path: Path, label_path: Path):
self.image = imread(image_path)
self.image_path = image_path
self.label_path = label_path
if not label_path.exists():
print(f"Label file {label_path} not found")
self.labels = None
else:
self.labels = YoloLabelReader(label_path)
def split_into_tiles(self, patch_size: tuple = (2, 2)):
"""Split image into patches of size patch_size"""
hstep, wstep = (
self.image.shape[0] // patch_size[0],
self.image.shape[1] // patch_size[1],
)
h, w = self.image.shape[:2]
for i in range(patch_size[0]):
for j in range(patch_size[1]):
tile_reference = f"i{i}j{j}"
hrange = (i * hstep / h, (i + 1) * hstep / h)
wrange = (j * wstep / w, (j + 1) * wstep / w)
tile = self.image[i * hstep : (i + 1) * hstep, j * wstep : (j + 1) * wstep]
labels = None
if self.labels is not None:
labels = deepcopy(self.labels.get_labels(hrange, wrange))
print(id(labels))
if labels is not None:
print(hrange[0], wrange[0])
for l in labels:
print(l.bbox)
[l.translate(wrange[0], hrange[0], 2, 2) for l in labels]
print("translated")
for l in labels:
print(l.bbox)
# print(labels)
yield tile_reference, tile, labels
def split_respective_to_label(self, padding: int = 67):
if self.labels is None:
raise ValueError("No labels found. Only images having labels can be split.")
for i, label in enumerate(self.labels):
tile_reference = f"_lbl-{i+1:02d}"
# print(label.bbox)
xc_norm, yc_norm, h_norm, w_norm = label.bbox # normalized coords
xc, yc, h, w = [
int(np.round(f))
for f in [
xc_norm * self.image.shape[1],
yc_norm * self.image.shape[0],
h_norm * self.image.shape[0],
w_norm * self.image.shape[1],
]
] # image coords
# print("img coords:", xc, yc, h, w)
pad_xneg = padding + 1 # int(w / 2) + padding
pad_xpos = padding # int(w / 2) + padding
pad_yneg = padding + 1 # int(h / 2) + padding
pad_ypos = padding # int(h / 2) + padding
if xc - pad_xneg < 0:
pad_xneg = xc
if pad_xpos + xc > self.image.shape[1]:
pad_xpos = self.image.shape[1] - xc
if yc - pad_yneg < 0:
pad_yneg = yc
if pad_ypos + yc > self.image.shape[0]:
pad_ypos = self.image.shape[0] - yc
# print("pads:", pad_xneg, pad_xpos, pad_yneg, pad_ypos)
tile = self.image[
yc - pad_yneg : yc + pad_ypos,
xc - pad_xneg : xc + pad_xpos,
]
ny, nx = tile.shape
x_offset = pad_xneg
y_offset = pad_yneg
# print("tile shape:", tile.shape)
yolo_annotation = f"{label.class_id} {x_offset/nx} {y_offset/ny} {h_norm} {w_norm} "
print(yolo_annotation)
yolo_annotation += " ".join(
[
f"{(x*self.image.shape[1]-(xc - x_offset))/nx:.6f} {(y*self.image.shape[0]-(yc-y_offset))/ny:.6f}"
for x, y in label.polygon
]
)
new_label = Label(yolo_annotation=yolo_annotation)
yield tile_reference, tile, [new_label]
def main(args):
if args.output:
args.output.mkdir(exist_ok=True, parents=True)
(args.output / "images").mkdir(exist_ok=True)
(args.output / "images-zoomed").mkdir(exist_ok=True)
(args.output / "labels").mkdir(exist_ok=True)
for image_path in (args.input / "images").glob("*.tif"):
data = ImageSplitter(
image_path=image_path,
label_path=(args.input / "labels" / image_path.stem).with_suffix(".txt"),
)
if args.split_around_label:
data = data.split_respective_to_label(padding=args.padding)
else:
data = data.split_into_tiles(patch_size=args.patch_size)
for tile_reference, tile, labels in data:
print()
print(tile_reference, tile.shape, labels) # len(labels) if labels else None)
# { debug
debug = False
if debug:
plt.figure(figsize=(10, 10 * tile.shape[0] / tile.shape[1]))
if labels is None:
plt.imshow(tile, cmap="gray")
plt.axis("off")
plt.title(f"{image_path.name} ({tile_reference})")
plt.show()
continue
print(labels[0].bbox)
# Draw annotations
out = draw_annotations(
cv2.cvtColor((tile / tile.max() * 255).astype(np.uint8), cv2.COLOR_GRAY2BGR),
[l.to_string() for l in labels],
alpha=0.1,
)
# Convert BGR -> RGB for matplotlib display
out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
plt.imshow(out_rgb)
plt.axis("off")
plt.title(f"{image_path.name} ({tile_reference})")
plt.show()
# } debug
if args.output:
imwrite(args.output / "images" / f"{image_path.stem}_{tile_reference}.tif", tile)
scale = 5
tile_zoomed = zoom(tile, zoom=scale)
imwrite(args.output / "images-zoomed" / f"{image_path.stem}_{tile_reference}.tif", tile_zoomed)
if labels is not None:
with open(args.output / "labels" / f"{image_path.stem}_{tile_reference}.txt", "w") as f:
for label in labels:
label.offset_label(tile.shape[1], tile.shape[0])
f.write(label.to_string() + "\n")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input", type=Path)
parser.add_argument("-o", "--output", type=Path)
parser.add_argument(
"-p",
"--patch-size",
nargs=2,
type=int,
default=[2, 2],
help="Number of patches along height and width, rows and columns, respectively",
)
parser.add_argument(
"-sal",
"--split-around-label",
action="store_true",
help="If enabled, the image will be split around the label and for each label, a separate image will be created.",
)
parser.add_argument(
"--padding",
type=int,
default=67,
help="Padding around the label when splitting around the label.",
)
args = parser.parse_args()
main(args)

1
src/utils/show_yolo_seg.py Symbolic link
View File

@@ -0,0 +1 @@
../../tests/show_yolo_seg.py

View File

@@ -0,0 +1,156 @@
"""Ultralytics runtime patches for 16-bit TIFF training.
Goals:
- Use `tifffile` to decode `.tif/.tiff` reliably (OpenCV can silently drop bit-depth depending on codec).
- Preserve 16-bit data through the dataloader as `uint16` tensors.
- Fix Ultralytics trainer normalization (default divides by 255) to scale `uint16` correctly.
- Avoid uint8-forcing augmentations by recommending/setting hyp values (handled by caller).
This module is intended to be imported/called **before** instantiating/using YOLO.
"""
from __future__ import annotations
from typing import Optional
from src.utils.logger import get_logger
logger = get_logger(__name__)
def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
"""Apply runtime monkey-patches to Ultralytics to better support 16-bit TIFFs.
This function is safe to call multiple times.
Args:
force: If True, re-apply patches even if already applied.
"""
# Import inside function to ensure patching occurs before YOLO model/dataset is created.
import os
import cv2
import numpy as np
# import tifffile
import torch
from src.utils.image import Image
from ultralytics.utils import patches as ul_patches
already_patched = getattr(ul_patches.imread, "__name__", "") == "tifffile_imread"
if already_patched and not force:
return
_original_imread = ul_patches.imread
def tifffile_imread(filename: str, flags: int = cv2.IMREAD_COLOR, pseudo_rgb: bool = True) -> Optional[np.ndarray]:
"""Replacement for [`ultralytics.utils.patches.imread()`](venv/lib/python3.12/site-packages/ultralytics/utils/patches.py:20).
- For `.tif/.tiff`, uses `tifffile.imread()` and preserves dtype (e.g. uint16).
- For other formats, falls back to Ultralytics' original implementation.
- Always returns HWC (3 dims). For grayscale, returns (H, W, 1) or (H, W, 3) depending on requested flags.
"""
# print("here")
# return _original_imread(filename, flags)
ext = os.path.splitext(filename)[1].lower()
if ext in (".tif", ".tiff"):
arr = Image(filename).get_qt_rgb()[:, :, :3]
# Normalize common shapes:
# - (H, W) -> (H, W, 1)
# - (C, H, W) -> (H, W, C) (heuristic)
if arr is None:
return None
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[0] < arr.shape[1]:
arr = np.transpose(arr, (1, 2, 0))
if arr.ndim == 2:
arr = arr[..., None]
# Ensure contiguous array for downstream OpenCV ops.
# logger.info(f"Loading with monkey-patched imread: {filename}")
arr = arr.astype(np.float32)
arr /= arr.max()
arr *= 2**16 - 1
arr = arr.astype(np.uint16)
return np.ascontiguousarray(arr)
# logger.info(f"Loading with original imread: {filename}")
return _original_imread(filename, flags)
# Patch the canonical reference.
ul_patches.imread = tifffile_imread
# Patch common module-level imports (some Ultralytics modules do `from ... import imread`).
# Importing these modules is safe and helps ensure the patched function is used.
try:
import ultralytics.data.base as _ul_base
_ul_base.imread = tifffile_imread
except Exception:
pass
try:
import ultralytics.data.loaders as _ul_loaders
_ul_loaders.imread = tifffile_imread
except Exception:
pass
# Patch trainer normalization: default divides by 255 regardless of input dtype.
from ultralytics.models.yolo.detect import train as detect_train
_orig_preprocess_batch = detect_train.DetectionTrainer.preprocess_batch
def preprocess_batch_16bit(self, batch: dict) -> dict: # type: ignore[override]
# Start from upstream behavior to keep device placement + multiscale identical,
# but replace the 255 division with dtype-aware scaling.
logger.info(f"Preprocessing batch with monkey-patched preprocess_batch")
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")
img = batch.get("img")
if isinstance(img, torch.Tensor):
# Decide scaling denom based on dtype (avoid expensive reductions if possible).
if img.dtype == torch.uint8:
denom = 255.0
elif img.dtype == torch.uint16:
denom = 65535.0
elif img.dtype.is_floating_point:
# Assume already in 0-1 range if float.
denom = 1.0
else:
# Generic integer fallback.
try:
denom = float(torch.iinfo(img.dtype).max)
except Exception:
denom = 255.0
batch["img"] = img.float() / denom
# Multi-scale branch copied from upstream to avoid re-introducing `/255` scaling.
if getattr(self.args, "multi_scale", False):
import math
import random
import torch.nn as nn
imgs = batch["img"]
sz = (
random.randrange(int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride))
// self.stride
* self.stride
)
sf = sz / max(imgs.shape[2:])
if sf != 1:
ns = [math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]]
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
batch["img"] = imgs
return batch
detect_train.DetectionTrainer.preprocess_batch = preprocess_batch_16bit
# Tag function to make it easier to detect patch state.
setattr(detect_train.DetectionTrainer.preprocess_batch, "_ultralytics_16bit_patch", True)

View File

@@ -17,6 +17,9 @@ import matplotlib.pyplot as plt
import argparse import argparse
from pathlib import Path from pathlib import Path
import random import random
from shapely.geometry import LineString
from src.utils.image import Image
def parse_label_line(line): def parse_label_line(line):
@@ -53,23 +56,30 @@ def yolo_bbox_to_xyxy(coords, img_w, img_h):
def poly_to_pts(coords, img_w, img_h): def poly_to_pts(coords, img_w, img_h):
# coords: [x1 y1 x2 y2 ...] either normalized or absolute # coords: [x1 y1 x2 y2 ...] either normalized or absolute
if coords_are_normalized(coords[4:]): if coords_are_normalized(coords[4:]):
coords = [ coords = [coords[i] * (img_w if i % 2 == 0 else img_h) for i in range(len(coords))]
coords[i] * (img_w if i % 2 == 0 else img_h) for i in range(len(coords))
]
pts = np.array(coords, dtype=np.int32).reshape(-1, 2) pts = np.array(coords, dtype=np.int32).reshape(-1, 2)
return pts return pts
def random_color_for_class(cls): def random_color_for_class(cls):
random.seed(cls) # deterministic per class random.seed(cls) # deterministic per class
return tuple(int(x) for x in np.array([random.randint(0, 255) for _ in range(3)])) return (
0,
0,
255,
) # tuple(int(x) for x in np.array([random.randint(0, 255) for _ in range(3)]))
def draw_annotations(img, labels, alpha=0.4, draw_bbox_for_poly=True): def draw_annotations(img, labels, alpha=0.4, draw_bbox_for_poly=True):
# img: BGR numpy array # img: BGR numpy array
overlay = img.copy() overlay = img.copy()
h, w = img.shape[:2] h, w = img.shape[:2]
for cls, coords in labels: for line in labels:
if isinstance(line, str):
cls, coords = parse_label_line(line)
if isinstance(line, tuple):
cls, coords = line
if not coords: if not coords:
continue continue
# polygon case (>=6 coordinates) # polygon case (>=6 coordinates)
@@ -77,25 +87,33 @@ def draw_annotations(img, labels, alpha=0.4, draw_bbox_for_poly=True):
color = random_color_for_class(cls) color = random_color_for_class(cls)
x1, y1, x2, y2 = yolo_bbox_to_xyxy(coords[:4], w, h) x1, y1, x2, y2 = yolo_bbox_to_xyxy(coords[:4], w, h)
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2) print(x1, y1, x2, y2)
cv2.rectangle(img, (x1, y1), (x2, y2), color, 1)
pts = poly_to_pts(coords[4:], w, h) pts = poly_to_pts(coords[4:], w, h)
# line = LineString(pts)
# # Buffer distance in pixels
# buffered = line.buffer(3, cap_style=2, join_style=2)
# coords = np.array(buffered.exterior.coords, dtype=np.int32)
# cv2.fillPoly(overlay, [coords], color=(255, 255, 255))
# fill on overlay # fill on overlay
cv2.fillPoly(overlay, [pts], color) cv2.fillPoly(overlay, [pts], color)
# outline on base image # outline on base image
cv2.polylines(img, [pts], isClosed=True, color=color, thickness=2) cv2.polylines(img, [pts], isClosed=True, color=color, thickness=1)
# put class text at first point # put class text at first point
x, y = int(pts[0, 0]), int(pts[0, 1]) - 6 x, y = int(pts[0, 0]), int(pts[0, 1]) - 6
cv2.putText( if 0:
img, cv2.putText(
str(cls), img,
(x, max(6, y)), str(cls),
cv2.FONT_HERSHEY_SIMPLEX, (x, max(6, y)),
0.6, cv2.FONT_HERSHEY_SIMPLEX,
(255, 255, 255), 0.6,
2, (255, 255, 255),
cv2.LINE_AA, 2,
) cv2.LINE_AA,
)
# YOLO bbox case (4 coords) # YOLO bbox case (4 coords)
elif len(coords) == 4: elif len(coords) == 4:
@@ -135,21 +153,21 @@ def load_labels_file(label_path):
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="Show YOLO segmentation / polygon annotations")
description="Show YOLO segmentation / polygon annotations"
)
parser.add_argument("image", type=str, help="Path to image file") parser.add_argument("image", type=str, help="Path to image file")
parser.add_argument("labels", type=str, help="Path to YOLO label file (polygons)") parser.add_argument("--labels", type=str, help="Path to YOLO label file (polygons)")
parser.add_argument( parser.add_argument("--alpha", type=float, default=0.4, help="Polygon fill alpha (0..1)")
"--alpha", type=float, default=0.4, help="Polygon fill alpha (0..1)" parser.add_argument("--no-bbox", action="store_true", help="Don't draw bounding boxes for polygons")
)
parser.add_argument(
"--no-bbox", action="store_true", help="Don't draw bounding boxes for polygons"
)
args = parser.parse_args() args = parser.parse_args()
print(args)
img_path = Path(args.image) img_path = Path(args.image)
lbl_path = Path(args.labels) if args.labels:
lbl_path = Path(args.labels)
else:
lbl_path = img_path.with_suffix(".txt")
lbl_path = Path(str(lbl_path).replace("images", "labels"))
if not img_path.exists(): if not img_path.exists():
print("Image not found:", img_path) print("Image not found:", img_path)
@@ -158,7 +176,9 @@ def main():
print("Label file not found:", lbl_path) print("Label file not found:", lbl_path)
sys.exit(1) sys.exit(1)
img = cv2.imread(str(img_path), cv2.IMREAD_COLOR) # img = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
img = (Image(img_path).get_qt_rgb() * 255).astype(np.uint8)
if img is None: if img is None:
print("Could not load image:", img_path) print("Could not load image:", img_path)
sys.exit(1) sys.exit(1)
@@ -167,15 +187,34 @@ def main():
if not labels: if not labels:
print("No labels parsed from", lbl_path) print("No labels parsed from", lbl_path)
# continue and just show image # continue and just show image
out = draw_annotations( out = draw_annotations(img.copy(), labels, alpha=args.alpha, draw_bbox_for_poly=(not args.no_bbox))
img.copy(), labels, alpha=args.alpha, draw_bbox_for_poly=(not args.no_bbox)
) lclass, coords = labels[0]
print(lclass, coords)
bbox = coords[:4]
print("bbox", bbox)
bbox = np.array(bbox) * np.array([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
yc, xc, h, w = bbox
print("bbox", bbox)
polyline = np.array(coords[4:]).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
print("pl", coords[4:])
print("pl", polyline)
# Convert BGR -> RGB for matplotlib display # Convert BGR -> RGB for matplotlib display
out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB) # out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
out_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# out_rgb = Image()
plt.figure(figsize=(10, 10 * out.shape[0] / out.shape[1])) plt.figure(figsize=(10, 10 * out.shape[0] / out.shape[1]))
plt.imshow(out_rgb) plt.imshow(out_rgb)
plt.axis("off") plt.plot(polyline[:, 0], polyline[:, 1], "y", linewidth=2)
plt.plot(
[yc - h / 2, yc - h / 2, yc + h / 2, yc + h / 2, yc - h / 2],
[xc - w / 2, xc + w / 2, xc + w / 2, xc - w / 2, xc - w / 2],
"r",
linewidth=2,
)
# plt.axis("off")
plt.title(f"{img_path.name} ({lbl_path.name})") plt.title(f"{img_path.name} ({lbl_path.name})")
plt.show() plt.show()