""" Image loading and management utilities for the microscopy object detection application. """ import cv2 import numpy as np from pathlib import Path from typing import Optional, Tuple, Union from PIL import Image as PILImage import tifffile from src.utils.logger import get_logger from src.utils.file_utils import validate_file_path, is_image_file from PySide6.QtGui import QImage logger = get_logger(__name__) class ImageLoadError(Exception): """Exception raised when an image cannot be loaded.""" pass class Image: """ A class for loading and managing images from file paths. Supports multiple image formats: .jpg, .jpeg, .png, .tif, .tiff, .bmp Provides access to image data in multiple formats (OpenCV/numpy, PIL). Attributes: path: Path to the image file data: Image data as numpy array (OpenCV format, BGR) pil_image: Image data as PIL Image (RGB) width: Image width in pixels height: Image height in pixels channels: Number of color channels format: Image file format size_bytes: File size in bytes """ SUPPORTED_EXTENSIONS = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"] def __init__(self, image_path: Union[str, Path]): """ Initialize an Image object by loading from a file path. Args: image_path: Path to the image file (string or Path object) Raises: ImageLoadError: If the image cannot be loaded or is invalid """ self.path = Path(image_path) self._data: Optional[np.ndarray] = None self._pil_image: Optional[PILImage.Image] = None self._width: int = 0 self._height: int = 0 self._channels: int = 0 self._format: str = "" self._size_bytes: int = 0 self._dtype: Optional[np.dtype] = None # Load the image self._load() def _load(self) -> None: """ Load the image from disk. Raises: ImageLoadError: If the image cannot be loaded """ # Validate path if not validate_file_path(str(self.path), must_exist=True): raise ImageLoadError(f"Invalid or non-existent file path: {self.path}") # Check file extension if not is_image_file(str(self.path), self.SUPPORTED_EXTENSIONS): ext = self.path.suffix.lower() raise ImageLoadError( f"Unsupported image format: {ext}. " f"Supported formats: {', '.join(self.SUPPORTED_EXTENSIONS)}" ) try: # Check if it's a TIFF file - use tifffile for better support if self.path.suffix.lower() in [".tif", ".tiff"]: self._data = tifffile.imread(str(self.path)) if self._data is None: raise ImageLoadError( f"Failed to load TIFF with tifffile: {self.path}" ) # Extract metadata self._height, self._width = ( self._data.shape[:2] if len(self._data.shape) >= 2 else (self._data.shape[0], 1) ) self._channels = ( self._data.shape[2] if len(self._data.shape) == 3 else 1 ) self._format = self.path.suffix.lower().lstrip(".") self._size_bytes = self.path.stat().st_size self._dtype = self._data.dtype # Load PIL version for compatibility if self._channels == 1: # Grayscale self._pil_image = PILImage.fromarray(self._data) else: # Multi-channel (RGB or RGBA) self._pil_image = PILImage.fromarray(self._data) logger.info( f"Successfully loaded TIFF image: {self.path.name} " f"({self._width}x{self._height}, {self._channels} channels, " f"dtype={self._dtype}, {self._format.upper()})" ) else: # Load with OpenCV (returns BGR format) for non-TIFF images self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED) if self._data is None: raise ImageLoadError( f"Failed to load image with OpenCV: {self.path}" ) # Extract metadata self._height, self._width = self._data.shape[:2] self._channels = ( self._data.shape[2] if len(self._data.shape) == 3 else 1 ) self._format = self.path.suffix.lower().lstrip(".") self._size_bytes = self.path.stat().st_size self._dtype = self._data.dtype # Load PIL version for compatibility (convert BGR to RGB) if self._channels == 3: rgb_data = cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB) self._pil_image = PILImage.fromarray(rgb_data) elif self._channels == 4: rgba_data = cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA) self._pil_image = PILImage.fromarray(rgba_data) else: # Grayscale self._pil_image = PILImage.fromarray(self._data) logger.info( f"Successfully loaded image: {self.path.name} " f"({self._width}x{self._height}, {self._channels} channels, " f"{self._format.upper()})" ) except Exception as e: logger.error(f"Error loading image {self.path}: {e}") raise ImageLoadError(f"Failed to load image: {e}") from e @property def data(self) -> np.ndarray: """ Get image data as numpy array (OpenCV format, BGR or grayscale). Returns: Image data as numpy array """ if self._data is None: raise ImageLoadError("Image data not available") return self._data @property def pil_image(self) -> PILImage.Image: """ Get image data as PIL Image (RGB or grayscale). Returns: PIL Image object """ if self._pil_image is None: raise ImageLoadError("PIL image not available") return self._pil_image @property def width(self) -> int: """Get image width in pixels.""" return self._width @property def height(self) -> int: """Get image height in pixels.""" return self._height @property def shape(self) -> Tuple[int, int, int]: """ Get image shape as (height, width, channels). Returns: Tuple of (height, width, channels) """ print("shape", self._height, self._width, self._channels) return (self._height, self._width, self._channels) @property def channels(self) -> int: """Get number of color channels.""" return self._channels @property def format(self) -> str: """Get image file format (e.g., 'jpg', 'png').""" return self._format @property def size_bytes(self) -> int: """Get file size in bytes.""" return self._size_bytes @property def size_mb(self) -> float: """Get file size in megabytes.""" return self._size_bytes / (1024 * 1024) @property def dtype(self) -> np.dtype: """Get the data type of the image array.""" if self._dtype is None: raise ImageLoadError("Image dtype not available") return self._dtype @property def qtimage_format(self) -> QImage.Format: """ Get the appropriate QImage format for the image. Returns: QImage.Format enum value """ if self._channels == 3: return QImage.Format_RGB888 elif self._channels == 4: return QImage.Format_RGBA8888 elif self._channels == 1: if self._dtype == np.uint16: return QImage.Format_Grayscale16 else: return QImage.Format_Grayscale8 else: raise ImageLoadError(f"Unsupported number of channels: {self._channels}") def get_rgb(self) -> np.ndarray: """ Get image data as RGB numpy array. Returns: Image data in RGB format as numpy array """ if self._channels == 3: return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB) elif self._channels == 4: return cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA) else: return self._data def get_grayscale(self) -> np.ndarray: """ Get image as grayscale numpy array. Returns: Grayscale image as numpy array """ if self._channels == 1: return self._data else: return cv2.cvtColor(self._data, cv2.COLOR_BGR2GRAY) def copy(self) -> np.ndarray: """ Get a copy of the image data. Returns: Copy of image data as numpy array """ return self._data.copy() def resize(self, width: int, height: int) -> np.ndarray: """ Resize the image to specified dimensions. Args: width: Target width in pixels height: Target height in pixels Returns: Resized image as numpy array (does not modify original) """ return cv2.resize(self._data, (width, height)) def is_grayscale(self) -> bool: """ Check if image is grayscale. Returns: True if image is grayscale (1 channel) """ return self._channels == 1 def is_color(self) -> bool: """ Check if image is color. Returns: True if image has 3 or more channels """ return self._channels >= 3 def to_normalized_float32(self) -> np.ndarray: """ Convert image data to normalized float32 in range [0, 1]. For 16-bit images, this properly scales the full dynamic range. For 8-bit images, divides by 255. Already float images are clipped to [0, 1]. Returns: Normalized image data as float32 numpy array [0, 1] """ data = self._data.astype(np.float32) if self._dtype == np.uint16: # 16-bit: normalize by max value (65535) data = data / 65535.0 elif self._dtype == np.uint8: # 8-bit: normalize by 255 data = data / 255.0 elif np.issubdtype(self._dtype, np.floating): # Already float, just clip to [0, 1] data = np.clip(data, 0.0, 1.0) else: # Other integer types: use dtype info if np.issubdtype(self._dtype, np.integer): max_val = np.iinfo(self._dtype).max data = data / float(max_val) else: # Unknown type: attempt min-max normalization min_val = data.min() max_val = data.max() if max_val > min_val: data = (data - min_val) / (max_val - min_val) else: data = np.zeros_like(data) return np.clip(data, 0.0, 1.0) def __repr__(self) -> str: """String representation of the Image object.""" return ( f"Image(path='{self.path.name}', " f"shape=({self._width}x{self._height}x{self._channels}), " f"format={self._format}, " f"size={self.size_mb:.2f}MB)" ) def __str__(self) -> str: """String representation of the Image object.""" return self.__repr__() def convert_grayscale_to_rgb_preserve_range( pil_image: PILImage.Image, ) -> PILImage.Image: """Convert a single-channel PIL image to RGB while preserving dynamic range. Args: pil_image: Single-channel PIL image (e.g., 16-bit grayscale). Returns: PIL Image in RGB mode with intensities normalized to 0-255. """ if pil_image.mode == "RGB": return pil_image grayscale = np.array(pil_image) if grayscale.ndim == 3: grayscale = grayscale[:, :, 0] original_dtype = grayscale.dtype grayscale = grayscale.astype(np.float32) if grayscale.size == 0: return PILImage.new("RGB", pil_image.size, color=(0, 0, 0)) if np.issubdtype(original_dtype, np.integer): denom = float(max(np.iinfo(original_dtype).max, 1)) else: max_val = float(grayscale.max()) denom = max(max_val, 1.0) grayscale = np.clip(grayscale / denom, 0.0, 1.0) grayscale_u8 = (grayscale * 255.0).round().astype(np.uint8) rgb_arr = np.repeat(grayscale_u8[:, :, None], 3, axis=2) return PILImage.fromarray(rgb_arr, mode="RGB")