Files
object-segmentation/src/utils/image.py

292 lines
8.6 KiB
Python

"""
Image loading and management utilities for the microscopy object detection application.
"""
import cv2
import numpy as np
from pathlib import Path
from typing import Optional, Tuple, Union
from PIL import Image as PILImage
from src.utils.logger import get_logger
from src.utils.file_utils import validate_file_path, is_image_file
from PySide6.QtGui import QImage
logger = get_logger(__name__)
class ImageLoadError(Exception):
"""Exception raised when an image cannot be loaded."""
pass
class Image:
"""
A class for loading and managing images from file paths.
Supports multiple image formats: .jpg, .jpeg, .png, .tif, .tiff, .bmp
Provides access to image data in multiple formats (OpenCV/numpy, PIL).
Attributes:
path: Path to the image file
data: Image data as numpy array (OpenCV format, BGR)
pil_image: Image data as PIL Image (RGB)
width: Image width in pixels
height: Image height in pixels
channels: Number of color channels
format: Image file format
size_bytes: File size in bytes
"""
SUPPORTED_EXTENSIONS = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
def __init__(self, image_path: Union[str, Path]):
"""
Initialize an Image object by loading from a file path.
Args:
image_path: Path to the image file (string or Path object)
Raises:
ImageLoadError: If the image cannot be loaded or is invalid
"""
self.path = Path(image_path)
self._data: Optional[np.ndarray] = None
self._pil_image: Optional[PILImage.Image] = None
self._width: int = 0
self._height: int = 0
self._channels: int = 0
self._format: str = ""
self._size_bytes: int = 0
self._dtype: Optional[np.dtype] = None
# Load the image
self._load()
def _load(self) -> None:
"""
Load the image from disk.
Raises:
ImageLoadError: If the image cannot be loaded
"""
# Validate path
if not validate_file_path(str(self.path), must_exist=True):
raise ImageLoadError(f"Invalid or non-existent file path: {self.path}")
# Check file extension
if not is_image_file(str(self.path), self.SUPPORTED_EXTENSIONS):
ext = self.path.suffix.lower()
raise ImageLoadError(
f"Unsupported image format: {ext}. "
f"Supported formats: {', '.join(self.SUPPORTED_EXTENSIONS)}"
)
try:
# Load with OpenCV (returns BGR format)
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
if self._data is None:
raise ImageLoadError(f"Failed to load image with OpenCV: {self.path}")
# Extract metadata
self._height, self._width = self._data.shape[:2]
self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1
self._format = self.path.suffix.lower().lstrip(".")
self._size_bytes = self.path.stat().st_size
self._dtype = self._data.dtype
# Load PIL version for compatibility (convert BGR to RGB)
if self._channels == 3:
rgb_data = cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
self._pil_image = PILImage.fromarray(rgb_data)
elif self._channels == 4:
rgba_data = cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA)
self._pil_image = PILImage.fromarray(rgba_data)
else:
# Grayscale
self._pil_image = PILImage.fromarray(self._data)
logger.info(
f"Successfully loaded image: {self.path.name} "
f"({self._width}x{self._height}, {self._channels} channels, "
f"{self._format.upper()})"
)
except Exception as e:
logger.error(f"Error loading image {self.path}: {e}")
raise ImageLoadError(f"Failed to load image: {e}") from e
@property
def data(self) -> np.ndarray:
"""
Get image data as numpy array (OpenCV format, BGR or grayscale).
Returns:
Image data as numpy array
"""
if self._data is None:
raise ImageLoadError("Image data not available")
return self._data
@property
def pil_image(self) -> PILImage.Image:
"""
Get image data as PIL Image (RGB or grayscale).
Returns:
PIL Image object
"""
if self._pil_image is None:
raise ImageLoadError("PIL image not available")
return self._pil_image
@property
def width(self) -> int:
"""Get image width in pixels."""
return self._width
@property
def height(self) -> int:
"""Get image height in pixels."""
return self._height
@property
def shape(self) -> Tuple[int, int, int]:
"""
Get image shape as (height, width, channels).
Returns:
Tuple of (height, width, channels)
"""
print("shape", self._height, self._width, self._channels)
return (self._height, self._width, self._channels)
@property
def channels(self) -> int:
"""Get number of color channels."""
return self._channels
@property
def format(self) -> str:
"""Get image file format (e.g., 'jpg', 'png')."""
return self._format
@property
def size_bytes(self) -> int:
"""Get file size in bytes."""
return self._size_bytes
@property
def size_mb(self) -> float:
"""Get file size in megabytes."""
return self._size_bytes / (1024 * 1024)
@property
def dtype(self) -> np.dtype:
"""Get the data type of the image array."""
if self._dtype is None:
raise ImageLoadError("Image dtype not available")
return self._dtype
@property
def qtimage_format(self) -> QImage.Format:
"""
Get the appropriate QImage format for the image.
Returns:
QImage.Format enum value
"""
if self._channels == 3:
return QImage.Format_RGB888
elif self._channels == 4:
return QImage.Format_RGBA8888
elif self._channels == 1:
if self._dtype == np.uint16:
return QImage.Format_Grayscale16
else:
return QImage.Format_Grayscale8
else:
raise ImageLoadError(f"Unsupported number of channels: {self._channels}")
def get_rgb(self) -> np.ndarray:
"""
Get image data as RGB numpy array.
Returns:
Image data in RGB format as numpy array
"""
if self._channels == 3:
return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
elif self._channels == 4:
return cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA)
else:
return self._data
def get_grayscale(self) -> np.ndarray:
"""
Get image as grayscale numpy array.
Returns:
Grayscale image as numpy array
"""
if self._channels == 1:
return self._data
else:
return cv2.cvtColor(self._data, cv2.COLOR_BGR2GRAY)
def copy(self) -> np.ndarray:
"""
Get a copy of the image data.
Returns:
Copy of image data as numpy array
"""
return self._data.copy()
def resize(self, width: int, height: int) -> np.ndarray:
"""
Resize the image to specified dimensions.
Args:
width: Target width in pixels
height: Target height in pixels
Returns:
Resized image as numpy array (does not modify original)
"""
return cv2.resize(self._data, (width, height))
def is_grayscale(self) -> bool:
"""
Check if image is grayscale.
Returns:
True if image is grayscale (1 channel)
"""
return self._channels == 1
def is_color(self) -> bool:
"""
Check if image is color.
Returns:
True if image has 3 or more channels
"""
return self._channels >= 3
def __repr__(self) -> str:
"""String representation of the Image object."""
return (
f"Image(path='{self.path.name}', "
f"shape=({self._width}x{self._height}x{self._channels}), "
f"format={self._format}, "
f"size={self.size_mb:.2f}MB)"
)
def __str__(self) -> str:
"""String representation of the Image object."""
return self.__repr__()