Fixing grayscale conversion

This commit is contained in:
2025-12-11 15:15:38 +02:00
parent e4ce882a18
commit 8eb1cc8c86
3 changed files with 56 additions and 33 deletions

View File

@@ -34,6 +34,7 @@ from PySide6.QtWidgets import (
from src.database.db_manager import DatabaseManager from src.database.db_manager import DatabaseManager
from src.model.yolo_wrapper import YOLOWrapper from src.model.yolo_wrapper import YOLOWrapper
from src.utils.config_manager import ConfigManager from src.utils.config_manager import ConfigManager
from src.utils.image import convert_grayscale_to_rgb_preserve_range
from src.utils.logger import get_logger from src.utils.logger import get_logger
@@ -1361,7 +1362,10 @@ class TrainingTab(QWidget):
dst.parent.mkdir(parents=True, exist_ok=True) dst.parent.mkdir(parents=True, exist_ok=True)
try: try:
with PILImage.open(src) as img: with PILImage.open(src) as img:
rgb_img = img.convert("RGB") if len(img.getbands()) == 1:
rgb_img = convert_grayscale_to_rgb_preserve_range(img)
else:
rgb_img = img.convert("RGB")
rgb_img.save(dst) rgb_img.save(dst)
except Exception as exc: except Exception as exc:
logger.warning(f"Failed to convert {src} to RGB: {exc}") logger.warning(f"Failed to convert {src} to RGB: {exc}")

View File

@@ -10,7 +10,7 @@ import torch
from PIL import Image from PIL import Image
import tempfile import tempfile
import os import os
import numpy as np from src.utils.image import convert_grayscale_to_rgb_preserve_range
from src.utils.logger import get_logger from src.utils.logger import get_logger
@@ -234,38 +234,20 @@ class YOLOWrapper:
try: try:
with Image.open(source_path) as img: with Image.open(source_path) as img:
if len(img.getbands()) == 1: if len(img.getbands()) == 1:
grayscale = np.array(img) rgb_img = convert_grayscale_to_rgb_preserve_range(img)
if grayscale.ndim == 3: else:
grayscale = grayscale[:, :, 0] rgb_img = img.convert("RGB")
original_dtype = grayscale.dtype
grayscale = grayscale.astype(np.float32)
if np.issubdtype(original_dtype, np.integer): suffix = source_path.suffix or ".png"
dtype_info = np.iinfo(original_dtype) tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
denom = float(max(dtype_info.max, 1)) tmp_path = tmp.name
else: tmp.close()
max_val = ( rgb_img.save(tmp_path)
float(grayscale.max()) if grayscale.size else 0.0 cleanup_path = tmp_path
) logger.info(
denom = max(max_val, 1.0) f"Converted image {source_path} to RGB for inference at {tmp_path}"
)
grayscale = np.clip(grayscale / denom, 0.0, 1.0) return tmp_path, cleanup_path
grayscale_u8 = (grayscale * 255.0).round().astype(np.uint8)
rgb_arr = np.repeat(grayscale_u8[:, :, None], 3, axis=2)
rgb_img = Image.fromarray(rgb_arr, mode="RGB")
suffix = source_path.suffix or ".png"
tmp = tempfile.NamedTemporaryFile(
suffix=suffix, delete=False
)
tmp_path = tmp.name
tmp.close()
rgb_img.save(tmp_path)
cleanup_path = tmp_path
logger.info(
f"Converted single-channel image {source_path} to RGB for inference at {tmp_path}"
)
return tmp_path, cleanup_path
except Exception as convert_error: except Exception as convert_error:
logger.warning( logger.warning(
f"Failed to preprocess {source_path} as RGB, continuing with original file: {convert_error}" f"Failed to preprocess {source_path} as RGB, continuing with original file: {convert_error}"

View File

@@ -289,3 +289,40 @@ class Image:
def __str__(self) -> str: def __str__(self) -> str:
"""String representation of the Image object.""" """String representation of the Image object."""
return self.__repr__() return self.__repr__()
def convert_grayscale_to_rgb_preserve_range(
pil_image: PILImage.Image,
) -> PILImage.Image:
"""Convert a single-channel PIL image to RGB while preserving dynamic range.
Args:
pil_image: Single-channel PIL image (e.g., 16-bit grayscale).
Returns:
PIL Image in RGB mode with intensities normalized to 0-255.
"""
if pil_image.mode == "RGB":
return pil_image
grayscale = np.array(pil_image)
if grayscale.ndim == 3:
grayscale = grayscale[:, :, 0]
original_dtype = grayscale.dtype
grayscale = grayscale.astype(np.float32)
if grayscale.size == 0:
return PILImage.new("RGB", pil_image.size, color=(0, 0, 0))
if np.issubdtype(original_dtype, np.integer):
denom = float(max(np.iinfo(original_dtype).max, 1))
else:
max_val = float(grayscale.max())
denom = max(max_val, 1.0)
grayscale = np.clip(grayscale / denom, 0.0, 1.0)
grayscale_u8 = (grayscale * 255.0).round().astype(np.uint8)
rgb_arr = np.repeat(grayscale_u8[:, :, None], 3, axis=2)
return PILImage.fromarray(rgb_arr, mode="RGB")