Files
object-segmentation/src/utils/image.py

408 lines
13 KiB
Python

"""
Image loading and management utilities for the microscopy object detection application.
"""
import cv2
import numpy as np
from pathlib import Path
from typing import Optional, Tuple, Union
from PIL import Image as PILImage
import tifffile
from src.utils.logger import get_logger
from src.utils.file_utils import validate_file_path, is_image_file
from PySide6.QtGui import QImage
logger = get_logger(__name__)
class ImageLoadError(Exception):
"""Exception raised when an image cannot be loaded."""
pass
class Image:
"""
A class for loading and managing images from file paths.
Supports multiple image formats: .jpg, .jpeg, .png, .tif, .tiff, .bmp
Provides access to image data in multiple formats (OpenCV/numpy, PIL).
Attributes:
path: Path to the image file
data: Image data as numpy array (OpenCV format, BGR)
pil_image: Image data as PIL Image (RGB)
width: Image width in pixels
height: Image height in pixels
channels: Number of color channels
format: Image file format
size_bytes: File size in bytes
"""
SUPPORTED_EXTENSIONS = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
def __init__(self, image_path: Union[str, Path]):
"""
Initialize an Image object by loading from a file path.
Args:
image_path: Path to the image file (string or Path object)
Raises:
ImageLoadError: If the image cannot be loaded or is invalid
"""
self.path = Path(image_path)
self._data: Optional[np.ndarray] = None
self._pil_image: Optional[PILImage.Image] = None
self._width: int = 0
self._height: int = 0
self._channels: int = 0
self._format: str = ""
self._size_bytes: int = 0
self._dtype: Optional[np.dtype] = None
# Load the image
self._load()
def _load(self) -> None:
"""
Load the image from disk.
Raises:
ImageLoadError: If the image cannot be loaded
"""
# Validate path
if not validate_file_path(str(self.path), must_exist=True):
raise ImageLoadError(f"Invalid or non-existent file path: {self.path}")
# Check file extension
if not is_image_file(str(self.path), self.SUPPORTED_EXTENSIONS):
ext = self.path.suffix.lower()
raise ImageLoadError(
f"Unsupported image format: {ext}. "
f"Supported formats: {', '.join(self.SUPPORTED_EXTENSIONS)}"
)
try:
# Check if it's a TIFF file - use tifffile for better support
if self.path.suffix.lower() in [".tif", ".tiff"]:
self._data = tifffile.imread(str(self.path))
if self._data is None:
raise ImageLoadError(
f"Failed to load TIFF with tifffile: {self.path}"
)
# Extract metadata
self._height, self._width = (
self._data.shape[:2]
if len(self._data.shape) >= 2
else (self._data.shape[0], 1)
)
self._channels = (
self._data.shape[2] if len(self._data.shape) == 3 else 1
)
self._format = self.path.suffix.lower().lstrip(".")
self._size_bytes = self.path.stat().st_size
self._dtype = self._data.dtype
# Load PIL version for compatibility
if self._channels == 1:
# Grayscale
self._pil_image = PILImage.fromarray(self._data)
else:
# Multi-channel (RGB or RGBA)
self._pil_image = PILImage.fromarray(self._data)
logger.info(
f"Successfully loaded TIFF image: {self.path.name} "
f"({self._width}x{self._height}, {self._channels} channels, "
f"dtype={self._dtype}, {self._format.upper()})"
)
else:
# Load with OpenCV (returns BGR format) for non-TIFF images
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
if self._data is None:
raise ImageLoadError(
f"Failed to load image with OpenCV: {self.path}"
)
# Extract metadata
self._height, self._width = self._data.shape[:2]
self._channels = (
self._data.shape[2] if len(self._data.shape) == 3 else 1
)
self._format = self.path.suffix.lower().lstrip(".")
self._size_bytes = self.path.stat().st_size
self._dtype = self._data.dtype
# Load PIL version for compatibility (convert BGR to RGB)
if self._channels == 3:
rgb_data = cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
self._pil_image = PILImage.fromarray(rgb_data)
elif self._channels == 4:
rgba_data = cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA)
self._pil_image = PILImage.fromarray(rgba_data)
else:
# Grayscale
self._pil_image = PILImage.fromarray(self._data)
logger.info(
f"Successfully loaded image: {self.path.name} "
f"({self._width}x{self._height}, {self._channels} channels, "
f"{self._format.upper()})"
)
except Exception as e:
logger.error(f"Error loading image {self.path}: {e}")
raise ImageLoadError(f"Failed to load image: {e}") from e
@property
def data(self) -> np.ndarray:
"""
Get image data as numpy array (OpenCV format, BGR or grayscale).
Returns:
Image data as numpy array
"""
if self._data is None:
raise ImageLoadError("Image data not available")
return self._data
@property
def pil_image(self) -> PILImage.Image:
"""
Get image data as PIL Image (RGB or grayscale).
Returns:
PIL Image object
"""
if self._pil_image is None:
raise ImageLoadError("PIL image not available")
return self._pil_image
@property
def width(self) -> int:
"""Get image width in pixels."""
return self._width
@property
def height(self) -> int:
"""Get image height in pixels."""
return self._height
@property
def shape(self) -> Tuple[int, int, int]:
"""
Get image shape as (height, width, channels).
Returns:
Tuple of (height, width, channels)
"""
print("shape", self._height, self._width, self._channels)
return (self._height, self._width, self._channels)
@property
def channels(self) -> int:
"""Get number of color channels."""
return self._channels
@property
def format(self) -> str:
"""Get image file format (e.g., 'jpg', 'png')."""
return self._format
@property
def size_bytes(self) -> int:
"""Get file size in bytes."""
return self._size_bytes
@property
def size_mb(self) -> float:
"""Get file size in megabytes."""
return self._size_bytes / (1024 * 1024)
@property
def dtype(self) -> np.dtype:
"""Get the data type of the image array."""
if self._dtype is None:
raise ImageLoadError("Image dtype not available")
return self._dtype
@property
def qtimage_format(self) -> QImage.Format:
"""
Get the appropriate QImage format for the image.
Returns:
QImage.Format enum value
"""
if self._channels == 3:
return QImage.Format_RGB888
elif self._channels == 4:
return QImage.Format_RGBA8888
elif self._channels == 1:
if self._dtype == np.uint16:
return QImage.Format_Grayscale16
else:
return QImage.Format_Grayscale8
else:
raise ImageLoadError(f"Unsupported number of channels: {self._channels}")
def get_rgb(self) -> np.ndarray:
"""
Get image data as RGB numpy array.
Returns:
Image data in RGB format as numpy array
"""
if self._channels == 3:
return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
elif self._channels == 4:
return cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA)
else:
return self._data
def get_grayscale(self) -> np.ndarray:
"""
Get image as grayscale numpy array.
Returns:
Grayscale image as numpy array
"""
if self._channels == 1:
return self._data
else:
return cv2.cvtColor(self._data, cv2.COLOR_BGR2GRAY)
def copy(self) -> np.ndarray:
"""
Get a copy of the image data.
Returns:
Copy of image data as numpy array
"""
return self._data.copy()
def resize(self, width: int, height: int) -> np.ndarray:
"""
Resize the image to specified dimensions.
Args:
width: Target width in pixels
height: Target height in pixels
Returns:
Resized image as numpy array (does not modify original)
"""
return cv2.resize(self._data, (width, height))
def is_grayscale(self) -> bool:
"""
Check if image is grayscale.
Returns:
True if image is grayscale (1 channel)
"""
return self._channels == 1
def is_color(self) -> bool:
"""
Check if image is color.
Returns:
True if image has 3 or more channels
"""
return self._channels >= 3
def to_normalized_float32(self) -> np.ndarray:
"""
Convert image data to normalized float32 in range [0, 1].
For 16-bit images, this properly scales the full dynamic range.
For 8-bit images, divides by 255.
Already float images are clipped to [0, 1].
Returns:
Normalized image data as float32 numpy array [0, 1]
"""
data = self._data.astype(np.float32)
if self._dtype == np.uint16:
# 16-bit: normalize by max value (65535)
data = data / 65535.0
elif self._dtype == np.uint8:
# 8-bit: normalize by 255
data = data / 255.0
elif np.issubdtype(self._dtype, np.floating):
# Already float, just clip to [0, 1]
data = np.clip(data, 0.0, 1.0)
else:
# Other integer types: use dtype info
if np.issubdtype(self._dtype, np.integer):
max_val = np.iinfo(self._dtype).max
data = data / float(max_val)
else:
# Unknown type: attempt min-max normalization
min_val = data.min()
max_val = data.max()
if max_val > min_val:
data = (data - min_val) / (max_val - min_val)
else:
data = np.zeros_like(data)
return np.clip(data, 0.0, 1.0)
def __repr__(self) -> str:
"""String representation of the Image object."""
return (
f"Image(path='{self.path.name}', "
f"shape=({self._width}x{self._height}x{self._channels}), "
f"format={self._format}, "
f"size={self.size_mb:.2f}MB)"
)
def __str__(self) -> str:
"""String representation of the Image object."""
return self.__repr__()
def convert_grayscale_to_rgb_preserve_range(
pil_image: PILImage.Image,
) -> PILImage.Image:
"""Convert a single-channel PIL image to RGB while preserving dynamic range.
Args:
pil_image: Single-channel PIL image (e.g., 16-bit grayscale).
Returns:
PIL Image in RGB mode with intensities normalized to 0-255.
"""
if pil_image.mode == "RGB":
return pil_image
grayscale = np.array(pil_image)
if grayscale.ndim == 3:
grayscale = grayscale[:, :, 0]
original_dtype = grayscale.dtype
grayscale = grayscale.astype(np.float32)
if grayscale.size == 0:
return PILImage.new("RGB", pil_image.size, color=(0, 0, 0))
if np.issubdtype(original_dtype, np.integer):
denom = float(max(np.iinfo(original_dtype).max, 1))
else:
max_val = float(grayscale.max())
denom = max(max_val, 1.0)
grayscale = np.clip(grayscale / denom, 0.0, 1.0)
grayscale_u8 = (grayscale * 255.0).round().astype(np.uint8)
rgb_arr = np.repeat(grayscale_u8[:, :, None], 3, axis=2)
return PILImage.fromarray(rgb_arr, mode="RGB")