""" Image loading and management utilities for the microscopy object detection application. """ import cv2 import numpy as np from pathlib import Path from typing import Optional, Tuple, Union from PIL import Image as PILImage from src.utils.logger import get_logger from src.utils.file_utils import validate_file_path, is_image_file from PySide6.QtGui import QImage logger = get_logger(__name__) class ImageLoadError(Exception): """Exception raised when an image cannot be loaded.""" pass class Image: """ A class for loading and managing images from file paths. Supports multiple image formats: .jpg, .jpeg, .png, .tif, .tiff, .bmp Provides access to image data in multiple formats (OpenCV/numpy, PIL). Attributes: path: Path to the image file data: Image data as numpy array (OpenCV format, BGR) pil_image: Image data as PIL Image (RGB) width: Image width in pixels height: Image height in pixels channels: Number of color channels format: Image file format size_bytes: File size in bytes """ SUPPORTED_EXTENSIONS = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"] def __init__(self, image_path: Union[str, Path]): """ Initialize an Image object by loading from a file path. Args: image_path: Path to the image file (string or Path object) Raises: ImageLoadError: If the image cannot be loaded or is invalid """ self.path = Path(image_path) self._data: Optional[np.ndarray] = None self._pil_image: Optional[PILImage.Image] = None self._width: int = 0 self._height: int = 0 self._channels: int = 0 self._format: str = "" self._size_bytes: int = 0 self._dtype: Optional[np.dtype] = None # Load the image self._load() def _load(self) -> None: """ Load the image from disk. Raises: ImageLoadError: If the image cannot be loaded """ # Validate path if not validate_file_path(str(self.path), must_exist=True): raise ImageLoadError(f"Invalid or non-existent file path: {self.path}") # Check file extension if not is_image_file(str(self.path), self.SUPPORTED_EXTENSIONS): ext = self.path.suffix.lower() raise ImageLoadError( f"Unsupported image format: {ext}. " f"Supported formats: {', '.join(self.SUPPORTED_EXTENSIONS)}" ) try: # Load with OpenCV (returns BGR format) self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED) if self._data is None: raise ImageLoadError(f"Failed to load image with OpenCV: {self.path}") # Extract metadata self._height, self._width = self._data.shape[:2] self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1 self._format = self.path.suffix.lower().lstrip(".") self._size_bytes = self.path.stat().st_size self._dtype = self._data.dtype # Load PIL version for compatibility (convert BGR to RGB) if self._channels == 3: rgb_data = cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB) self._pil_image = PILImage.fromarray(rgb_data) elif self._channels == 4: rgba_data = cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA) self._pil_image = PILImage.fromarray(rgba_data) else: # Grayscale self._pil_image = PILImage.fromarray(self._data) logger.info( f"Successfully loaded image: {self.path.name} " f"({self._width}x{self._height}, {self._channels} channels, " f"{self._format.upper()})" ) except Exception as e: logger.error(f"Error loading image {self.path}: {e}") raise ImageLoadError(f"Failed to load image: {e}") from e @property def data(self) -> np.ndarray: """ Get image data as numpy array (OpenCV format, BGR or grayscale). Returns: Image data as numpy array """ if self._data is None: raise ImageLoadError("Image data not available") return self._data @property def pil_image(self) -> PILImage.Image: """ Get image data as PIL Image (RGB or grayscale). Returns: PIL Image object """ if self._pil_image is None: raise ImageLoadError("PIL image not available") return self._pil_image @property def width(self) -> int: """Get image width in pixels.""" return self._width @property def height(self) -> int: """Get image height in pixels.""" return self._height @property def shape(self) -> Tuple[int, int, int]: """ Get image shape as (height, width, channels). Returns: Tuple of (height, width, channels) """ print("shape", self._height, self._width, self._channels) return (self._height, self._width, self._channels) @property def channels(self) -> int: """Get number of color channels.""" return self._channels @property def format(self) -> str: """Get image file format (e.g., 'jpg', 'png').""" return self._format @property def size_bytes(self) -> int: """Get file size in bytes.""" return self._size_bytes @property def size_mb(self) -> float: """Get file size in megabytes.""" return self._size_bytes / (1024 * 1024) @property def dtype(self) -> np.dtype: """Get the data type of the image array.""" if self._dtype is None: raise ImageLoadError("Image dtype not available") return self._dtype @property def qtimage_format(self) -> QImage.Format: """ Get the appropriate QImage format for the image. Returns: QImage.Format enum value """ if self._channels == 3: return QImage.Format_RGB888 elif self._channels == 4: return QImage.Format_RGBA8888 elif self._channels == 1: if self._dtype == np.uint16: return QImage.Format_Grayscale16 else: return QImage.Format_Grayscale8 else: raise ImageLoadError(f"Unsupported number of channels: {self._channels}") def get_rgb(self) -> np.ndarray: """ Get image data as RGB numpy array. Returns: Image data in RGB format as numpy array """ if self._channels == 3: return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB) elif self._channels == 4: return cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA) else: return self._data def get_grayscale(self) -> np.ndarray: """ Get image as grayscale numpy array. Returns: Grayscale image as numpy array """ if self._channels == 1: return self._data else: return cv2.cvtColor(self._data, cv2.COLOR_BGR2GRAY) def copy(self) -> np.ndarray: """ Get a copy of the image data. Returns: Copy of image data as numpy array """ return self._data.copy() def resize(self, width: int, height: int) -> np.ndarray: """ Resize the image to specified dimensions. Args: width: Target width in pixels height: Target height in pixels Returns: Resized image as numpy array (does not modify original) """ return cv2.resize(self._data, (width, height)) def is_grayscale(self) -> bool: """ Check if image is grayscale. Returns: True if image is grayscale (1 channel) """ return self._channels == 1 def is_color(self) -> bool: """ Check if image is color. Returns: True if image has 3 or more channels """ return self._channels >= 3 def __repr__(self) -> str: """String representation of the Image object.""" return ( f"Image(path='{self.path.name}', " f"shape=({self._width}x{self._height}x{self._channels}), " f"format={self._format}, " f"size={self.size_mb:.2f}MB)" ) def __str__(self) -> str: """String representation of the Image object.""" return self.__repr__()