Fixing pseudo rgb

This commit is contained in:
2025-12-19 09:56:43 +02:00
parent a8e5db3135
commit 061f8b3ca2
4 changed files with 91 additions and 98 deletions

View File

@@ -250,12 +250,10 @@ class AnnotationCanvasWidget(QWidget):
# Get image data in a format compatible with Qt # Get image data in a format compatible with Qt
if self.current_image.channels in (3, 4): if self.current_image.channels in (3, 4):
image_data = self.current_image.get_rgb() image_data = self.current_image.get_rgb()
height, width = image_data.shape[:2]
else: else:
image_data = self.current_image.get_grayscale() image_data = self.current_image.get_qt_rgb()
height, width = image_data.shape
image_data = np.ascontiguousarray(image_data) height, width = image_data.shape[:2]
bytes_per_line = image_data.strides[0] bytes_per_line = image_data.strides[0]
qimage = QImage( qimage = QImage(
@@ -263,7 +261,7 @@ class AnnotationCanvasWidget(QWidget):
width, width,
height, height,
bytes_per_line, bytes_per_line,
self.current_image.qtimage_format, QImage.Format_RGBX32FPx4, # self.current_image.qtimage_format,
).copy() # Copy so Qt owns the buffer even after numpy array goes out of scope ).copy() # Copy so Qt owns the buffer even after numpy array goes out of scope
self.original_pixmap = QPixmap.fromImage(qimage) self.original_pixmap = QPixmap.fromImage(qimage)

View File

@@ -257,17 +257,11 @@ class YOLOWrapper:
if source_path.is_file(): if source_path.is_file():
try: try:
img_obj = Image(source_path) img_obj = Image(source_path)
pil_img = img_obj.pil_image
if len(pil_img.getbands()) == 1:
rgb_img = img_obj.convert_grayscale_to_rgb_preserve_range()
else:
rgb_img = pil_img.convert("RGB")
suffix = source_path.suffix or ".png" suffix = source_path.suffix or ".png"
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
tmp_path = tmp.name tmp_path = tmp.name
tmp.close() tmp.close()
rgb_img.save(tmp_path) img_obj.save(tmp_path)
cleanup_path = tmp_path cleanup_path = tmp_path
logger.info( logger.info(
f"Converted image {source_path} to RGB for inference at {tmp_path}" f"Converted image {source_path} to RGB for inference at {tmp_path}"

View File

@@ -6,16 +6,49 @@ import cv2
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
from PIL import Image as PILImage
from src.utils.logger import get_logger from src.utils.logger import get_logger
from src.utils.file_utils import validate_file_path, is_image_file from src.utils.file_utils import validate_file_path, is_image_file
from PySide6.QtGui import QImage from PySide6.QtGui import QImage
from tifffile import imread, imwrite
logger = get_logger(__name__) logger = get_logger(__name__)
def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.3) -> np.ndarray:
"""
Convert a grayscale image to a pseudo-RGB image using a gamma correction.
Args:
arr: Input grayscale image as numpy array
Returns:
Pseudo-RGB image as numpy array
"""
if arr.ndim != 2:
raise ValueError("Input array must be a grayscale image with shape (H, W)")
a1 = arr.copy().astype(np.float32)
a1 -= np.percentile(a1, 2)
a1[a1 < 0] = 0
# p999 = np.percentile(a1, 99.9)
# a1[a1 > p999] = p999
a1 /= a1.max()
a2 = a1.copy()
a2 = a2**gamma
a2 /= a2.max()
a3 = a1.copy()
p9999 = np.percentile(a3, 99.99)
a3[a3 > p9999] = p9999
a3 /= a3.max()
return np.stack([a1, a2, a3], axis=0)
class ImageLoadError(Exception): class ImageLoadError(Exception):
"""Exception raised when an image cannot be loaded.""" """Exception raised when an image cannot be loaded."""
@@ -54,7 +87,6 @@ class Image:
""" """
self.path = Path(image_path) self.path = Path(image_path)
self._data: Optional[np.ndarray] = None self._data: Optional[np.ndarray] = None
self._pil_image: Optional[PILImage.Image] = None
self._width: int = 0 self._width: int = 0
self._height: int = 0 self._height: int = 0
self._channels: int = 0 self._channels: int = 0
@@ -85,6 +117,10 @@ class Image:
) )
try: try:
if self.path.suffix.lower() in [".tif", ".tiff"]:
self._data = imread(str(self.path))
else:
raise NotImplementedError
# Load with OpenCV (returns BGR format) # Load with OpenCV (returns BGR format)
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED) self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
@@ -92,23 +128,13 @@ class Image:
raise ImageLoadError(f"Failed to load image with OpenCV: {self.path}") raise ImageLoadError(f"Failed to load image with OpenCV: {self.path}")
# Extract metadata # Extract metadata
print(self._data.shape)
self._height, self._width = self._data.shape[:2] self._height, self._width = self._data.shape[:2]
self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1 self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1
self._format = self.path.suffix.lower().lstrip(".") self._format = self.path.suffix.lower().lstrip(".")
self._size_bytes = self.path.stat().st_size self._size_bytes = self.path.stat().st_size
self._dtype = self._data.dtype self._dtype = self._data.dtype
# Load PIL version for compatibility (convert BGR to RGB)
if self._channels == 3:
rgb_data = cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
self._pil_image = PILImage.fromarray(rgb_data)
elif self._channels == 4:
rgba_data = cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA)
self._pil_image = PILImage.fromarray(rgba_data)
else:
# Grayscale
self._pil_image = PILImage.fromarray(self._data)
logger.info( logger.info(
f"Successfully loaded image: {self.path.name} " f"Successfully loaded image: {self.path.name} "
f"({self._width}x{self._height}, {self._channels} channels, " f"({self._width}x{self._height}, {self._channels} channels, "
@@ -131,18 +157,6 @@ class Image:
raise ImageLoadError("Image data not available") raise ImageLoadError("Image data not available")
return self._data return self._data
@property
def pil_image(self) -> PILImage.Image:
"""
Get image data as PIL Image (RGB or grayscale).
Returns:
PIL Image object
"""
if self._pil_image is None:
raise ImageLoadError("PIL image not available")
return self._pil_image
@property @property
def width(self) -> int: def width(self) -> int:
"""Get image width in pixels.""" """Get image width in pixels."""
@@ -187,6 +201,7 @@ class Image:
@property @property
def dtype(self) -> np.dtype: def dtype(self) -> np.dtype:
"""Get the data type of the image array.""" """Get the data type of the image array."""
if self._dtype is None: if self._dtype is None:
raise ImageLoadError("Image dtype not available") raise ImageLoadError("Image dtype not available")
return self._dtype return self._dtype
@@ -206,8 +221,10 @@ class Image:
elif self._channels == 1: elif self._channels == 1:
if self._dtype == np.uint16: if self._dtype == np.uint16:
return QImage.Format_Grayscale16 return QImage.Format_Grayscale16
else: elif self._dtype == np.uint8:
return QImage.Format_Grayscale8 return QImage.Format_Grayscale8
elif self._dtype == np.float32:
return QImage.Format_BGR30
else: else:
raise ImageLoadError(f"Unsupported number of channels: {self._channels}") raise ImageLoadError(f"Unsupported number of channels: {self._channels}")
@@ -218,6 +235,12 @@ class Image:
Returns: Returns:
Image data in RGB format as numpy array Image data in RGB format as numpy array
""" """
if self.channels == 1:
img = get_pseudo_rgb(self.data)
self._dtype = img.dtype
return img
raise NotImplementedError
if self._channels == 3: if self._channels == 3:
return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB) return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
elif self._channels == 4: elif self._channels == 4:
@@ -225,6 +248,18 @@ class Image:
else: else:
return self._data return self._data
def get_qt_rgb(self) -> np.ascontiguousarray:
# we keep data as (C, H, W)
_img = self.get_rgb()
img = np.zeros((self.height, self.width, 4), dtype=np.float32)
img[..., 0] = _img[0] # R gradient
img[..., 1] = _img[1] # G gradient
img[..., 2] = _img[2] # B constant
img[..., 3] = 1.0 # A = 1.0 (opaque)
return np.ascontiguousarray(img)
def get_grayscale(self) -> np.ndarray: def get_grayscale(self) -> np.ndarray:
""" """
Get image as grayscale numpy array. Get image as grayscale numpy array.
@@ -277,37 +312,18 @@ class Image:
""" """
return self._channels >= 3 return self._channels >= 3
def convert_grayscale_to_rgb_preserve_range( def save(self, path: Union[str, Path], pseudo_rgb: bool = True) -> None:
self,
) -> PILImage.Image:
"""Convert a single-channel PIL image to RGB while preserving dynamic range.
Returns: if self.channels == 1:
PIL Image in RGB mode with intensities normalized to 0-255. if pseudo_rgb:
""" img = get_pseudo_rgb(self.data)
if self._channels == 3:
return self.pil_image
grayscale = self.data
if grayscale.ndim == 3:
grayscale = grayscale[:, :, 0]
original_dtype = grayscale.dtype
grayscale = grayscale.astype(np.float32)
if grayscale.size == 0:
return PILImage.new("RGB", self.shape, color=(0, 0, 0))
if np.issubdtype(original_dtype, np.integer):
denom = float(max(np.iinfo(original_dtype).max, 1))
else: else:
max_val = float(grayscale.max()) img = np.repeat(self.data, 3, axis=2)
denom = max(max_val, 1.0)
grayscale = np.clip(grayscale / denom, 0.0, 1.0) else:
grayscale_u8 = (grayscale * 255.0).round().astype(np.uint8) raise NotImplementedError("Only grayscale images are supported for now.")
rgb_arr = np.repeat(grayscale_u8[:, :, None], 3, axis=2)
return PILImage.fromarray(rgb_arr, mode="RGB") imwrite(path, data=img)
def __repr__(self) -> str: def __repr__(self) -> str:
"""String representation of the Image object.""" """String representation of the Image object."""
@@ -322,3 +338,15 @@ class Image:
def __str__(self) -> str: def __str__(self) -> str:
"""String representation of the Image object.""" """String representation of the Image object."""
return self.__repr__() return self.__repr__()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--path", type=str, required=True)
args = parser.parse_args()
img = Image(args.path)
img.save(args.path + "test.tif")
print(img)

View File

@@ -32,8 +32,10 @@ def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
import cv2 import cv2
import numpy as np import numpy as np
import tifffile
# import tifffile
import torch import torch
from src.utils.image import Image
from ultralytics.utils import patches as ul_patches from ultralytics.utils import patches as ul_patches
@@ -55,7 +57,7 @@ def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
ext = os.path.splitext(filename)[1].lower() ext = os.path.splitext(filename)[1].lower()
if ext in (".tif", ".tiff"): if ext in (".tif", ".tiff"):
arr = tifffile.imread(filename) arr = Image(filename).get_qt_rgb()[:, :, :3]
# Normalize common shapes: # Normalize common shapes:
# - (H, W) -> (H, W, 1) # - (H, W) -> (H, W, 1)
@@ -71,35 +73,6 @@ def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
if arr.ndim == 2: if arr.ndim == 2:
arr = arr[..., None] arr = arr[..., None]
# Ultralytics expects BGR ordering when `channels=3`.
# For grayscale data we replicate channels (no scaling, no quantization).
if flags != cv2.IMREAD_GRAYSCALE:
if arr.shape[2] == 1:
if pseudo_rgb:
gamma = 0.3
a1 = arr.copy().astype(np.float32)
a1 -= np.percentile(a1, 2)
a1[a1 < 0] = 0
p98 = np.percentile(a1, 98)
a1[a1 > p98] = p98
a1 /= a1.max()
a2 = a1.copy()
a2 = a2**gamma
a2 /= a2.max()
a3 = a1.copy()
p90 = np.percentile(a3, 90)
a3[a3 > p90] = p90
a3 /= a3.max()
arr = np.concatenate([a1, a2, a3], axis=2)
else:
arr = np.repeat(arr, 3, axis=2)
elif arr.shape[2] >= 3:
arr = arr[:, :, :3]
# Ensure contiguous array for downstream OpenCV ops. # Ensure contiguous array for downstream OpenCV ops.
logger.info(f"Loading with monkey-patched imread: {filename}") logger.info(f"Loading with monkey-patched imread: {filename}")
return np.ascontiguousarray(arr) return np.ascontiguousarray(arr)