Implementing float 32 data managent

This commit is contained in:
2025-12-13 00:31:23 +02:00
parent 9c4c39fb39
commit b3b1e3acff
4 changed files with 167 additions and 40 deletions

View File

@@ -11,6 +11,7 @@ pyqtgraph>=0.13.0
opencv-python>=4.8.0 opencv-python>=4.8.0
Pillow>=10.0.0 Pillow>=10.0.0
numpy>=1.24.0 numpy>=1.24.0
tifffile>=2023.0.0
# Database # Database
sqlalchemy>=2.0.0 sqlalchemy>=2.0.0

View File

@@ -9,6 +9,7 @@ from typing import Optional, List, Dict, Callable, Any
import torch import torch
import tempfile import tempfile
import os import os
import numpy as np
from src.utils.image import Image, convert_grayscale_to_rgb_preserve_range from src.utils.image import Image, convert_grayscale_to_rgb_preserve_range
from src.utils.logger import get_logger from src.utils.logger import get_logger
@@ -228,7 +229,14 @@ class YOLOWrapper:
raise raise
def _prepare_source(self, source): def _prepare_source(self, source):
"""Convert single-channel images to RGB temporarily for inference.""" """Convert single-channel images to RGB for inference.
For 16-bit TIFF files, this will:
1. Load using tifffile
2. Normalize to float32 [0-1] (NO uint8 conversion to avoid data loss)
3. Replicate grayscale → RGB (3 channels)
4. Pass directly as numpy array to YOLO
"""
cleanup_path = None cleanup_path = None
if isinstance(source, (str, Path)): if isinstance(source, (str, Path)):
@@ -236,22 +244,61 @@ class YOLOWrapper:
if source_path.is_file(): if source_path.is_file():
try: try:
img_obj = Image(source_path) img_obj = Image(source_path)
pil_img = img_obj.pil_image
if len(pil_img.getbands()) == 1:
rgb_img = convert_grayscale_to_rgb_preserve_range(pil_img)
else:
rgb_img = pil_img.convert("RGB")
suffix = source_path.suffix or ".png" # Check if it's a 16-bit TIFF file
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) is_16bit_tiff = (
tmp_path = tmp.name source_path.suffix.lower() in [".tif", ".tiff"]
tmp.close() and img_obj.dtype == np.uint16
rgb_img.save(tmp_path)
cleanup_path = tmp_path
logger.info(
f"Converted image {source_path} to RGB for inference at {tmp_path}"
) )
return tmp_path, cleanup_path
if is_16bit_tiff:
# Process 16-bit TIFF: normalize to float32 [0-1]
# NO uint8 conversion - pass float32 directly to avoid data loss
normalized_float = img_obj.to_normalized_float32()
# Convert grayscale to RGB by replicating channels
if len(normalized_float.shape) == 2:
# Grayscale: H,W → H,W,3
rgb_float = np.stack([normalized_float] * 3, axis=-1)
elif (
len(normalized_float.shape) == 3
and normalized_float.shape[2] == 1
):
# Grayscale with channel dim: H,W,1 → H,W,3
rgb_float = np.repeat(normalized_float, 3, axis=2)
else:
# Already multi-channel
rgb_float = normalized_float
# Ensure contiguous array and float32
rgb_float = np.ascontiguousarray(rgb_float, dtype=np.float32)
logger.info(
f"Loaded 16-bit TIFF {source_path} as float32 [0-1] RGB "
f"(shape: {rgb_float.shape}, dtype: {rgb_float.dtype}, "
f"range: [{rgb_float.min():.4f}, {rgb_float.max():.4f}])"
)
# Return numpy array directly - YOLO can handle it
return rgb_float, cleanup_path
else:
# Standard processing for other images
pil_img = img_obj.pil_image
if len(pil_img.getbands()) == 1:
rgb_img = convert_grayscale_to_rgb_preserve_range(pil_img)
else:
rgb_img = pil_img.convert("RGB")
suffix = source_path.suffix or ".png"
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
tmp_path = tmp.name
tmp.close()
rgb_img.save(tmp_path)
cleanup_path = tmp_path
logger.info(
f"Converted image {source_path} to RGB for inference at {tmp_path}"
)
return tmp_path, cleanup_path
except Exception as convert_error: except Exception as convert_error:
logger.warning( logger.warning(
f"Failed to preprocess {source_path} as RGB, continuing with original file: {convert_error}" f"Failed to preprocess {source_path} as RGB, continuing with original file: {convert_error}"

View File

@@ -7,6 +7,7 @@ import numpy as np
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
from PIL import Image as PILImage from PIL import Image as PILImage
import tifffile
from src.utils.logger import get_logger from src.utils.logger import get_logger
from src.utils.file_utils import validate_file_path, is_image_file from src.utils.file_utils import validate_file_path, is_image_file
@@ -85,35 +86,75 @@ class Image:
) )
try: try:
# Load with OpenCV (returns BGR format) # Check if it's a TIFF file - use tifffile for better support
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED) if self.path.suffix.lower() in [".tif", ".tiff"]:
self._data = tifffile.imread(str(self.path))
if self._data is None: if self._data is None:
raise ImageLoadError(f"Failed to load image with OpenCV: {self.path}") raise ImageLoadError(
f"Failed to load TIFF with tifffile: {self.path}"
)
# Extract metadata # Extract metadata
self._height, self._width = self._data.shape[:2] self._height, self._width = (
self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1 self._data.shape[:2]
self._format = self.path.suffix.lower().lstrip(".") if len(self._data.shape) >= 2
self._size_bytes = self.path.stat().st_size else (self._data.shape[0], 1)
self._dtype = self._data.dtype )
self._channels = (
self._data.shape[2] if len(self._data.shape) == 3 else 1
)
self._format = self.path.suffix.lower().lstrip(".")
self._size_bytes = self.path.stat().st_size
self._dtype = self._data.dtype
# Load PIL version for compatibility (convert BGR to RGB) # Load PIL version for compatibility
if self._channels == 3: if self._channels == 1:
rgb_data = cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB) # Grayscale
self._pil_image = PILImage.fromarray(rgb_data) self._pil_image = PILImage.fromarray(self._data)
elif self._channels == 4: else:
rgba_data = cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA) # Multi-channel (RGB or RGBA)
self._pil_image = PILImage.fromarray(rgba_data) self._pil_image = PILImage.fromarray(self._data)
logger.info(
f"Successfully loaded TIFF image: {self.path.name} "
f"({self._width}x{self._height}, {self._channels} channels, "
f"dtype={self._dtype}, {self._format.upper()})"
)
else: else:
# Grayscale # Load with OpenCV (returns BGR format) for non-TIFF images
self._pil_image = PILImage.fromarray(self._data) self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
logger.info( if self._data is None:
f"Successfully loaded image: {self.path.name} " raise ImageLoadError(
f"({self._width}x{self._height}, {self._channels} channels, " f"Failed to load image with OpenCV: {self.path}"
f"{self._format.upper()})" )
)
# Extract metadata
self._height, self._width = self._data.shape[:2]
self._channels = (
self._data.shape[2] if len(self._data.shape) == 3 else 1
)
self._format = self.path.suffix.lower().lstrip(".")
self._size_bytes = self.path.stat().st_size
self._dtype = self._data.dtype
# Load PIL version for compatibility (convert BGR to RGB)
if self._channels == 3:
rgb_data = cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
self._pil_image = PILImage.fromarray(rgb_data)
elif self._channels == 4:
rgba_data = cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA)
self._pil_image = PILImage.fromarray(rgba_data)
else:
# Grayscale
self._pil_image = PILImage.fromarray(self._data)
logger.info(
f"Successfully loaded image: {self.path.name} "
f"({self._width}x{self._height}, {self._channels} channels, "
f"{self._format.upper()})"
)
except Exception as e: except Exception as e:
logger.error(f"Error loading image {self.path}: {e}") logger.error(f"Error loading image {self.path}: {e}")
@@ -277,6 +318,44 @@ class Image:
""" """
return self._channels >= 3 return self._channels >= 3
def to_normalized_float32(self) -> np.ndarray:
"""
Convert image data to normalized float32 in range [0, 1].
For 16-bit images, this properly scales the full dynamic range.
For 8-bit images, divides by 255.
Already float images are clipped to [0, 1].
Returns:
Normalized image data as float32 numpy array [0, 1]
"""
data = self._data.astype(np.float32)
if self._dtype == np.uint16:
# 16-bit: normalize by max value (65535)
data = data / 65535.0
elif self._dtype == np.uint8:
# 8-bit: normalize by 255
data = data / 255.0
elif np.issubdtype(self._dtype, np.floating):
# Already float, just clip to [0, 1]
data = np.clip(data, 0.0, 1.0)
else:
# Other integer types: use dtype info
if np.issubdtype(self._dtype, np.integer):
max_val = np.iinfo(self._dtype).max
data = data / float(max_val)
else:
# Unknown type: attempt min-max normalization
min_val = data.min()
max_val = data.max()
if max_val > min_val:
data = (data - min_val) / (max_val - min_val)
else:
data = np.zeros_like(data)
return np.clip(data, 0.0, 1.0)
def __repr__(self) -> str: def __repr__(self) -> str:
"""String representation of the Image object.""" """String representation of the Image object."""
return ( return (

View File

@@ -86,7 +86,7 @@ class UT:
# TODO add image coordinates normalization # TODO add image coordinates normalization
coords = "" coords = ""
for x, y in roi.subpixel_coordinates: for x, y in roi.subpixel_coordinates:
coords += f"{x/self.width} {y/self.height}" coords += f"{x/self.width} {y/self.height} "
f.write(f"{class_index} {coords}\n") f.write(f"{class_index} {coords}\n")
return return