Fixing pseudo rgb
This commit is contained in:
@@ -250,12 +250,10 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
# Get image data in a format compatible with Qt
|
# Get image data in a format compatible with Qt
|
||||||
if self.current_image.channels in (3, 4):
|
if self.current_image.channels in (3, 4):
|
||||||
image_data = self.current_image.get_rgb()
|
image_data = self.current_image.get_rgb()
|
||||||
height, width = image_data.shape[:2]
|
|
||||||
else:
|
else:
|
||||||
image_data = self.current_image.get_grayscale()
|
image_data = self.current_image.get_qt_rgb()
|
||||||
height, width = image_data.shape
|
|
||||||
|
|
||||||
image_data = np.ascontiguousarray(image_data)
|
height, width = image_data.shape[:2]
|
||||||
bytes_per_line = image_data.strides[0]
|
bytes_per_line = image_data.strides[0]
|
||||||
|
|
||||||
qimage = QImage(
|
qimage = QImage(
|
||||||
@@ -263,7 +261,7 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
bytes_per_line,
|
bytes_per_line,
|
||||||
self.current_image.qtimage_format,
|
QImage.Format_RGBX32FPx4, # self.current_image.qtimage_format,
|
||||||
).copy() # Copy so Qt owns the buffer even after numpy array goes out of scope
|
).copy() # Copy so Qt owns the buffer even after numpy array goes out of scope
|
||||||
|
|
||||||
self.original_pixmap = QPixmap.fromImage(qimage)
|
self.original_pixmap = QPixmap.fromImage(qimage)
|
||||||
|
|||||||
@@ -257,17 +257,11 @@ class YOLOWrapper:
|
|||||||
if source_path.is_file():
|
if source_path.is_file():
|
||||||
try:
|
try:
|
||||||
img_obj = Image(source_path)
|
img_obj = Image(source_path)
|
||||||
pil_img = img_obj.pil_image
|
|
||||||
if len(pil_img.getbands()) == 1:
|
|
||||||
rgb_img = img_obj.convert_grayscale_to_rgb_preserve_range()
|
|
||||||
else:
|
|
||||||
rgb_img = pil_img.convert("RGB")
|
|
||||||
|
|
||||||
suffix = source_path.suffix or ".png"
|
suffix = source_path.suffix or ".png"
|
||||||
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
|
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
|
||||||
tmp_path = tmp.name
|
tmp_path = tmp.name
|
||||||
tmp.close()
|
tmp.close()
|
||||||
rgb_img.save(tmp_path)
|
img_obj.save(tmp_path)
|
||||||
cleanup_path = tmp_path
|
cleanup_path = tmp_path
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Converted image {source_path} to RGB for inference at {tmp_path}"
|
f"Converted image {source_path} to RGB for inference at {tmp_path}"
|
||||||
|
|||||||
@@ -6,16 +6,49 @@ import cv2
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
from PIL import Image as PILImage
|
|
||||||
|
|
||||||
from src.utils.logger import get_logger
|
from src.utils.logger import get_logger
|
||||||
from src.utils.file_utils import validate_file_path, is_image_file
|
from src.utils.file_utils import validate_file_path, is_image_file
|
||||||
|
|
||||||
from PySide6.QtGui import QImage
|
from PySide6.QtGui import QImage
|
||||||
|
|
||||||
|
from tifffile import imread, imwrite
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.3) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Convert a grayscale image to a pseudo-RGB image using a gamma correction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arr: Input grayscale image as numpy array
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pseudo-RGB image as numpy array
|
||||||
|
"""
|
||||||
|
if arr.ndim != 2:
|
||||||
|
raise ValueError("Input array must be a grayscale image with shape (H, W)")
|
||||||
|
|
||||||
|
a1 = arr.copy().astype(np.float32)
|
||||||
|
a1 -= np.percentile(a1, 2)
|
||||||
|
a1[a1 < 0] = 0
|
||||||
|
# p999 = np.percentile(a1, 99.9)
|
||||||
|
# a1[a1 > p999] = p999
|
||||||
|
a1 /= a1.max()
|
||||||
|
|
||||||
|
a2 = a1.copy()
|
||||||
|
a2 = a2**gamma
|
||||||
|
a2 /= a2.max()
|
||||||
|
|
||||||
|
a3 = a1.copy()
|
||||||
|
p9999 = np.percentile(a3, 99.99)
|
||||||
|
a3[a3 > p9999] = p9999
|
||||||
|
a3 /= a3.max()
|
||||||
|
|
||||||
|
return np.stack([a1, a2, a3], axis=0)
|
||||||
|
|
||||||
|
|
||||||
class ImageLoadError(Exception):
|
class ImageLoadError(Exception):
|
||||||
"""Exception raised when an image cannot be loaded."""
|
"""Exception raised when an image cannot be loaded."""
|
||||||
|
|
||||||
@@ -54,7 +87,6 @@ class Image:
|
|||||||
"""
|
"""
|
||||||
self.path = Path(image_path)
|
self.path = Path(image_path)
|
||||||
self._data: Optional[np.ndarray] = None
|
self._data: Optional[np.ndarray] = None
|
||||||
self._pil_image: Optional[PILImage.Image] = None
|
|
||||||
self._width: int = 0
|
self._width: int = 0
|
||||||
self._height: int = 0
|
self._height: int = 0
|
||||||
self._channels: int = 0
|
self._channels: int = 0
|
||||||
@@ -85,30 +117,24 @@ class Image:
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load with OpenCV (returns BGR format)
|
if self.path.suffix.lower() in [".tif", ".tiff"]:
|
||||||
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
|
self._data = imread(str(self.path))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
# Load with OpenCV (returns BGR format)
|
||||||
|
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
|
||||||
|
|
||||||
if self._data is None:
|
if self._data is None:
|
||||||
raise ImageLoadError(f"Failed to load image with OpenCV: {self.path}")
|
raise ImageLoadError(f"Failed to load image with OpenCV: {self.path}")
|
||||||
|
|
||||||
# Extract metadata
|
# Extract metadata
|
||||||
|
print(self._data.shape)
|
||||||
self._height, self._width = self._data.shape[:2]
|
self._height, self._width = self._data.shape[:2]
|
||||||
self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1
|
self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1
|
||||||
self._format = self.path.suffix.lower().lstrip(".")
|
self._format = self.path.suffix.lower().lstrip(".")
|
||||||
self._size_bytes = self.path.stat().st_size
|
self._size_bytes = self.path.stat().st_size
|
||||||
self._dtype = self._data.dtype
|
self._dtype = self._data.dtype
|
||||||
|
|
||||||
# Load PIL version for compatibility (convert BGR to RGB)
|
|
||||||
if self._channels == 3:
|
|
||||||
rgb_data = cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
|
|
||||||
self._pil_image = PILImage.fromarray(rgb_data)
|
|
||||||
elif self._channels == 4:
|
|
||||||
rgba_data = cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA)
|
|
||||||
self._pil_image = PILImage.fromarray(rgba_data)
|
|
||||||
else:
|
|
||||||
# Grayscale
|
|
||||||
self._pil_image = PILImage.fromarray(self._data)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Successfully loaded image: {self.path.name} "
|
f"Successfully loaded image: {self.path.name} "
|
||||||
f"({self._width}x{self._height}, {self._channels} channels, "
|
f"({self._width}x{self._height}, {self._channels} channels, "
|
||||||
@@ -131,18 +157,6 @@ class Image:
|
|||||||
raise ImageLoadError("Image data not available")
|
raise ImageLoadError("Image data not available")
|
||||||
return self._data
|
return self._data
|
||||||
|
|
||||||
@property
|
|
||||||
def pil_image(self) -> PILImage.Image:
|
|
||||||
"""
|
|
||||||
Get image data as PIL Image (RGB or grayscale).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
PIL Image object
|
|
||||||
"""
|
|
||||||
if self._pil_image is None:
|
|
||||||
raise ImageLoadError("PIL image not available")
|
|
||||||
return self._pil_image
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def width(self) -> int:
|
def width(self) -> int:
|
||||||
"""Get image width in pixels."""
|
"""Get image width in pixels."""
|
||||||
@@ -187,6 +201,7 @@ class Image:
|
|||||||
@property
|
@property
|
||||||
def dtype(self) -> np.dtype:
|
def dtype(self) -> np.dtype:
|
||||||
"""Get the data type of the image array."""
|
"""Get the data type of the image array."""
|
||||||
|
|
||||||
if self._dtype is None:
|
if self._dtype is None:
|
||||||
raise ImageLoadError("Image dtype not available")
|
raise ImageLoadError("Image dtype not available")
|
||||||
return self._dtype
|
return self._dtype
|
||||||
@@ -206,8 +221,10 @@ class Image:
|
|||||||
elif self._channels == 1:
|
elif self._channels == 1:
|
||||||
if self._dtype == np.uint16:
|
if self._dtype == np.uint16:
|
||||||
return QImage.Format_Grayscale16
|
return QImage.Format_Grayscale16
|
||||||
else:
|
elif self._dtype == np.uint8:
|
||||||
return QImage.Format_Grayscale8
|
return QImage.Format_Grayscale8
|
||||||
|
elif self._dtype == np.float32:
|
||||||
|
return QImage.Format_BGR30
|
||||||
else:
|
else:
|
||||||
raise ImageLoadError(f"Unsupported number of channels: {self._channels}")
|
raise ImageLoadError(f"Unsupported number of channels: {self._channels}")
|
||||||
|
|
||||||
@@ -218,6 +235,12 @@ class Image:
|
|||||||
Returns:
|
Returns:
|
||||||
Image data in RGB format as numpy array
|
Image data in RGB format as numpy array
|
||||||
"""
|
"""
|
||||||
|
if self.channels == 1:
|
||||||
|
img = get_pseudo_rgb(self.data)
|
||||||
|
self._dtype = img.dtype
|
||||||
|
return img
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
if self._channels == 3:
|
if self._channels == 3:
|
||||||
return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
|
return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
|
||||||
elif self._channels == 4:
|
elif self._channels == 4:
|
||||||
@@ -225,6 +248,18 @@ class Image:
|
|||||||
else:
|
else:
|
||||||
return self._data
|
return self._data
|
||||||
|
|
||||||
|
def get_qt_rgb(self) -> np.ascontiguousarray:
|
||||||
|
# we keep data as (C, H, W)
|
||||||
|
_img = self.get_rgb()
|
||||||
|
|
||||||
|
img = np.zeros((self.height, self.width, 4), dtype=np.float32)
|
||||||
|
img[..., 0] = _img[0] # R gradient
|
||||||
|
img[..., 1] = _img[1] # G gradient
|
||||||
|
img[..., 2] = _img[2] # B constant
|
||||||
|
img[..., 3] = 1.0 # A = 1.0 (opaque)
|
||||||
|
|
||||||
|
return np.ascontiguousarray(img)
|
||||||
|
|
||||||
def get_grayscale(self) -> np.ndarray:
|
def get_grayscale(self) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Get image as grayscale numpy array.
|
Get image as grayscale numpy array.
|
||||||
@@ -277,37 +312,18 @@ class Image:
|
|||||||
"""
|
"""
|
||||||
return self._channels >= 3
|
return self._channels >= 3
|
||||||
|
|
||||||
def convert_grayscale_to_rgb_preserve_range(
|
def save(self, path: Union[str, Path], pseudo_rgb: bool = True) -> None:
|
||||||
self,
|
|
||||||
) -> PILImage.Image:
|
|
||||||
"""Convert a single-channel PIL image to RGB while preserving dynamic range.
|
|
||||||
|
|
||||||
Returns:
|
if self.channels == 1:
|
||||||
PIL Image in RGB mode with intensities normalized to 0-255.
|
if pseudo_rgb:
|
||||||
"""
|
img = get_pseudo_rgb(self.data)
|
||||||
if self._channels == 3:
|
else:
|
||||||
return self.pil_image
|
img = np.repeat(self.data, 3, axis=2)
|
||||||
|
|
||||||
grayscale = self.data
|
|
||||||
if grayscale.ndim == 3:
|
|
||||||
grayscale = grayscale[:, :, 0]
|
|
||||||
|
|
||||||
original_dtype = grayscale.dtype
|
|
||||||
grayscale = grayscale.astype(np.float32)
|
|
||||||
|
|
||||||
if grayscale.size == 0:
|
|
||||||
return PILImage.new("RGB", self.shape, color=(0, 0, 0))
|
|
||||||
|
|
||||||
if np.issubdtype(original_dtype, np.integer):
|
|
||||||
denom = float(max(np.iinfo(original_dtype).max, 1))
|
|
||||||
else:
|
else:
|
||||||
max_val = float(grayscale.max())
|
raise NotImplementedError("Only grayscale images are supported for now.")
|
||||||
denom = max(max_val, 1.0)
|
|
||||||
|
|
||||||
grayscale = np.clip(grayscale / denom, 0.0, 1.0)
|
imwrite(path, data=img)
|
||||||
grayscale_u8 = (grayscale * 255.0).round().astype(np.uint8)
|
|
||||||
rgb_arr = np.repeat(grayscale_u8[:, :, None], 3, axis=2)
|
|
||||||
return PILImage.fromarray(rgb_arr, mode="RGB")
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
"""String representation of the Image object."""
|
"""String representation of the Image object."""
|
||||||
@@ -322,3 +338,15 @@ class Image:
|
|||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
"""String representation of the Image object."""
|
"""String representation of the Image object."""
|
||||||
return self.__repr__()
|
return self.__repr__()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--path", type=str, required=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
img = Image(args.path)
|
||||||
|
img.save(args.path + "test.tif")
|
||||||
|
print(img)
|
||||||
|
|||||||
@@ -32,8 +32,10 @@ def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
|
|||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tifffile
|
|
||||||
|
# import tifffile
|
||||||
import torch
|
import torch
|
||||||
|
from src.utils.image import Image
|
||||||
|
|
||||||
from ultralytics.utils import patches as ul_patches
|
from ultralytics.utils import patches as ul_patches
|
||||||
|
|
||||||
@@ -55,7 +57,7 @@ def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
|
|||||||
|
|
||||||
ext = os.path.splitext(filename)[1].lower()
|
ext = os.path.splitext(filename)[1].lower()
|
||||||
if ext in (".tif", ".tiff"):
|
if ext in (".tif", ".tiff"):
|
||||||
arr = tifffile.imread(filename)
|
arr = Image(filename).get_qt_rgb()[:, :, :3]
|
||||||
|
|
||||||
# Normalize common shapes:
|
# Normalize common shapes:
|
||||||
# - (H, W) -> (H, W, 1)
|
# - (H, W) -> (H, W, 1)
|
||||||
@@ -71,35 +73,6 @@ def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
|
|||||||
if arr.ndim == 2:
|
if arr.ndim == 2:
|
||||||
arr = arr[..., None]
|
arr = arr[..., None]
|
||||||
|
|
||||||
# Ultralytics expects BGR ordering when `channels=3`.
|
|
||||||
# For grayscale data we replicate channels (no scaling, no quantization).
|
|
||||||
if flags != cv2.IMREAD_GRAYSCALE:
|
|
||||||
if arr.shape[2] == 1:
|
|
||||||
if pseudo_rgb:
|
|
||||||
gamma = 0.3
|
|
||||||
a1 = arr.copy().astype(np.float32)
|
|
||||||
a1 -= np.percentile(a1, 2)
|
|
||||||
a1[a1 < 0] = 0
|
|
||||||
p98 = np.percentile(a1, 98)
|
|
||||||
a1[a1 > p98] = p98
|
|
||||||
a1 /= a1.max()
|
|
||||||
|
|
||||||
a2 = a1.copy()
|
|
||||||
a2 = a2**gamma
|
|
||||||
a2 /= a2.max()
|
|
||||||
|
|
||||||
a3 = a1.copy()
|
|
||||||
p90 = np.percentile(a3, 90)
|
|
||||||
a3[a3 > p90] = p90
|
|
||||||
a3 /= a3.max()
|
|
||||||
|
|
||||||
arr = np.concatenate([a1, a2, a3], axis=2)
|
|
||||||
|
|
||||||
else:
|
|
||||||
arr = np.repeat(arr, 3, axis=2)
|
|
||||||
elif arr.shape[2] >= 3:
|
|
||||||
arr = arr[:, :, :3]
|
|
||||||
|
|
||||||
# Ensure contiguous array for downstream OpenCV ops.
|
# Ensure contiguous array for downstream OpenCV ops.
|
||||||
logger.info(f"Loading with monkey-patched imread: {filename}")
|
logger.info(f"Loading with monkey-patched imread: {filename}")
|
||||||
return np.ascontiguousarray(arr)
|
return np.ascontiguousarray(arr)
|
||||||
|
|||||||
Reference in New Issue
Block a user