Implementing float 32 data managent
This commit is contained in:
@@ -7,6 +7,7 @@ import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
from PIL import Image as PILImage
|
||||
import tifffile
|
||||
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.file_utils import validate_file_path, is_image_file
|
||||
@@ -85,35 +86,75 @@ class Image:
|
||||
)
|
||||
|
||||
try:
|
||||
# Load with OpenCV (returns BGR format)
|
||||
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
|
||||
# Check if it's a TIFF file - use tifffile for better support
|
||||
if self.path.suffix.lower() in [".tif", ".tiff"]:
|
||||
self._data = tifffile.imread(str(self.path))
|
||||
|
||||
if self._data is None:
|
||||
raise ImageLoadError(f"Failed to load image with OpenCV: {self.path}")
|
||||
if self._data is None:
|
||||
raise ImageLoadError(
|
||||
f"Failed to load TIFF with tifffile: {self.path}"
|
||||
)
|
||||
|
||||
# Extract metadata
|
||||
self._height, self._width = self._data.shape[:2]
|
||||
self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1
|
||||
self._format = self.path.suffix.lower().lstrip(".")
|
||||
self._size_bytes = self.path.stat().st_size
|
||||
self._dtype = self._data.dtype
|
||||
# Extract metadata
|
||||
self._height, self._width = (
|
||||
self._data.shape[:2]
|
||||
if len(self._data.shape) >= 2
|
||||
else (self._data.shape[0], 1)
|
||||
)
|
||||
self._channels = (
|
||||
self._data.shape[2] if len(self._data.shape) == 3 else 1
|
||||
)
|
||||
self._format = self.path.suffix.lower().lstrip(".")
|
||||
self._size_bytes = self.path.stat().st_size
|
||||
self._dtype = self._data.dtype
|
||||
|
||||
# Load PIL version for compatibility (convert BGR to RGB)
|
||||
if self._channels == 3:
|
||||
rgb_data = cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
|
||||
self._pil_image = PILImage.fromarray(rgb_data)
|
||||
elif self._channels == 4:
|
||||
rgba_data = cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA)
|
||||
self._pil_image = PILImage.fromarray(rgba_data)
|
||||
# Load PIL version for compatibility
|
||||
if self._channels == 1:
|
||||
# Grayscale
|
||||
self._pil_image = PILImage.fromarray(self._data)
|
||||
else:
|
||||
# Multi-channel (RGB or RGBA)
|
||||
self._pil_image = PILImage.fromarray(self._data)
|
||||
|
||||
logger.info(
|
||||
f"Successfully loaded TIFF image: {self.path.name} "
|
||||
f"({self._width}x{self._height}, {self._channels} channels, "
|
||||
f"dtype={self._dtype}, {self._format.upper()})"
|
||||
)
|
||||
else:
|
||||
# Grayscale
|
||||
self._pil_image = PILImage.fromarray(self._data)
|
||||
# Load with OpenCV (returns BGR format) for non-TIFF images
|
||||
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
|
||||
|
||||
logger.info(
|
||||
f"Successfully loaded image: {self.path.name} "
|
||||
f"({self._width}x{self._height}, {self._channels} channels, "
|
||||
f"{self._format.upper()})"
|
||||
)
|
||||
if self._data is None:
|
||||
raise ImageLoadError(
|
||||
f"Failed to load image with OpenCV: {self.path}"
|
||||
)
|
||||
|
||||
# Extract metadata
|
||||
self._height, self._width = self._data.shape[:2]
|
||||
self._channels = (
|
||||
self._data.shape[2] if len(self._data.shape) == 3 else 1
|
||||
)
|
||||
self._format = self.path.suffix.lower().lstrip(".")
|
||||
self._size_bytes = self.path.stat().st_size
|
||||
self._dtype = self._data.dtype
|
||||
|
||||
# Load PIL version for compatibility (convert BGR to RGB)
|
||||
if self._channels == 3:
|
||||
rgb_data = cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
|
||||
self._pil_image = PILImage.fromarray(rgb_data)
|
||||
elif self._channels == 4:
|
||||
rgba_data = cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA)
|
||||
self._pil_image = PILImage.fromarray(rgba_data)
|
||||
else:
|
||||
# Grayscale
|
||||
self._pil_image = PILImage.fromarray(self._data)
|
||||
|
||||
logger.info(
|
||||
f"Successfully loaded image: {self.path.name} "
|
||||
f"({self._width}x{self._height}, {self._channels} channels, "
|
||||
f"{self._format.upper()})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading image {self.path}: {e}")
|
||||
@@ -277,6 +318,44 @@ class Image:
|
||||
"""
|
||||
return self._channels >= 3
|
||||
|
||||
def to_normalized_float32(self) -> np.ndarray:
|
||||
"""
|
||||
Convert image data to normalized float32 in range [0, 1].
|
||||
|
||||
For 16-bit images, this properly scales the full dynamic range.
|
||||
For 8-bit images, divides by 255.
|
||||
Already float images are clipped to [0, 1].
|
||||
|
||||
Returns:
|
||||
Normalized image data as float32 numpy array [0, 1]
|
||||
"""
|
||||
data = self._data.astype(np.float32)
|
||||
|
||||
if self._dtype == np.uint16:
|
||||
# 16-bit: normalize by max value (65535)
|
||||
data = data / 65535.0
|
||||
elif self._dtype == np.uint8:
|
||||
# 8-bit: normalize by 255
|
||||
data = data / 255.0
|
||||
elif np.issubdtype(self._dtype, np.floating):
|
||||
# Already float, just clip to [0, 1]
|
||||
data = np.clip(data, 0.0, 1.0)
|
||||
else:
|
||||
# Other integer types: use dtype info
|
||||
if np.issubdtype(self._dtype, np.integer):
|
||||
max_val = np.iinfo(self._dtype).max
|
||||
data = data / float(max_val)
|
||||
else:
|
||||
# Unknown type: attempt min-max normalization
|
||||
min_val = data.min()
|
||||
max_val = data.max()
|
||||
if max_val > min_val:
|
||||
data = (data - min_val) / (max_val - min_val)
|
||||
else:
|
||||
data = np.zeros_like(data)
|
||||
|
||||
return np.clip(data, 0.0, 1.0)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of the Image object."""
|
||||
return (
|
||||
|
||||
Reference in New Issue
Block a user