Fixing pseudo rgb
This commit is contained in:
@@ -250,12 +250,10 @@ class AnnotationCanvasWidget(QWidget):
|
||||
# Get image data in a format compatible with Qt
|
||||
if self.current_image.channels in (3, 4):
|
||||
image_data = self.current_image.get_rgb()
|
||||
height, width = image_data.shape[:2]
|
||||
else:
|
||||
image_data = self.current_image.get_grayscale()
|
||||
height, width = image_data.shape
|
||||
image_data = self.current_image.get_qt_rgb()
|
||||
|
||||
image_data = np.ascontiguousarray(image_data)
|
||||
height, width = image_data.shape[:2]
|
||||
bytes_per_line = image_data.strides[0]
|
||||
|
||||
qimage = QImage(
|
||||
@@ -263,7 +261,7 @@ class AnnotationCanvasWidget(QWidget):
|
||||
width,
|
||||
height,
|
||||
bytes_per_line,
|
||||
self.current_image.qtimage_format,
|
||||
QImage.Format_RGBX32FPx4, # self.current_image.qtimage_format,
|
||||
).copy() # Copy so Qt owns the buffer even after numpy array goes out of scope
|
||||
|
||||
self.original_pixmap = QPixmap.fromImage(qimage)
|
||||
|
||||
@@ -257,17 +257,11 @@ class YOLOWrapper:
|
||||
if source_path.is_file():
|
||||
try:
|
||||
img_obj = Image(source_path)
|
||||
pil_img = img_obj.pil_image
|
||||
if len(pil_img.getbands()) == 1:
|
||||
rgb_img = img_obj.convert_grayscale_to_rgb_preserve_range()
|
||||
else:
|
||||
rgb_img = pil_img.convert("RGB")
|
||||
|
||||
suffix = source_path.suffix or ".png"
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
|
||||
tmp_path = tmp.name
|
||||
tmp.close()
|
||||
rgb_img.save(tmp_path)
|
||||
img_obj.save(tmp_path)
|
||||
cleanup_path = tmp_path
|
||||
logger.info(
|
||||
f"Converted image {source_path} to RGB for inference at {tmp_path}"
|
||||
|
||||
@@ -6,16 +6,49 @@ import cv2
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.file_utils import validate_file_path, is_image_file
|
||||
|
||||
from PySide6.QtGui import QImage
|
||||
|
||||
from tifffile import imread, imwrite
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.3) -> np.ndarray:
|
||||
"""
|
||||
Convert a grayscale image to a pseudo-RGB image using a gamma correction.
|
||||
|
||||
Args:
|
||||
arr: Input grayscale image as numpy array
|
||||
|
||||
Returns:
|
||||
Pseudo-RGB image as numpy array
|
||||
"""
|
||||
if arr.ndim != 2:
|
||||
raise ValueError("Input array must be a grayscale image with shape (H, W)")
|
||||
|
||||
a1 = arr.copy().astype(np.float32)
|
||||
a1 -= np.percentile(a1, 2)
|
||||
a1[a1 < 0] = 0
|
||||
# p999 = np.percentile(a1, 99.9)
|
||||
# a1[a1 > p999] = p999
|
||||
a1 /= a1.max()
|
||||
|
||||
a2 = a1.copy()
|
||||
a2 = a2**gamma
|
||||
a2 /= a2.max()
|
||||
|
||||
a3 = a1.copy()
|
||||
p9999 = np.percentile(a3, 99.99)
|
||||
a3[a3 > p9999] = p9999
|
||||
a3 /= a3.max()
|
||||
|
||||
return np.stack([a1, a2, a3], axis=0)
|
||||
|
||||
|
||||
class ImageLoadError(Exception):
|
||||
"""Exception raised when an image cannot be loaded."""
|
||||
|
||||
@@ -54,7 +87,6 @@ class Image:
|
||||
"""
|
||||
self.path = Path(image_path)
|
||||
self._data: Optional[np.ndarray] = None
|
||||
self._pil_image: Optional[PILImage.Image] = None
|
||||
self._width: int = 0
|
||||
self._height: int = 0
|
||||
self._channels: int = 0
|
||||
@@ -85,6 +117,10 @@ class Image:
|
||||
)
|
||||
|
||||
try:
|
||||
if self.path.suffix.lower() in [".tif", ".tiff"]:
|
||||
self._data = imread(str(self.path))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
# Load with OpenCV (returns BGR format)
|
||||
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
|
||||
|
||||
@@ -92,23 +128,13 @@ class Image:
|
||||
raise ImageLoadError(f"Failed to load image with OpenCV: {self.path}")
|
||||
|
||||
# Extract metadata
|
||||
print(self._data.shape)
|
||||
self._height, self._width = self._data.shape[:2]
|
||||
self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1
|
||||
self._format = self.path.suffix.lower().lstrip(".")
|
||||
self._size_bytes = self.path.stat().st_size
|
||||
self._dtype = self._data.dtype
|
||||
|
||||
# Load PIL version for compatibility (convert BGR to RGB)
|
||||
if self._channels == 3:
|
||||
rgb_data = cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
|
||||
self._pil_image = PILImage.fromarray(rgb_data)
|
||||
elif self._channels == 4:
|
||||
rgba_data = cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA)
|
||||
self._pil_image = PILImage.fromarray(rgba_data)
|
||||
else:
|
||||
# Grayscale
|
||||
self._pil_image = PILImage.fromarray(self._data)
|
||||
|
||||
logger.info(
|
||||
f"Successfully loaded image: {self.path.name} "
|
||||
f"({self._width}x{self._height}, {self._channels} channels, "
|
||||
@@ -131,18 +157,6 @@ class Image:
|
||||
raise ImageLoadError("Image data not available")
|
||||
return self._data
|
||||
|
||||
@property
|
||||
def pil_image(self) -> PILImage.Image:
|
||||
"""
|
||||
Get image data as PIL Image (RGB or grayscale).
|
||||
|
||||
Returns:
|
||||
PIL Image object
|
||||
"""
|
||||
if self._pil_image is None:
|
||||
raise ImageLoadError("PIL image not available")
|
||||
return self._pil_image
|
||||
|
||||
@property
|
||||
def width(self) -> int:
|
||||
"""Get image width in pixels."""
|
||||
@@ -187,6 +201,7 @@ class Image:
|
||||
@property
|
||||
def dtype(self) -> np.dtype:
|
||||
"""Get the data type of the image array."""
|
||||
|
||||
if self._dtype is None:
|
||||
raise ImageLoadError("Image dtype not available")
|
||||
return self._dtype
|
||||
@@ -206,8 +221,10 @@ class Image:
|
||||
elif self._channels == 1:
|
||||
if self._dtype == np.uint16:
|
||||
return QImage.Format_Grayscale16
|
||||
else:
|
||||
elif self._dtype == np.uint8:
|
||||
return QImage.Format_Grayscale8
|
||||
elif self._dtype == np.float32:
|
||||
return QImage.Format_BGR30
|
||||
else:
|
||||
raise ImageLoadError(f"Unsupported number of channels: {self._channels}")
|
||||
|
||||
@@ -218,6 +235,12 @@ class Image:
|
||||
Returns:
|
||||
Image data in RGB format as numpy array
|
||||
"""
|
||||
if self.channels == 1:
|
||||
img = get_pseudo_rgb(self.data)
|
||||
self._dtype = img.dtype
|
||||
return img
|
||||
raise NotImplementedError
|
||||
|
||||
if self._channels == 3:
|
||||
return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
|
||||
elif self._channels == 4:
|
||||
@@ -225,6 +248,18 @@ class Image:
|
||||
else:
|
||||
return self._data
|
||||
|
||||
def get_qt_rgb(self) -> np.ascontiguousarray:
|
||||
# we keep data as (C, H, W)
|
||||
_img = self.get_rgb()
|
||||
|
||||
img = np.zeros((self.height, self.width, 4), dtype=np.float32)
|
||||
img[..., 0] = _img[0] # R gradient
|
||||
img[..., 1] = _img[1] # G gradient
|
||||
img[..., 2] = _img[2] # B constant
|
||||
img[..., 3] = 1.0 # A = 1.0 (opaque)
|
||||
|
||||
return np.ascontiguousarray(img)
|
||||
|
||||
def get_grayscale(self) -> np.ndarray:
|
||||
"""
|
||||
Get image as grayscale numpy array.
|
||||
@@ -277,37 +312,18 @@ class Image:
|
||||
"""
|
||||
return self._channels >= 3
|
||||
|
||||
def convert_grayscale_to_rgb_preserve_range(
|
||||
self,
|
||||
) -> PILImage.Image:
|
||||
"""Convert a single-channel PIL image to RGB while preserving dynamic range.
|
||||
def save(self, path: Union[str, Path], pseudo_rgb: bool = True) -> None:
|
||||
|
||||
Returns:
|
||||
PIL Image in RGB mode with intensities normalized to 0-255.
|
||||
"""
|
||||
if self._channels == 3:
|
||||
return self.pil_image
|
||||
|
||||
grayscale = self.data
|
||||
if grayscale.ndim == 3:
|
||||
grayscale = grayscale[:, :, 0]
|
||||
|
||||
original_dtype = grayscale.dtype
|
||||
grayscale = grayscale.astype(np.float32)
|
||||
|
||||
if grayscale.size == 0:
|
||||
return PILImage.new("RGB", self.shape, color=(0, 0, 0))
|
||||
|
||||
if np.issubdtype(original_dtype, np.integer):
|
||||
denom = float(max(np.iinfo(original_dtype).max, 1))
|
||||
if self.channels == 1:
|
||||
if pseudo_rgb:
|
||||
img = get_pseudo_rgb(self.data)
|
||||
else:
|
||||
max_val = float(grayscale.max())
|
||||
denom = max(max_val, 1.0)
|
||||
img = np.repeat(self.data, 3, axis=2)
|
||||
|
||||
grayscale = np.clip(grayscale / denom, 0.0, 1.0)
|
||||
grayscale_u8 = (grayscale * 255.0).round().astype(np.uint8)
|
||||
rgb_arr = np.repeat(grayscale_u8[:, :, None], 3, axis=2)
|
||||
return PILImage.fromarray(rgb_arr, mode="RGB")
|
||||
else:
|
||||
raise NotImplementedError("Only grayscale images are supported for now.")
|
||||
|
||||
imwrite(path, data=img)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of the Image object."""
|
||||
@@ -322,3 +338,15 @@ class Image:
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the Image object."""
|
||||
return self.__repr__()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--path", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
img = Image(args.path)
|
||||
img.save(args.path + "test.tif")
|
||||
print(img)
|
||||
|
||||
@@ -32,8 +32,10 @@ def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import tifffile
|
||||
|
||||
# import tifffile
|
||||
import torch
|
||||
from src.utils.image import Image
|
||||
|
||||
from ultralytics.utils import patches as ul_patches
|
||||
|
||||
@@ -55,7 +57,7 @@ def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
|
||||
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
if ext in (".tif", ".tiff"):
|
||||
arr = tifffile.imread(filename)
|
||||
arr = Image(filename).get_qt_rgb()[:, :, :3]
|
||||
|
||||
# Normalize common shapes:
|
||||
# - (H, W) -> (H, W, 1)
|
||||
@@ -71,35 +73,6 @@ def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
|
||||
if arr.ndim == 2:
|
||||
arr = arr[..., None]
|
||||
|
||||
# Ultralytics expects BGR ordering when `channels=3`.
|
||||
# For grayscale data we replicate channels (no scaling, no quantization).
|
||||
if flags != cv2.IMREAD_GRAYSCALE:
|
||||
if arr.shape[2] == 1:
|
||||
if pseudo_rgb:
|
||||
gamma = 0.3
|
||||
a1 = arr.copy().astype(np.float32)
|
||||
a1 -= np.percentile(a1, 2)
|
||||
a1[a1 < 0] = 0
|
||||
p98 = np.percentile(a1, 98)
|
||||
a1[a1 > p98] = p98
|
||||
a1 /= a1.max()
|
||||
|
||||
a2 = a1.copy()
|
||||
a2 = a2**gamma
|
||||
a2 /= a2.max()
|
||||
|
||||
a3 = a1.copy()
|
||||
p90 = np.percentile(a3, 90)
|
||||
a3[a3 > p90] = p90
|
||||
a3 /= a3.max()
|
||||
|
||||
arr = np.concatenate([a1, a2, a3], axis=2)
|
||||
|
||||
else:
|
||||
arr = np.repeat(arr, 3, axis=2)
|
||||
elif arr.shape[2] >= 3:
|
||||
arr = arr[:, :, :3]
|
||||
|
||||
# Ensure contiguous array for downstream OpenCV ops.
|
||||
logger.info(f"Loading with monkey-patched imread: {filename}")
|
||||
return np.ascontiguousarray(arr)
|
||||
|
||||
Reference in New Issue
Block a user