Fixing grayscale conversion
This commit is contained in:
@@ -34,6 +34,7 @@ from PySide6.QtWidgets import (
|
||||
from src.database.db_manager import DatabaseManager
|
||||
from src.model.yolo_wrapper import YOLOWrapper
|
||||
from src.utils.config_manager import ConfigManager
|
||||
from src.utils.image import convert_grayscale_to_rgb_preserve_range
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
|
||||
@@ -1361,6 +1362,9 @@ class TrainingTab(QWidget):
|
||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
with PILImage.open(src) as img:
|
||||
if len(img.getbands()) == 1:
|
||||
rgb_img = convert_grayscale_to_rgb_preserve_range(img)
|
||||
else:
|
||||
rgb_img = img.convert("RGB")
|
||||
rgb_img.save(dst)
|
||||
except Exception as exc:
|
||||
|
||||
@@ -10,7 +10,7 @@ import torch
|
||||
from PIL import Image
|
||||
import tempfile
|
||||
import os
|
||||
import numpy as np
|
||||
from src.utils.image import convert_grayscale_to_rgb_preserve_range
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
|
||||
@@ -234,36 +234,18 @@ class YOLOWrapper:
|
||||
try:
|
||||
with Image.open(source_path) as img:
|
||||
if len(img.getbands()) == 1:
|
||||
grayscale = np.array(img)
|
||||
if grayscale.ndim == 3:
|
||||
grayscale = grayscale[:, :, 0]
|
||||
original_dtype = grayscale.dtype
|
||||
grayscale = grayscale.astype(np.float32)
|
||||
|
||||
if np.issubdtype(original_dtype, np.integer):
|
||||
dtype_info = np.iinfo(original_dtype)
|
||||
denom = float(max(dtype_info.max, 1))
|
||||
rgb_img = convert_grayscale_to_rgb_preserve_range(img)
|
||||
else:
|
||||
max_val = (
|
||||
float(grayscale.max()) if grayscale.size else 0.0
|
||||
)
|
||||
denom = max(max_val, 1.0)
|
||||
|
||||
grayscale = np.clip(grayscale / denom, 0.0, 1.0)
|
||||
grayscale_u8 = (grayscale * 255.0).round().astype(np.uint8)
|
||||
rgb_arr = np.repeat(grayscale_u8[:, :, None], 3, axis=2)
|
||||
rgb_img = Image.fromarray(rgb_arr, mode="RGB")
|
||||
rgb_img = img.convert("RGB")
|
||||
|
||||
suffix = source_path.suffix or ".png"
|
||||
tmp = tempfile.NamedTemporaryFile(
|
||||
suffix=suffix, delete=False
|
||||
)
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
|
||||
tmp_path = tmp.name
|
||||
tmp.close()
|
||||
rgb_img.save(tmp_path)
|
||||
cleanup_path = tmp_path
|
||||
logger.info(
|
||||
f"Converted single-channel image {source_path} to RGB for inference at {tmp_path}"
|
||||
f"Converted image {source_path} to RGB for inference at {tmp_path}"
|
||||
)
|
||||
return tmp_path, cleanup_path
|
||||
except Exception as convert_error:
|
||||
|
||||
@@ -289,3 +289,40 @@ class Image:
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the Image object."""
|
||||
return self.__repr__()
|
||||
|
||||
|
||||
def convert_grayscale_to_rgb_preserve_range(
|
||||
pil_image: PILImage.Image,
|
||||
) -> PILImage.Image:
|
||||
"""Convert a single-channel PIL image to RGB while preserving dynamic range.
|
||||
|
||||
Args:
|
||||
pil_image: Single-channel PIL image (e.g., 16-bit grayscale).
|
||||
|
||||
Returns:
|
||||
PIL Image in RGB mode with intensities normalized to 0-255.
|
||||
"""
|
||||
|
||||
if pil_image.mode == "RGB":
|
||||
return pil_image
|
||||
|
||||
grayscale = np.array(pil_image)
|
||||
if grayscale.ndim == 3:
|
||||
grayscale = grayscale[:, :, 0]
|
||||
|
||||
original_dtype = grayscale.dtype
|
||||
grayscale = grayscale.astype(np.float32)
|
||||
|
||||
if grayscale.size == 0:
|
||||
return PILImage.new("RGB", pil_image.size, color=(0, 0, 0))
|
||||
|
||||
if np.issubdtype(original_dtype, np.integer):
|
||||
denom = float(max(np.iinfo(original_dtype).max, 1))
|
||||
else:
|
||||
max_val = float(grayscale.max())
|
||||
denom = max(max_val, 1.0)
|
||||
|
||||
grayscale = np.clip(grayscale / denom, 0.0, 1.0)
|
||||
grayscale_u8 = (grayscale * 255.0).round().astype(np.uint8)
|
||||
rgb_arr = np.repeat(grayscale_u8[:, :, None], 3, axis=2)
|
||||
return PILImage.fromarray(rgb_arr, mode="RGB")
|
||||
|
||||
Reference in New Issue
Block a user