From 061f8b3ca29049620f90ebef6a680e90c37d60d1 Mon Sep 17 00:00:00 2001 From: Martin Laasmaa Date: Fri, 19 Dec 2025 09:56:43 +0200 Subject: [PATCH] Fixing pseudo rgb --- src/gui/widgets/annotation_canvas_widget.py | 8 +- src/model/yolo_wrapper.py | 8 +- src/utils/image.py | 138 ++++++++++++-------- src/utils/ultralytics_16bit_patch.py | 35 +---- 4 files changed, 91 insertions(+), 98 deletions(-) diff --git a/src/gui/widgets/annotation_canvas_widget.py b/src/gui/widgets/annotation_canvas_widget.py index 94e03cd..f67118d 100644 --- a/src/gui/widgets/annotation_canvas_widget.py +++ b/src/gui/widgets/annotation_canvas_widget.py @@ -250,12 +250,10 @@ class AnnotationCanvasWidget(QWidget): # Get image data in a format compatible with Qt if self.current_image.channels in (3, 4): image_data = self.current_image.get_rgb() - height, width = image_data.shape[:2] else: - image_data = self.current_image.get_grayscale() - height, width = image_data.shape + image_data = self.current_image.get_qt_rgb() - image_data = np.ascontiguousarray(image_data) + height, width = image_data.shape[:2] bytes_per_line = image_data.strides[0] qimage = QImage( @@ -263,7 +261,7 @@ class AnnotationCanvasWidget(QWidget): width, height, bytes_per_line, - self.current_image.qtimage_format, + QImage.Format_RGBX32FPx4, # self.current_image.qtimage_format, ).copy() # Copy so Qt owns the buffer even after numpy array goes out of scope self.original_pixmap = QPixmap.fromImage(qimage) diff --git a/src/model/yolo_wrapper.py b/src/model/yolo_wrapper.py index 21b1931..c02c78e 100644 --- a/src/model/yolo_wrapper.py +++ b/src/model/yolo_wrapper.py @@ -257,17 +257,11 @@ class YOLOWrapper: if source_path.is_file(): try: img_obj = Image(source_path) - pil_img = img_obj.pil_image - if len(pil_img.getbands()) == 1: - rgb_img = img_obj.convert_grayscale_to_rgb_preserve_range() - else: - rgb_img = pil_img.convert("RGB") - suffix = source_path.suffix or ".png" tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) tmp_path = tmp.name tmp.close() - rgb_img.save(tmp_path) + img_obj.save(tmp_path) cleanup_path = tmp_path logger.info( f"Converted image {source_path} to RGB for inference at {tmp_path}" diff --git a/src/utils/image.py b/src/utils/image.py index 578ccde..c51590f 100644 --- a/src/utils/image.py +++ b/src/utils/image.py @@ -6,16 +6,49 @@ import cv2 import numpy as np from pathlib import Path from typing import Optional, Tuple, Union -from PIL import Image as PILImage from src.utils.logger import get_logger from src.utils.file_utils import validate_file_path, is_image_file from PySide6.QtGui import QImage +from tifffile import imread, imwrite + logger = get_logger(__name__) +def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.3) -> np.ndarray: + """ + Convert a grayscale image to a pseudo-RGB image using a gamma correction. + + Args: + arr: Input grayscale image as numpy array + + Returns: + Pseudo-RGB image as numpy array + """ + if arr.ndim != 2: + raise ValueError("Input array must be a grayscale image with shape (H, W)") + + a1 = arr.copy().astype(np.float32) + a1 -= np.percentile(a1, 2) + a1[a1 < 0] = 0 + # p999 = np.percentile(a1, 99.9) + # a1[a1 > p999] = p999 + a1 /= a1.max() + + a2 = a1.copy() + a2 = a2**gamma + a2 /= a2.max() + + a3 = a1.copy() + p9999 = np.percentile(a3, 99.99) + a3[a3 > p9999] = p9999 + a3 /= a3.max() + + return np.stack([a1, a2, a3], axis=0) + + class ImageLoadError(Exception): """Exception raised when an image cannot be loaded.""" @@ -54,7 +87,6 @@ class Image: """ self.path = Path(image_path) self._data: Optional[np.ndarray] = None - self._pil_image: Optional[PILImage.Image] = None self._width: int = 0 self._height: int = 0 self._channels: int = 0 @@ -85,30 +117,24 @@ class Image: ) try: - # Load with OpenCV (returns BGR format) - self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED) + if self.path.suffix.lower() in [".tif", ".tiff"]: + self._data = imread(str(self.path)) + else: + raise NotImplementedError + # Load with OpenCV (returns BGR format) + self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED) if self._data is None: raise ImageLoadError(f"Failed to load image with OpenCV: {self.path}") # Extract metadata + print(self._data.shape) self._height, self._width = self._data.shape[:2] self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1 self._format = self.path.suffix.lower().lstrip(".") self._size_bytes = self.path.stat().st_size self._dtype = self._data.dtype - # Load PIL version for compatibility (convert BGR to RGB) - if self._channels == 3: - rgb_data = cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB) - self._pil_image = PILImage.fromarray(rgb_data) - elif self._channels == 4: - rgba_data = cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA) - self._pil_image = PILImage.fromarray(rgba_data) - else: - # Grayscale - self._pil_image = PILImage.fromarray(self._data) - logger.info( f"Successfully loaded image: {self.path.name} " f"({self._width}x{self._height}, {self._channels} channels, " @@ -131,18 +157,6 @@ class Image: raise ImageLoadError("Image data not available") return self._data - @property - def pil_image(self) -> PILImage.Image: - """ - Get image data as PIL Image (RGB or grayscale). - - Returns: - PIL Image object - """ - if self._pil_image is None: - raise ImageLoadError("PIL image not available") - return self._pil_image - @property def width(self) -> int: """Get image width in pixels.""" @@ -187,6 +201,7 @@ class Image: @property def dtype(self) -> np.dtype: """Get the data type of the image array.""" + if self._dtype is None: raise ImageLoadError("Image dtype not available") return self._dtype @@ -206,8 +221,10 @@ class Image: elif self._channels == 1: if self._dtype == np.uint16: return QImage.Format_Grayscale16 - else: + elif self._dtype == np.uint8: return QImage.Format_Grayscale8 + elif self._dtype == np.float32: + return QImage.Format_BGR30 else: raise ImageLoadError(f"Unsupported number of channels: {self._channels}") @@ -218,6 +235,12 @@ class Image: Returns: Image data in RGB format as numpy array """ + if self.channels == 1: + img = get_pseudo_rgb(self.data) + self._dtype = img.dtype + return img + raise NotImplementedError + if self._channels == 3: return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB) elif self._channels == 4: @@ -225,6 +248,18 @@ class Image: else: return self._data + def get_qt_rgb(self) -> np.ascontiguousarray: + # we keep data as (C, H, W) + _img = self.get_rgb() + + img = np.zeros((self.height, self.width, 4), dtype=np.float32) + img[..., 0] = _img[0] # R gradient + img[..., 1] = _img[1] # G gradient + img[..., 2] = _img[2] # B constant + img[..., 3] = 1.0 # A = 1.0 (opaque) + + return np.ascontiguousarray(img) + def get_grayscale(self) -> np.ndarray: """ Get image as grayscale numpy array. @@ -277,37 +312,18 @@ class Image: """ return self._channels >= 3 - def convert_grayscale_to_rgb_preserve_range( - self, - ) -> PILImage.Image: - """Convert a single-channel PIL image to RGB while preserving dynamic range. + def save(self, path: Union[str, Path], pseudo_rgb: bool = True) -> None: - Returns: - PIL Image in RGB mode with intensities normalized to 0-255. - """ - if self._channels == 3: - return self.pil_image + if self.channels == 1: + if pseudo_rgb: + img = get_pseudo_rgb(self.data) + else: + img = np.repeat(self.data, 3, axis=2) - grayscale = self.data - if grayscale.ndim == 3: - grayscale = grayscale[:, :, 0] - - original_dtype = grayscale.dtype - grayscale = grayscale.astype(np.float32) - - if grayscale.size == 0: - return PILImage.new("RGB", self.shape, color=(0, 0, 0)) - - if np.issubdtype(original_dtype, np.integer): - denom = float(max(np.iinfo(original_dtype).max, 1)) else: - max_val = float(grayscale.max()) - denom = max(max_val, 1.0) + raise NotImplementedError("Only grayscale images are supported for now.") - grayscale = np.clip(grayscale / denom, 0.0, 1.0) - grayscale_u8 = (grayscale * 255.0).round().astype(np.uint8) - rgb_arr = np.repeat(grayscale_u8[:, :, None], 3, axis=2) - return PILImage.fromarray(rgb_arr, mode="RGB") + imwrite(path, data=img) def __repr__(self) -> str: """String representation of the Image object.""" @@ -322,3 +338,15 @@ class Image: def __str__(self) -> str: """String representation of the Image object.""" return self.__repr__() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--path", type=str, required=True) + args = parser.parse_args() + + img = Image(args.path) + img.save(args.path + "test.tif") + print(img) diff --git a/src/utils/ultralytics_16bit_patch.py b/src/utils/ultralytics_16bit_patch.py index dc3706f..c976bc1 100644 --- a/src/utils/ultralytics_16bit_patch.py +++ b/src/utils/ultralytics_16bit_patch.py @@ -32,8 +32,10 @@ def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None: import cv2 import numpy as np - import tifffile + + # import tifffile import torch + from src.utils.image import Image from ultralytics.utils import patches as ul_patches @@ -55,7 +57,7 @@ def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None: ext = os.path.splitext(filename)[1].lower() if ext in (".tif", ".tiff"): - arr = tifffile.imread(filename) + arr = Image(filename).get_qt_rgb()[:, :, :3] # Normalize common shapes: # - (H, W) -> (H, W, 1) @@ -71,35 +73,6 @@ def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None: if arr.ndim == 2: arr = arr[..., None] - # Ultralytics expects BGR ordering when `channels=3`. - # For grayscale data we replicate channels (no scaling, no quantization). - if flags != cv2.IMREAD_GRAYSCALE: - if arr.shape[2] == 1: - if pseudo_rgb: - gamma = 0.3 - a1 = arr.copy().astype(np.float32) - a1 -= np.percentile(a1, 2) - a1[a1 < 0] = 0 - p98 = np.percentile(a1, 98) - a1[a1 > p98] = p98 - a1 /= a1.max() - - a2 = a1.copy() - a2 = a2**gamma - a2 /= a2.max() - - a3 = a1.copy() - p90 = np.percentile(a3, 90) - a3[a3 > p90] = p90 - a3 /= a3.max() - - arr = np.concatenate([a1, a2, a3], axis=2) - - else: - arr = np.repeat(arr, 3, axis=2) - elif arr.shape[2] >= 3: - arr = arr[:, :, :3] - # Ensure contiguous array for downstream OpenCV ops. logger.info(f"Loading with monkey-patched imread: {filename}") return np.ascontiguousarray(arr)