Fixing pseudo rgb

This commit is contained in:
2025-12-19 09:56:43 +02:00
parent a8e5db3135
commit 061f8b3ca2
4 changed files with 91 additions and 98 deletions

View File

@@ -250,12 +250,10 @@ class AnnotationCanvasWidget(QWidget):
# Get image data in a format compatible with Qt
if self.current_image.channels in (3, 4):
image_data = self.current_image.get_rgb()
height, width = image_data.shape[:2]
else:
image_data = self.current_image.get_grayscale()
height, width = image_data.shape
image_data = self.current_image.get_qt_rgb()
image_data = np.ascontiguousarray(image_data)
height, width = image_data.shape[:2]
bytes_per_line = image_data.strides[0]
qimage = QImage(
@@ -263,7 +261,7 @@ class AnnotationCanvasWidget(QWidget):
width,
height,
bytes_per_line,
self.current_image.qtimage_format,
QImage.Format_RGBX32FPx4, # self.current_image.qtimage_format,
).copy() # Copy so Qt owns the buffer even after numpy array goes out of scope
self.original_pixmap = QPixmap.fromImage(qimage)

View File

@@ -257,17 +257,11 @@ class YOLOWrapper:
if source_path.is_file():
try:
img_obj = Image(source_path)
pil_img = img_obj.pil_image
if len(pil_img.getbands()) == 1:
rgb_img = img_obj.convert_grayscale_to_rgb_preserve_range()
else:
rgb_img = pil_img.convert("RGB")
suffix = source_path.suffix or ".png"
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
tmp_path = tmp.name
tmp.close()
rgb_img.save(tmp_path)
img_obj.save(tmp_path)
cleanup_path = tmp_path
logger.info(
f"Converted image {source_path} to RGB for inference at {tmp_path}"

View File

@@ -6,16 +6,49 @@ import cv2
import numpy as np
from pathlib import Path
from typing import Optional, Tuple, Union
from PIL import Image as PILImage
from src.utils.logger import get_logger
from src.utils.file_utils import validate_file_path, is_image_file
from PySide6.QtGui import QImage
from tifffile import imread, imwrite
logger = get_logger(__name__)
def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.3) -> np.ndarray:
"""
Convert a grayscale image to a pseudo-RGB image using a gamma correction.
Args:
arr: Input grayscale image as numpy array
Returns:
Pseudo-RGB image as numpy array
"""
if arr.ndim != 2:
raise ValueError("Input array must be a grayscale image with shape (H, W)")
a1 = arr.copy().astype(np.float32)
a1 -= np.percentile(a1, 2)
a1[a1 < 0] = 0
# p999 = np.percentile(a1, 99.9)
# a1[a1 > p999] = p999
a1 /= a1.max()
a2 = a1.copy()
a2 = a2**gamma
a2 /= a2.max()
a3 = a1.copy()
p9999 = np.percentile(a3, 99.99)
a3[a3 > p9999] = p9999
a3 /= a3.max()
return np.stack([a1, a2, a3], axis=0)
class ImageLoadError(Exception):
"""Exception raised when an image cannot be loaded."""
@@ -54,7 +87,6 @@ class Image:
"""
self.path = Path(image_path)
self._data: Optional[np.ndarray] = None
self._pil_image: Optional[PILImage.Image] = None
self._width: int = 0
self._height: int = 0
self._channels: int = 0
@@ -85,6 +117,10 @@ class Image:
)
try:
if self.path.suffix.lower() in [".tif", ".tiff"]:
self._data = imread(str(self.path))
else:
raise NotImplementedError
# Load with OpenCV (returns BGR format)
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
@@ -92,23 +128,13 @@ class Image:
raise ImageLoadError(f"Failed to load image with OpenCV: {self.path}")
# Extract metadata
print(self._data.shape)
self._height, self._width = self._data.shape[:2]
self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1
self._format = self.path.suffix.lower().lstrip(".")
self._size_bytes = self.path.stat().st_size
self._dtype = self._data.dtype
# Load PIL version for compatibility (convert BGR to RGB)
if self._channels == 3:
rgb_data = cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
self._pil_image = PILImage.fromarray(rgb_data)
elif self._channels == 4:
rgba_data = cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA)
self._pil_image = PILImage.fromarray(rgba_data)
else:
# Grayscale
self._pil_image = PILImage.fromarray(self._data)
logger.info(
f"Successfully loaded image: {self.path.name} "
f"({self._width}x{self._height}, {self._channels} channels, "
@@ -131,18 +157,6 @@ class Image:
raise ImageLoadError("Image data not available")
return self._data
@property
def pil_image(self) -> PILImage.Image:
"""
Get image data as PIL Image (RGB or grayscale).
Returns:
PIL Image object
"""
if self._pil_image is None:
raise ImageLoadError("PIL image not available")
return self._pil_image
@property
def width(self) -> int:
"""Get image width in pixels."""
@@ -187,6 +201,7 @@ class Image:
@property
def dtype(self) -> np.dtype:
"""Get the data type of the image array."""
if self._dtype is None:
raise ImageLoadError("Image dtype not available")
return self._dtype
@@ -206,8 +221,10 @@ class Image:
elif self._channels == 1:
if self._dtype == np.uint16:
return QImage.Format_Grayscale16
else:
elif self._dtype == np.uint8:
return QImage.Format_Grayscale8
elif self._dtype == np.float32:
return QImage.Format_BGR30
else:
raise ImageLoadError(f"Unsupported number of channels: {self._channels}")
@@ -218,6 +235,12 @@ class Image:
Returns:
Image data in RGB format as numpy array
"""
if self.channels == 1:
img = get_pseudo_rgb(self.data)
self._dtype = img.dtype
return img
raise NotImplementedError
if self._channels == 3:
return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
elif self._channels == 4:
@@ -225,6 +248,18 @@ class Image:
else:
return self._data
def get_qt_rgb(self) -> np.ascontiguousarray:
# we keep data as (C, H, W)
_img = self.get_rgb()
img = np.zeros((self.height, self.width, 4), dtype=np.float32)
img[..., 0] = _img[0] # R gradient
img[..., 1] = _img[1] # G gradient
img[..., 2] = _img[2] # B constant
img[..., 3] = 1.0 # A = 1.0 (opaque)
return np.ascontiguousarray(img)
def get_grayscale(self) -> np.ndarray:
"""
Get image as grayscale numpy array.
@@ -277,37 +312,18 @@ class Image:
"""
return self._channels >= 3
def convert_grayscale_to_rgb_preserve_range(
self,
) -> PILImage.Image:
"""Convert a single-channel PIL image to RGB while preserving dynamic range.
def save(self, path: Union[str, Path], pseudo_rgb: bool = True) -> None:
Returns:
PIL Image in RGB mode with intensities normalized to 0-255.
"""
if self._channels == 3:
return self.pil_image
grayscale = self.data
if grayscale.ndim == 3:
grayscale = grayscale[:, :, 0]
original_dtype = grayscale.dtype
grayscale = grayscale.astype(np.float32)
if grayscale.size == 0:
return PILImage.new("RGB", self.shape, color=(0, 0, 0))
if np.issubdtype(original_dtype, np.integer):
denom = float(max(np.iinfo(original_dtype).max, 1))
if self.channels == 1:
if pseudo_rgb:
img = get_pseudo_rgb(self.data)
else:
max_val = float(grayscale.max())
denom = max(max_val, 1.0)
img = np.repeat(self.data, 3, axis=2)
grayscale = np.clip(grayscale / denom, 0.0, 1.0)
grayscale_u8 = (grayscale * 255.0).round().astype(np.uint8)
rgb_arr = np.repeat(grayscale_u8[:, :, None], 3, axis=2)
return PILImage.fromarray(rgb_arr, mode="RGB")
else:
raise NotImplementedError("Only grayscale images are supported for now.")
imwrite(path, data=img)
def __repr__(self) -> str:
"""String representation of the Image object."""
@@ -322,3 +338,15 @@ class Image:
def __str__(self) -> str:
"""String representation of the Image object."""
return self.__repr__()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--path", type=str, required=True)
args = parser.parse_args()
img = Image(args.path)
img.save(args.path + "test.tif")
print(img)

View File

@@ -32,8 +32,10 @@ def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
import cv2
import numpy as np
import tifffile
# import tifffile
import torch
from src.utils.image import Image
from ultralytics.utils import patches as ul_patches
@@ -55,7 +57,7 @@ def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
ext = os.path.splitext(filename)[1].lower()
if ext in (".tif", ".tiff"):
arr = tifffile.imread(filename)
arr = Image(filename).get_qt_rgb()[:, :, :3]
# Normalize common shapes:
# - (H, W) -> (H, W, 1)
@@ -71,35 +73,6 @@ def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
if arr.ndim == 2:
arr = arr[..., None]
# Ultralytics expects BGR ordering when `channels=3`.
# For grayscale data we replicate channels (no scaling, no quantization).
if flags != cv2.IMREAD_GRAYSCALE:
if arr.shape[2] == 1:
if pseudo_rgb:
gamma = 0.3
a1 = arr.copy().astype(np.float32)
a1 -= np.percentile(a1, 2)
a1[a1 < 0] = 0
p98 = np.percentile(a1, 98)
a1[a1 > p98] = p98
a1 /= a1.max()
a2 = a1.copy()
a2 = a2**gamma
a2 /= a2.max()
a3 = a1.copy()
p90 = np.percentile(a3, 90)
a3[a3 > p90] = p90
a3 /= a3.max()
arr = np.concatenate([a1, a2, a3], axis=2)
else:
arr = np.repeat(arr, 3, axis=2)
elif arr.shape[2] >= 3:
arr = arr[:, :, :3]
# Ensure contiguous array for downstream OpenCV ops.
logger.info(f"Loading with monkey-patched imread: {filename}")
return np.ascontiguousarray(arr)