7 Commits

7 changed files with 476 additions and 121 deletions

View File

@@ -1,57 +0,0 @@
database:
path: data/detections.db
image_repository:
base_path: ''
allowed_extensions:
- .jpg
- .jpeg
- .png
- .tif
- .tiff
- .bmp
models:
default_base_model: yolov8s-seg.pt
models_directory: data/models
base_model_choices:
- yolov8s-seg.pt
- yolo11s-seg.pt
training:
default_epochs: 100
default_batch_size: 16
default_imgsz: 1024
default_patience: 50
default_lr0: 0.01
two_stage:
enabled: false
stage1:
epochs: 20
lr0: 0.0005
patience: 10
freeze: 10
stage2:
epochs: 150
lr0: 0.0003
patience: 30
last_dataset_yaml: /home/martin/code/object_detection/data/datasets/data.yaml
last_dataset_dir: /home/martin/code/object_detection/data/datasets
detection:
default_confidence: 0.25
default_iou: 0.45
max_batch_size: 100
visualization:
bbox_colors:
organelle: '#FF6B6B'
membrane_branch: '#4ECDC4'
default: '#00FF00'
bbox_thickness: 2
font_size: 12
export:
formats:
- csv
- json
- excel
default_format: csv
logging:
level: INFO
file: logs/app.log
format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'

View File

@@ -34,7 +34,7 @@ from PySide6.QtWidgets import (
from src.database.db_manager import DatabaseManager from src.database.db_manager import DatabaseManager
from src.model.yolo_wrapper import YOLOWrapper from src.model.yolo_wrapper import YOLOWrapper
from src.utils.config_manager import ConfigManager from src.utils.config_manager import ConfigManager
from src.utils.image import Image, convert_grayscale_to_rgb_preserve_range from src.utils.image import Image
from src.utils.logger import get_logger from src.utils.logger import get_logger
@@ -1303,6 +1303,14 @@ class TrainingTab(QWidget):
sample_image = self._find_first_image(images_dir) sample_image = self._find_first_image(images_dir)
if not sample_image: if not sample_image:
return False return False
# Do not force an RGB cache for TIFF datasets.
# We handle grayscale/16-bit TIFFs via runtime Ultralytics patches that:
# - load TIFFs with `tifffile`
# - replicate grayscale to 3 channels without quantization
# - normalize uint16 correctly during training
if sample_image.suffix.lower() in {".tif", ".tiff"}:
return False
try: try:
img = Image(sample_image) img = Image(sample_image)
return img.pil_image.mode.upper() != "RGB" return img.pil_image.mode.upper() != "RGB"
@@ -1368,7 +1376,7 @@ class TrainingTab(QWidget):
img_obj = Image(src) img_obj = Image(src)
pil_img = img_obj.pil_image pil_img = img_obj.pil_image
if len(pil_img.getbands()) == 1: if len(pil_img.getbands()) == 1:
rgb_img = convert_grayscale_to_rgb_preserve_range(pil_img) rgb_img = img_obj.convert_grayscale_to_rgb_preserve_range()
else: else:
rgb_img = pil_img.convert("RGB") rgb_img = pil_img.convert("RGB")
rgb_img.save(dst) rgb_img.save(dst)

View File

@@ -1,16 +1,21 @@
""" """YOLO model wrapper for the microscopy object detection application.
YOLO model wrapper for the microscopy object detection application.
Provides a clean interface to YOLOv8 for training, validation, and inference. Notes on 16-bit TIFF support:
- Ultralytics training defaults assume 8-bit images and normalize by dividing by 255.
- This project can patch Ultralytics at runtime to decode TIFFs via `tifffile` and
normalize `uint16` correctly.
See [`apply_ultralytics_16bit_tiff_patches()`](src/utils/ultralytics_16bit_patch.py:1).
""" """
from ultralytics import YOLO
from pathlib import Path from pathlib import Path
from typing import Optional, List, Dict, Callable, Any from typing import Optional, List, Dict, Callable, Any
import torch import torch
import tempfile import tempfile
import os import os
from src.utils.image import Image, convert_grayscale_to_rgb_preserve_range from src.utils.image import Image
from src.utils.logger import get_logger from src.utils.logger import get_logger
from src.utils.ultralytics_16bit_patch import apply_ultralytics_16bit_tiff_patches
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -31,6 +36,9 @@ class YOLOWrapper:
self.device = "cuda" if torch.cuda.is_available() else "cpu" self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"YOLOWrapper initialized with device: {self.device}") logger.info(f"YOLOWrapper initialized with device: {self.device}")
# Apply Ultralytics runtime patches early (before first import/instantiation of YOLO datasets/trainers).
apply_ultralytics_16bit_tiff_patches()
def load_model(self) -> bool: def load_model(self) -> bool:
""" """
Load YOLO model from path. Load YOLO model from path.
@@ -40,6 +48,9 @@ class YOLOWrapper:
""" """
try: try:
logger.info(f"Loading YOLO model from {self.model_path}") logger.info(f"Loading YOLO model from {self.model_path}")
# Import YOLO lazily to ensure runtime patches are applied first.
from ultralytics import YOLO
self.model = YOLO(self.model_path) self.model = YOLO(self.model_path)
self.model.to(self.device) self.model.to(self.device)
logger.info("Model loaded successfully") logger.info("Model loaded successfully")
@@ -89,6 +100,16 @@ class YOLOWrapper:
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}" f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
) )
# Defaults for 16-bit safety: disable augmentations that force uint8 and HSV ops that assume 0..255.
# Users can override by passing explicit kwargs.
kwargs.setdefault("mosaic", 0.0)
kwargs.setdefault("mixup", 0.0)
kwargs.setdefault("cutmix", 0.0)
kwargs.setdefault("copy_paste", 0.0)
kwargs.setdefault("hsv_h", 0.0)
kwargs.setdefault("hsv_s", 0.0)
kwargs.setdefault("hsv_v", 0.0)
# Train the model # Train the model
results = self.model.train( results = self.model.train(
data=data_yaml, data=data_yaml,
@@ -238,7 +259,7 @@ class YOLOWrapper:
img_obj = Image(source_path) img_obj = Image(source_path)
pil_img = img_obj.pil_image pil_img = img_obj.pil_image
if len(pil_img.getbands()) == 1: if len(pil_img.getbands()) == 1:
rgb_img = convert_grayscale_to_rgb_preserve_range(pil_img) rgb_img = img_obj.convert_grayscale_to_rgb_preserve_range()
else: else:
rgb_img = pil_img.convert("RGB") rgb_img = pil_img.convert("RGB")

View File

@@ -277,36 +277,18 @@ class Image:
""" """
return self._channels >= 3 return self._channels >= 3
def __repr__(self) -> str: def convert_grayscale_to_rgb_preserve_range(
"""String representation of the Image object.""" self,
return ( ) -> PILImage.Image:
f"Image(path='{self.path.name}', "
f"shape=({self._width}x{self._height}x{self._channels}), "
f"format={self._format}, "
f"size={self.size_mb:.2f}MB)"
)
def __str__(self) -> str:
"""String representation of the Image object."""
return self.__repr__()
def convert_grayscale_to_rgb_preserve_range(
pil_image: PILImage.Image,
) -> PILImage.Image:
"""Convert a single-channel PIL image to RGB while preserving dynamic range. """Convert a single-channel PIL image to RGB while preserving dynamic range.
Args:
pil_image: Single-channel PIL image (e.g., 16-bit grayscale).
Returns: Returns:
PIL Image in RGB mode with intensities normalized to 0-255. PIL Image in RGB mode with intensities normalized to 0-255.
""" """
if self._channels == 3:
return self.pil_image
if pil_image.mode == "RGB": grayscale = self.data
return pil_image
grayscale = np.array(pil_image)
if grayscale.ndim == 3: if grayscale.ndim == 3:
grayscale = grayscale[:, :, 0] grayscale = grayscale[:, :, 0]
@@ -314,7 +296,7 @@ def convert_grayscale_to_rgb_preserve_range(
grayscale = grayscale.astype(np.float32) grayscale = grayscale.astype(np.float32)
if grayscale.size == 0: if grayscale.size == 0:
return PILImage.new("RGB", pil_image.size, color=(0, 0, 0)) return PILImage.new("RGB", self.shape, color=(0, 0, 0))
if np.issubdtype(original_dtype, np.integer): if np.issubdtype(original_dtype, np.integer):
denom = float(max(np.iinfo(original_dtype).max, 1)) denom = float(max(np.iinfo(original_dtype).max, 1))
@@ -326,3 +308,17 @@ def convert_grayscale_to_rgb_preserve_range(
grayscale_u8 = (grayscale * 255.0).round().astype(np.uint8) grayscale_u8 = (grayscale * 255.0).round().astype(np.uint8)
rgb_arr = np.repeat(grayscale_u8[:, :, None], 3, axis=2) rgb_arr = np.repeat(grayscale_u8[:, :, None], 3, axis=2)
return PILImage.fromarray(rgb_arr, mode="RGB") return PILImage.fromarray(rgb_arr, mode="RGB")
def __repr__(self) -> str:
"""String representation of the Image object."""
return (
f"Image(path='{self.path.name}', "
# Display as HxWxC to match the conventional NumPy shape semantics.
f"shape=({self._height}x{self._width}x{self._channels}), "
f"format={self._format}, "
f"size={self.size_mb:.2f}MB)"
)
def __str__(self) -> str:
"""String representation of the Image object."""
return self.__repr__()

View File

@@ -12,23 +12,32 @@ class UT:
Operetta files along with rois drawn in ImageJ Operetta files along with rois drawn in ImageJ
""" """
def __init__(self, roifile_fn: Path): def __init__(self, roifile_fn: Path, no_labels: bool):
self.roifile_fn = roifile_fn self.roifile_fn = roifile_fn
print("is file", self.roifile_fn.is_file())
self.rois = None
if no_labels:
self.rois = ImagejRoi.fromfile(self.roifile_fn) self.rois = ImagejRoi.fromfile(self.roifile_fn)
self.stem = self.roifile_fn.stem.strip("-RoiSet") self.stem = self.roifile_fn.stem.split("Roi-")[1]
else:
self.roifile_fn = roifile_fn / roifile_fn.parts[-1]
self.stem = self.roifile_fn.stem
print(self.roifile_fn)
print(self.stem)
self.image, self.image_props = self._load_images() self.image, self.image_props = self._load_images()
def _load_images(self): def _load_images(self):
"""Loading sequence of tif files """Loading sequence of tif files
array sequence is CZYX array sequence is CZYX
""" """
print(self.roifile_fn.parent, self.stem) print("Loading images:", self.roifile_fn.parent, self.stem)
fns = list(self.roifile_fn.parent.glob(f"{self.stem}*.tif*")) fns = list(self.roifile_fn.parent.glob(f"{self.stem.lower()}*.tif*"))
stems = [fn.stem.split(self.stem)[-1] for fn in fns] stems = [fn.stem.split(self.stem)[-1] for fn in fns]
n_ch = len(set([stem.split("-ch")[-1].split("t")[0] for stem in stems])) n_ch = len(set([stem.split("-ch")[-1].split("t")[0] for stem in stems]))
n_p = len(set([stem.split("-")[0] for stem in stems])) n_p = len(set([stem.split("-")[0] for stem in stems]))
n_t = len(set([stem.split("t")[1] for stem in stems])) n_t = len(set([stem.split("t")[1] for stem in stems]))
print(n_ch, n_p, n_t)
with TiffFile(fns[0]) as tif: with TiffFile(fns[0]) as tif:
img = tif.asarray() img = tif.asarray()
@@ -42,6 +51,7 @@ class UT:
"height": h, "height": h,
"dtype": dtype, "dtype": dtype,
} }
print("Image props", self.image_props)
image_stack = np.zeros((n_ch, n_p, w, h), dtype=dtype) image_stack = np.zeros((n_ch, n_p, w, h), dtype=dtype)
for fn in fns: for fn in fns:
@@ -49,7 +59,7 @@ class UT:
img = tif.asarray() img = tif.asarray()
stem = fn.stem.split(self.stem)[-1] stem = fn.stem.split(self.stem)[-1]
ch = int(stem.split("-ch")[-1].split("t")[0]) ch = int(stem.split("-ch")[-1].split("t")[0])
p = int(stem.split("-")[0].lstrip("p")) p = int(stem.split("-")[0].split("p")[1])
t = int(stem.split("t")[1]) t = int(stem.split("t")[1])
print(fn.stem, "ch", ch, "p", p, "t", t) print(fn.stem, "ch", ch, "p", p, "t", t)
image_stack[ch - 1, p - 1] = img image_stack[ch - 1, p - 1] = img
@@ -82,11 +92,22 @@ class UT:
): ):
"""Export rois to a file""" """Export rois to a file"""
with open(path / subfolder / f"{self.stem}.txt", "w") as f: with open(path / subfolder / f"{self.stem}.txt", "w") as f:
for roi in self.rois: for i, roi in enumerate(self.rois):
# TODO add image coordinates normalization rc = roi.subpixel_coordinates
coords = "" if rc is None:
for x, y in roi.subpixel_coordinates: print(
coords += f"{x/self.width} {y/self.height}" f"No coordinates: {self.roifile_fn}, element {i}, out of {len(self.rois)}"
)
continue
xmn, ymn = rc.min(axis=0)
xmx, ymx = rc.max(axis=0)
xc = (xmn + xmx) / 2
yc = (ymn + ymx) / 2
bw = xmx - xmn
bh = ymx - ymn
coords = f"{xc/self.width} {yc/self.height} {bw/self.width} {bh/self.height} "
for x, y in rc:
coords += f"{x/self.width} {y/self.height} "
f.write(f"{class_index} {coords}\n") f.write(f"{class_index} {coords}\n")
return return
@@ -104,6 +125,7 @@ class UT:
self.image = np.max(self.image[channel], axis=0) self.image = np.max(self.image[channel], axis=0)
print(self.image.shape) print(self.image.shape)
print(path / subfolder / f"{self.stem}.tif")
with TiffWriter(path / subfolder / f"{self.stem}.tif") as tif: with TiffWriter(path / subfolder / f"{self.stem}.tif") as tif:
tif.write(self.image) tif.write(self.image)
@@ -112,11 +134,27 @@ if __name__ == "__main__":
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("input", type=Path) parser.add_argument("-i", "--input", nargs="*", type=Path)
parser.add_argument("output", type=Path) parser.add_argument("-o", "--output", type=Path)
parser.add_argument(
"--no-labels",
action="store_false",
help="Source does not have labels, export only images",
)
args = parser.parse_args() args = parser.parse_args()
for rfn in args.input.glob("*.zip"): for path in args.input:
ut = UT(rfn) print("Path:", path)
if not args.no_labels:
print("No labels")
ut = UT(path, args.no_labels)
ut.export_image(args.output, plane_mode="max projection", channel=0)
else:
for rfn in Path(path).glob("*.zip"):
print("Roi FN:", rfn)
ut = UT(rfn, args.no_labels)
ut.export_rois(args.output, class_index=0) ut.export_rois(args.output, class_index=0)
ut.export_image(args.output, plane_mode="max projection", channel=0) ut.export_image(args.output, plane_mode="max projection", channel=0)
print()

View File

@@ -0,0 +1,165 @@
"""Ultralytics runtime patches for 16-bit TIFF training.
Goals:
- Use `tifffile` to decode `.tif/.tiff` reliably (OpenCV can silently drop bit-depth depending on codec).
- Preserve 16-bit data through the dataloader as `uint16` tensors.
- Fix Ultralytics trainer normalization (default divides by 255) to scale `uint16` correctly.
- Avoid uint8-forcing augmentations by recommending/setting hyp values (handled by caller).
This module is intended to be imported/called **before** instantiating/using YOLO.
"""
from __future__ import annotations
from typing import Optional
def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
"""Apply runtime monkey-patches to Ultralytics to better support 16-bit TIFFs.
This function is safe to call multiple times.
Args:
force: If True, re-apply patches even if already applied.
"""
# Import inside function to ensure patching occurs before YOLO model/dataset is created.
import os
import cv2
import numpy as np
import tifffile
import torch
from ultralytics.utils import patches as ul_patches
already_patched = getattr(ul_patches.imread, "__name__", "") == "tifffile_imread"
if already_patched and not force:
return
_original_imread = ul_patches.imread
def tifffile_imread(
filename: str, flags: int = cv2.IMREAD_COLOR
) -> Optional[np.ndarray]:
"""Replacement for [`ultralytics.utils.patches.imread()`](venv/lib/python3.12/site-packages/ultralytics/utils/patches.py:20).
- For `.tif/.tiff`, uses `tifffile.imread()` and preserves dtype (e.g. uint16).
- For other formats, falls back to Ultralytics' original implementation.
- Always returns HWC (3 dims). For grayscale, returns (H, W, 1) or (H, W, 3) depending on requested flags.
"""
ext = os.path.splitext(filename)[1].lower()
if ext in (".tif", ".tiff"):
arr = tifffile.imread(filename)
# Normalize common shapes:
# - (H, W) -> (H, W, 1)
# - (C, H, W) -> (H, W, C) (heuristic)
if arr is None:
return None
if (
arr.ndim == 3
and arr.shape[0] in (1, 3, 4)
and arr.shape[0] < arr.shape[1]
):
arr = np.transpose(arr, (1, 2, 0))
if arr.ndim == 2:
arr = arr[..., None]
# Ultralytics expects BGR ordering when `channels=3`.
# For grayscale data we replicate channels (no scaling, no quantization).
if flags != cv2.IMREAD_GRAYSCALE:
if arr.shape[2] == 1:
arr = np.repeat(arr, 3, axis=2)
elif arr.shape[2] >= 3:
arr = arr[:, :, :3]
# Ensure contiguous array for downstream OpenCV ops.
return np.ascontiguousarray(arr)
return _original_imread(filename, flags)
# Patch the canonical reference.
ul_patches.imread = tifffile_imread
# Patch common module-level imports (some Ultralytics modules do `from ... import imread`).
# Importing these modules is safe and helps ensure the patched function is used.
try:
import ultralytics.data.base as _ul_base
_ul_base.imread = tifffile_imread
except Exception:
pass
try:
import ultralytics.data.loaders as _ul_loaders
_ul_loaders.imread = tifffile_imread
except Exception:
pass
# Patch trainer normalization: default divides by 255 regardless of input dtype.
from ultralytics.models.yolo.detect import train as detect_train
_orig_preprocess_batch = detect_train.DetectionTrainer.preprocess_batch
def preprocess_batch_16bit(self, batch: dict) -> dict: # type: ignore[override]
# Start from upstream behavior to keep device placement + multiscale identical,
# but replace the 255 division with dtype-aware scaling.
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")
img = batch.get("img")
if isinstance(img, torch.Tensor):
# Decide scaling denom based on dtype (avoid expensive reductions if possible).
if img.dtype == torch.uint8:
denom = 255.0
elif img.dtype == torch.uint16:
denom = 65535.0
elif img.dtype.is_floating_point:
# Assume already in 0-1 range if float.
denom = 1.0
else:
# Generic integer fallback.
try:
denom = float(torch.iinfo(img.dtype).max)
except Exception:
denom = 255.0
batch["img"] = img.float() / denom
# Multi-scale branch copied from upstream to avoid re-introducing `/255` scaling.
if getattr(self.args, "multi_scale", False):
import math
import random
import torch.nn as nn
imgs = batch["img"]
sz = (
random.randrange(
int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride)
)
// self.stride
* self.stride
)
sf = sz / max(imgs.shape[2:])
if sf != 1:
ns = [
math.ceil(x * sf / self.stride) * self.stride
for x in imgs.shape[2:]
]
imgs = nn.functional.interpolate(
imgs, size=ns, mode="bilinear", align_corners=False
)
batch["img"] = imgs
return batch
detect_train.DetectionTrainer.preprocess_batch = preprocess_batch_16bit
# Tag function to make it easier to detect patch state.
setattr(
detect_train.DetectionTrainer.preprocess_batch, "_ultralytics_16bit_patch", True
)

184
tests/show_yolo_seg.py Normal file
View File

@@ -0,0 +1,184 @@
#!/usr/bin/env python3
"""
show_yolo_seg.py
Usage:
python show_yolo_seg.py /path/to/image.jpg /path/to/labels.txt
Supports:
- Segmentation polygons: "class x1 y1 x2 y2 ... xn yn"
- YOLO bbox lines as fallback: "class x_center y_center width height"
Coordinates can be normalized [0..1] or absolute pixels (auto-detected).
"""
import sys
import cv2
import numpy as np
import matplotlib.pyplot as plt
import argparse
from pathlib import Path
import random
def parse_label_line(line):
parts = line.strip().split()
if not parts:
return None
cls = int(float(parts[0]))
coords = [float(x) for x in parts[1:]]
return cls, coords
def coords_are_normalized(coords):
# If every coordinate is between 0 and 1 (inclusive-ish), assume normalized
if not coords:
return False
return max(coords) <= 1.001
def yolo_bbox_to_xyxy(coords, img_w, img_h):
# coords: [xc, yc, w, h] normalized or absolute
xc, yc, w, h = coords[:4]
if max(coords) <= 1.001:
xc *= img_w
yc *= img_h
w *= img_w
h *= img_h
x1 = int(round(xc - w / 2))
y1 = int(round(yc - h / 2))
x2 = int(round(xc + w / 2))
y2 = int(round(yc + h / 2))
return x1, y1, x2, y2
def poly_to_pts(coords, img_w, img_h):
# coords: [x1 y1 x2 y2 ...] either normalized or absolute
if coords_are_normalized(coords[4:]):
coords = [
coords[i] * (img_w if i % 2 == 0 else img_h) for i in range(len(coords))
]
pts = np.array(coords, dtype=np.int32).reshape(-1, 2)
return pts
def random_color_for_class(cls):
random.seed(cls) # deterministic per class
return tuple(int(x) for x in np.array([random.randint(0, 255) for _ in range(3)]))
def draw_annotations(img, labels, alpha=0.4, draw_bbox_for_poly=True):
# img: BGR numpy array
overlay = img.copy()
h, w = img.shape[:2]
for cls, coords in labels:
if not coords:
continue
# polygon case (>=6 coordinates)
if len(coords) >= 6:
color = random_color_for_class(cls)
x1, y1, x2, y2 = yolo_bbox_to_xyxy(coords[:4], w, h)
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
pts = poly_to_pts(coords[4:], w, h)
# fill on overlay
cv2.fillPoly(overlay, [pts], color)
# outline on base image
cv2.polylines(img, [pts], isClosed=True, color=color, thickness=2)
# put class text at first point
x, y = int(pts[0, 0]), int(pts[0, 1]) - 6
cv2.putText(
img,
str(cls),
(x, max(6, y)),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(255, 255, 255),
2,
cv2.LINE_AA,
)
# YOLO bbox case (4 coords)
elif len(coords) == 4:
x1, y1, x2, y2 = yolo_bbox_to_xyxy(coords, w, h)
color = random_color_for_class(cls)
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
cv2.putText(
img,
str(cls),
(x1, max(6, y1 - 4)),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(255, 255, 255),
2,
cv2.LINE_AA,
)
else:
# Unknown / invalid format, skip
continue
# blend overlay for filled polygons
cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0, img)
return img
def load_labels_file(label_path):
labels = []
with open(label_path, "r") as f:
for raw in f:
line = raw.strip()
if not line:
continue
parsed = parse_label_line(line)
if parsed:
labels.append(parsed)
return labels
def main():
parser = argparse.ArgumentParser(
description="Show YOLO segmentation / polygon annotations"
)
parser.add_argument("image", type=str, help="Path to image file")
parser.add_argument("labels", type=str, help="Path to YOLO label file (polygons)")
parser.add_argument(
"--alpha", type=float, default=0.4, help="Polygon fill alpha (0..1)"
)
parser.add_argument(
"--no-bbox", action="store_true", help="Don't draw bounding boxes for polygons"
)
args = parser.parse_args()
img_path = Path(args.image)
lbl_path = Path(args.labels)
if not img_path.exists():
print("Image not found:", img_path)
sys.exit(1)
if not lbl_path.exists():
print("Label file not found:", lbl_path)
sys.exit(1)
img = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
if img is None:
print("Could not load image:", img_path)
sys.exit(1)
labels = load_labels_file(str(lbl_path))
if not labels:
print("No labels parsed from", lbl_path)
# continue and just show image
out = draw_annotations(
img.copy(), labels, alpha=args.alpha, draw_bbox_for_poly=(not args.no_bbox)
)
# Convert BGR -> RGB for matplotlib display
out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(10, 10 * out.shape[0] / out.shape[1]))
plt.imshow(out_rgb)
plt.axis("off")
plt.title(f"{img_path.name} ({lbl_path.name})")
plt.show()
if __name__ == "__main__":
main()