Implementing float 32 data managent
This commit is contained in:
@@ -11,6 +11,7 @@ pyqtgraph>=0.13.0
|
||||
opencv-python>=4.8.0
|
||||
Pillow>=10.0.0
|
||||
numpy>=1.24.0
|
||||
tifffile>=2023.0.0
|
||||
|
||||
# Database
|
||||
sqlalchemy>=2.0.0
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import Optional, List, Dict, Callable, Any
|
||||
import torch
|
||||
import tempfile
|
||||
import os
|
||||
import numpy as np
|
||||
from src.utils.image import Image, convert_grayscale_to_rgb_preserve_range
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
@@ -228,7 +229,14 @@ class YOLOWrapper:
|
||||
raise
|
||||
|
||||
def _prepare_source(self, source):
|
||||
"""Convert single-channel images to RGB temporarily for inference."""
|
||||
"""Convert single-channel images to RGB for inference.
|
||||
|
||||
For 16-bit TIFF files, this will:
|
||||
1. Load using tifffile
|
||||
2. Normalize to float32 [0-1] (NO uint8 conversion to avoid data loss)
|
||||
3. Replicate grayscale → RGB (3 channels)
|
||||
4. Pass directly as numpy array to YOLO
|
||||
"""
|
||||
cleanup_path = None
|
||||
|
||||
if isinstance(source, (str, Path)):
|
||||
@@ -236,22 +244,61 @@ class YOLOWrapper:
|
||||
if source_path.is_file():
|
||||
try:
|
||||
img_obj = Image(source_path)
|
||||
pil_img = img_obj.pil_image
|
||||
if len(pil_img.getbands()) == 1:
|
||||
rgb_img = convert_grayscale_to_rgb_preserve_range(pil_img)
|
||||
else:
|
||||
rgb_img = pil_img.convert("RGB")
|
||||
|
||||
suffix = source_path.suffix or ".png"
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
|
||||
tmp_path = tmp.name
|
||||
tmp.close()
|
||||
rgb_img.save(tmp_path)
|
||||
cleanup_path = tmp_path
|
||||
logger.info(
|
||||
f"Converted image {source_path} to RGB for inference at {tmp_path}"
|
||||
# Check if it's a 16-bit TIFF file
|
||||
is_16bit_tiff = (
|
||||
source_path.suffix.lower() in [".tif", ".tiff"]
|
||||
and img_obj.dtype == np.uint16
|
||||
)
|
||||
return tmp_path, cleanup_path
|
||||
|
||||
if is_16bit_tiff:
|
||||
# Process 16-bit TIFF: normalize to float32 [0-1]
|
||||
# NO uint8 conversion - pass float32 directly to avoid data loss
|
||||
normalized_float = img_obj.to_normalized_float32()
|
||||
|
||||
# Convert grayscale to RGB by replicating channels
|
||||
if len(normalized_float.shape) == 2:
|
||||
# Grayscale: H,W → H,W,3
|
||||
rgb_float = np.stack([normalized_float] * 3, axis=-1)
|
||||
elif (
|
||||
len(normalized_float.shape) == 3
|
||||
and normalized_float.shape[2] == 1
|
||||
):
|
||||
# Grayscale with channel dim: H,W,1 → H,W,3
|
||||
rgb_float = np.repeat(normalized_float, 3, axis=2)
|
||||
else:
|
||||
# Already multi-channel
|
||||
rgb_float = normalized_float
|
||||
|
||||
# Ensure contiguous array and float32
|
||||
rgb_float = np.ascontiguousarray(rgb_float, dtype=np.float32)
|
||||
|
||||
logger.info(
|
||||
f"Loaded 16-bit TIFF {source_path} as float32 [0-1] RGB "
|
||||
f"(shape: {rgb_float.shape}, dtype: {rgb_float.dtype}, "
|
||||
f"range: [{rgb_float.min():.4f}, {rgb_float.max():.4f}])"
|
||||
)
|
||||
|
||||
# Return numpy array directly - YOLO can handle it
|
||||
return rgb_float, cleanup_path
|
||||
else:
|
||||
# Standard processing for other images
|
||||
pil_img = img_obj.pil_image
|
||||
if len(pil_img.getbands()) == 1:
|
||||
rgb_img = convert_grayscale_to_rgb_preserve_range(pil_img)
|
||||
else:
|
||||
rgb_img = pil_img.convert("RGB")
|
||||
|
||||
suffix = source_path.suffix or ".png"
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
|
||||
tmp_path = tmp.name
|
||||
tmp.close()
|
||||
rgb_img.save(tmp_path)
|
||||
cleanup_path = tmp_path
|
||||
logger.info(
|
||||
f"Converted image {source_path} to RGB for inference at {tmp_path}"
|
||||
)
|
||||
return tmp_path, cleanup_path
|
||||
except Exception as convert_error:
|
||||
logger.warning(
|
||||
f"Failed to preprocess {source_path} as RGB, continuing with original file: {convert_error}"
|
||||
|
||||
@@ -7,6 +7,7 @@ import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
from PIL import Image as PILImage
|
||||
import tifffile
|
||||
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.file_utils import validate_file_path, is_image_file
|
||||
@@ -85,35 +86,75 @@ class Image:
|
||||
)
|
||||
|
||||
try:
|
||||
# Load with OpenCV (returns BGR format)
|
||||
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
|
||||
# Check if it's a TIFF file - use tifffile for better support
|
||||
if self.path.suffix.lower() in [".tif", ".tiff"]:
|
||||
self._data = tifffile.imread(str(self.path))
|
||||
|
||||
if self._data is None:
|
||||
raise ImageLoadError(f"Failed to load image with OpenCV: {self.path}")
|
||||
if self._data is None:
|
||||
raise ImageLoadError(
|
||||
f"Failed to load TIFF with tifffile: {self.path}"
|
||||
)
|
||||
|
||||
# Extract metadata
|
||||
self._height, self._width = self._data.shape[:2]
|
||||
self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1
|
||||
self._format = self.path.suffix.lower().lstrip(".")
|
||||
self._size_bytes = self.path.stat().st_size
|
||||
self._dtype = self._data.dtype
|
||||
# Extract metadata
|
||||
self._height, self._width = (
|
||||
self._data.shape[:2]
|
||||
if len(self._data.shape) >= 2
|
||||
else (self._data.shape[0], 1)
|
||||
)
|
||||
self._channels = (
|
||||
self._data.shape[2] if len(self._data.shape) == 3 else 1
|
||||
)
|
||||
self._format = self.path.suffix.lower().lstrip(".")
|
||||
self._size_bytes = self.path.stat().st_size
|
||||
self._dtype = self._data.dtype
|
||||
|
||||
# Load PIL version for compatibility (convert BGR to RGB)
|
||||
if self._channels == 3:
|
||||
rgb_data = cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
|
||||
self._pil_image = PILImage.fromarray(rgb_data)
|
||||
elif self._channels == 4:
|
||||
rgba_data = cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA)
|
||||
self._pil_image = PILImage.fromarray(rgba_data)
|
||||
# Load PIL version for compatibility
|
||||
if self._channels == 1:
|
||||
# Grayscale
|
||||
self._pil_image = PILImage.fromarray(self._data)
|
||||
else:
|
||||
# Multi-channel (RGB or RGBA)
|
||||
self._pil_image = PILImage.fromarray(self._data)
|
||||
|
||||
logger.info(
|
||||
f"Successfully loaded TIFF image: {self.path.name} "
|
||||
f"({self._width}x{self._height}, {self._channels} channels, "
|
||||
f"dtype={self._dtype}, {self._format.upper()})"
|
||||
)
|
||||
else:
|
||||
# Grayscale
|
||||
self._pil_image = PILImage.fromarray(self._data)
|
||||
# Load with OpenCV (returns BGR format) for non-TIFF images
|
||||
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
|
||||
|
||||
logger.info(
|
||||
f"Successfully loaded image: {self.path.name} "
|
||||
f"({self._width}x{self._height}, {self._channels} channels, "
|
||||
f"{self._format.upper()})"
|
||||
)
|
||||
if self._data is None:
|
||||
raise ImageLoadError(
|
||||
f"Failed to load image with OpenCV: {self.path}"
|
||||
)
|
||||
|
||||
# Extract metadata
|
||||
self._height, self._width = self._data.shape[:2]
|
||||
self._channels = (
|
||||
self._data.shape[2] if len(self._data.shape) == 3 else 1
|
||||
)
|
||||
self._format = self.path.suffix.lower().lstrip(".")
|
||||
self._size_bytes = self.path.stat().st_size
|
||||
self._dtype = self._data.dtype
|
||||
|
||||
# Load PIL version for compatibility (convert BGR to RGB)
|
||||
if self._channels == 3:
|
||||
rgb_data = cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
|
||||
self._pil_image = PILImage.fromarray(rgb_data)
|
||||
elif self._channels == 4:
|
||||
rgba_data = cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA)
|
||||
self._pil_image = PILImage.fromarray(rgba_data)
|
||||
else:
|
||||
# Grayscale
|
||||
self._pil_image = PILImage.fromarray(self._data)
|
||||
|
||||
logger.info(
|
||||
f"Successfully loaded image: {self.path.name} "
|
||||
f"({self._width}x{self._height}, {self._channels} channels, "
|
||||
f"{self._format.upper()})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading image {self.path}: {e}")
|
||||
@@ -277,6 +318,44 @@ class Image:
|
||||
"""
|
||||
return self._channels >= 3
|
||||
|
||||
def to_normalized_float32(self) -> np.ndarray:
|
||||
"""
|
||||
Convert image data to normalized float32 in range [0, 1].
|
||||
|
||||
For 16-bit images, this properly scales the full dynamic range.
|
||||
For 8-bit images, divides by 255.
|
||||
Already float images are clipped to [0, 1].
|
||||
|
||||
Returns:
|
||||
Normalized image data as float32 numpy array [0, 1]
|
||||
"""
|
||||
data = self._data.astype(np.float32)
|
||||
|
||||
if self._dtype == np.uint16:
|
||||
# 16-bit: normalize by max value (65535)
|
||||
data = data / 65535.0
|
||||
elif self._dtype == np.uint8:
|
||||
# 8-bit: normalize by 255
|
||||
data = data / 255.0
|
||||
elif np.issubdtype(self._dtype, np.floating):
|
||||
# Already float, just clip to [0, 1]
|
||||
data = np.clip(data, 0.0, 1.0)
|
||||
else:
|
||||
# Other integer types: use dtype info
|
||||
if np.issubdtype(self._dtype, np.integer):
|
||||
max_val = np.iinfo(self._dtype).max
|
||||
data = data / float(max_val)
|
||||
else:
|
||||
# Unknown type: attempt min-max normalization
|
||||
min_val = data.min()
|
||||
max_val = data.max()
|
||||
if max_val > min_val:
|
||||
data = (data - min_val) / (max_val - min_val)
|
||||
else:
|
||||
data = np.zeros_like(data)
|
||||
|
||||
return np.clip(data, 0.0, 1.0)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of the Image object."""
|
||||
return (
|
||||
|
||||
@@ -86,7 +86,7 @@ class UT:
|
||||
# TODO add image coordinates normalization
|
||||
coords = ""
|
||||
for x, y in roi.subpixel_coordinates:
|
||||
coords += f"{x/self.width} {y/self.height}"
|
||||
coords += f"{x/self.width} {y/self.height} "
|
||||
f.write(f"{class_index} {coords}\n")
|
||||
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user