diff --git a/src/gui/tabs/training_tab.py b/src/gui/tabs/training_tab.py index 3c2ca9e..257d3c2 100644 --- a/src/gui/tabs/training_tab.py +++ b/src/gui/tabs/training_tab.py @@ -34,6 +34,7 @@ from PySide6.QtWidgets import ( from src.database.db_manager import DatabaseManager from src.model.yolo_wrapper import YOLOWrapper from src.utils.config_manager import ConfigManager +from src.utils.image import convert_grayscale_to_rgb_preserve_range from src.utils.logger import get_logger @@ -1361,7 +1362,10 @@ class TrainingTab(QWidget): dst.parent.mkdir(parents=True, exist_ok=True) try: with PILImage.open(src) as img: - rgb_img = img.convert("RGB") + if len(img.getbands()) == 1: + rgb_img = convert_grayscale_to_rgb_preserve_range(img) + else: + rgb_img = img.convert("RGB") rgb_img.save(dst) except Exception as exc: logger.warning(f"Failed to convert {src} to RGB: {exc}") diff --git a/src/model/yolo_wrapper.py b/src/model/yolo_wrapper.py index fc4ac0c..f165c7b 100644 --- a/src/model/yolo_wrapper.py +++ b/src/model/yolo_wrapper.py @@ -10,7 +10,7 @@ import torch from PIL import Image import tempfile import os -import numpy as np +from src.utils.image import convert_grayscale_to_rgb_preserve_range from src.utils.logger import get_logger @@ -234,38 +234,20 @@ class YOLOWrapper: try: with Image.open(source_path) as img: if len(img.getbands()) == 1: - grayscale = np.array(img) - if grayscale.ndim == 3: - grayscale = grayscale[:, :, 0] - original_dtype = grayscale.dtype - grayscale = grayscale.astype(np.float32) + rgb_img = convert_grayscale_to_rgb_preserve_range(img) + else: + rgb_img = img.convert("RGB") - if np.issubdtype(original_dtype, np.integer): - dtype_info = np.iinfo(original_dtype) - denom = float(max(dtype_info.max, 1)) - else: - max_val = ( - float(grayscale.max()) if grayscale.size else 0.0 - ) - denom = max(max_val, 1.0) - - grayscale = np.clip(grayscale / denom, 0.0, 1.0) - grayscale_u8 = (grayscale * 255.0).round().astype(np.uint8) - rgb_arr = np.repeat(grayscale_u8[:, :, None], 3, axis=2) - rgb_img = Image.fromarray(rgb_arr, mode="RGB") - - suffix = source_path.suffix or ".png" - tmp = tempfile.NamedTemporaryFile( - suffix=suffix, delete=False - ) - tmp_path = tmp.name - tmp.close() - rgb_img.save(tmp_path) - cleanup_path = tmp_path - logger.info( - f"Converted single-channel image {source_path} to RGB for inference at {tmp_path}" - ) - return tmp_path, cleanup_path + suffix = source_path.suffix or ".png" + tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) + tmp_path = tmp.name + tmp.close() + rgb_img.save(tmp_path) + cleanup_path = tmp_path + logger.info( + f"Converted image {source_path} to RGB for inference at {tmp_path}" + ) + return tmp_path, cleanup_path except Exception as convert_error: logger.warning( f"Failed to preprocess {source_path} as RGB, continuing with original file: {convert_error}" diff --git a/src/utils/image.py b/src/utils/image.py index 9dc867d..69139cd 100644 --- a/src/utils/image.py +++ b/src/utils/image.py @@ -289,3 +289,40 @@ class Image: def __str__(self) -> str: """String representation of the Image object.""" return self.__repr__() + + +def convert_grayscale_to_rgb_preserve_range( + pil_image: PILImage.Image, +) -> PILImage.Image: + """Convert a single-channel PIL image to RGB while preserving dynamic range. + + Args: + pil_image: Single-channel PIL image (e.g., 16-bit grayscale). + + Returns: + PIL Image in RGB mode with intensities normalized to 0-255. + """ + + if pil_image.mode == "RGB": + return pil_image + + grayscale = np.array(pil_image) + if grayscale.ndim == 3: + grayscale = grayscale[:, :, 0] + + original_dtype = grayscale.dtype + grayscale = grayscale.astype(np.float32) + + if grayscale.size == 0: + return PILImage.new("RGB", pil_image.size, color=(0, 0, 0)) + + if np.issubdtype(original_dtype, np.integer): + denom = float(max(np.iinfo(original_dtype).max, 1)) + else: + max_val = float(grayscale.max()) + denom = max(max_val, 1.0) + + grayscale = np.clip(grayscale / denom, 0.0, 1.0) + grayscale_u8 = (grayscale * 255.0).round().astype(np.uint8) + rgb_arr = np.repeat(grayscale_u8[:, :, None], 3, axis=2) + return PILImage.fromarray(rgb_arr, mode="RGB")