Implementing float 32 data managent
This commit is contained in:
@@ -11,6 +11,7 @@ pyqtgraph>=0.13.0
|
|||||||
opencv-python>=4.8.0
|
opencv-python>=4.8.0
|
||||||
Pillow>=10.0.0
|
Pillow>=10.0.0
|
||||||
numpy>=1.24.0
|
numpy>=1.24.0
|
||||||
|
tifffile>=2023.0.0
|
||||||
|
|
||||||
# Database
|
# Database
|
||||||
sqlalchemy>=2.0.0
|
sqlalchemy>=2.0.0
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from typing import Optional, List, Dict, Callable, Any
|
|||||||
import torch
|
import torch
|
||||||
import tempfile
|
import tempfile
|
||||||
import os
|
import os
|
||||||
|
import numpy as np
|
||||||
from src.utils.image import Image, convert_grayscale_to_rgb_preserve_range
|
from src.utils.image import Image, convert_grayscale_to_rgb_preserve_range
|
||||||
from src.utils.logger import get_logger
|
from src.utils.logger import get_logger
|
||||||
|
|
||||||
@@ -228,7 +229,14 @@ class YOLOWrapper:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def _prepare_source(self, source):
|
def _prepare_source(self, source):
|
||||||
"""Convert single-channel images to RGB temporarily for inference."""
|
"""Convert single-channel images to RGB for inference.
|
||||||
|
|
||||||
|
For 16-bit TIFF files, this will:
|
||||||
|
1. Load using tifffile
|
||||||
|
2. Normalize to float32 [0-1] (NO uint8 conversion to avoid data loss)
|
||||||
|
3. Replicate grayscale → RGB (3 channels)
|
||||||
|
4. Pass directly as numpy array to YOLO
|
||||||
|
"""
|
||||||
cleanup_path = None
|
cleanup_path = None
|
||||||
|
|
||||||
if isinstance(source, (str, Path)):
|
if isinstance(source, (str, Path)):
|
||||||
@@ -236,6 +244,45 @@ class YOLOWrapper:
|
|||||||
if source_path.is_file():
|
if source_path.is_file():
|
||||||
try:
|
try:
|
||||||
img_obj = Image(source_path)
|
img_obj = Image(source_path)
|
||||||
|
|
||||||
|
# Check if it's a 16-bit TIFF file
|
||||||
|
is_16bit_tiff = (
|
||||||
|
source_path.suffix.lower() in [".tif", ".tiff"]
|
||||||
|
and img_obj.dtype == np.uint16
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_16bit_tiff:
|
||||||
|
# Process 16-bit TIFF: normalize to float32 [0-1]
|
||||||
|
# NO uint8 conversion - pass float32 directly to avoid data loss
|
||||||
|
normalized_float = img_obj.to_normalized_float32()
|
||||||
|
|
||||||
|
# Convert grayscale to RGB by replicating channels
|
||||||
|
if len(normalized_float.shape) == 2:
|
||||||
|
# Grayscale: H,W → H,W,3
|
||||||
|
rgb_float = np.stack([normalized_float] * 3, axis=-1)
|
||||||
|
elif (
|
||||||
|
len(normalized_float.shape) == 3
|
||||||
|
and normalized_float.shape[2] == 1
|
||||||
|
):
|
||||||
|
# Grayscale with channel dim: H,W,1 → H,W,3
|
||||||
|
rgb_float = np.repeat(normalized_float, 3, axis=2)
|
||||||
|
else:
|
||||||
|
# Already multi-channel
|
||||||
|
rgb_float = normalized_float
|
||||||
|
|
||||||
|
# Ensure contiguous array and float32
|
||||||
|
rgb_float = np.ascontiguousarray(rgb_float, dtype=np.float32)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Loaded 16-bit TIFF {source_path} as float32 [0-1] RGB "
|
||||||
|
f"(shape: {rgb_float.shape}, dtype: {rgb_float.dtype}, "
|
||||||
|
f"range: [{rgb_float.min():.4f}, {rgb_float.max():.4f}])"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return numpy array directly - YOLO can handle it
|
||||||
|
return rgb_float, cleanup_path
|
||||||
|
else:
|
||||||
|
# Standard processing for other images
|
||||||
pil_img = img_obj.pil_image
|
pil_img = img_obj.pil_image
|
||||||
if len(pil_img.getbands()) == 1:
|
if len(pil_img.getbands()) == 1:
|
||||||
rgb_img = convert_grayscale_to_rgb_preserve_range(pil_img)
|
rgb_img = convert_grayscale_to_rgb_preserve_range(pil_img)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import numpy as np
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
|
import tifffile
|
||||||
|
|
||||||
from src.utils.logger import get_logger
|
from src.utils.logger import get_logger
|
||||||
from src.utils.file_utils import validate_file_path, is_image_file
|
from src.utils.file_utils import validate_file_path, is_image_file
|
||||||
@@ -85,15 +86,55 @@ class Image:
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load with OpenCV (returns BGR format)
|
# Check if it's a TIFF file - use tifffile for better support
|
||||||
|
if self.path.suffix.lower() in [".tif", ".tiff"]:
|
||||||
|
self._data = tifffile.imread(str(self.path))
|
||||||
|
|
||||||
|
if self._data is None:
|
||||||
|
raise ImageLoadError(
|
||||||
|
f"Failed to load TIFF with tifffile: {self.path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract metadata
|
||||||
|
self._height, self._width = (
|
||||||
|
self._data.shape[:2]
|
||||||
|
if len(self._data.shape) >= 2
|
||||||
|
else (self._data.shape[0], 1)
|
||||||
|
)
|
||||||
|
self._channels = (
|
||||||
|
self._data.shape[2] if len(self._data.shape) == 3 else 1
|
||||||
|
)
|
||||||
|
self._format = self.path.suffix.lower().lstrip(".")
|
||||||
|
self._size_bytes = self.path.stat().st_size
|
||||||
|
self._dtype = self._data.dtype
|
||||||
|
|
||||||
|
# Load PIL version for compatibility
|
||||||
|
if self._channels == 1:
|
||||||
|
# Grayscale
|
||||||
|
self._pil_image = PILImage.fromarray(self._data)
|
||||||
|
else:
|
||||||
|
# Multi-channel (RGB or RGBA)
|
||||||
|
self._pil_image = PILImage.fromarray(self._data)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Successfully loaded TIFF image: {self.path.name} "
|
||||||
|
f"({self._width}x{self._height}, {self._channels} channels, "
|
||||||
|
f"dtype={self._dtype}, {self._format.upper()})"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Load with OpenCV (returns BGR format) for non-TIFF images
|
||||||
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
|
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
|
||||||
|
|
||||||
if self._data is None:
|
if self._data is None:
|
||||||
raise ImageLoadError(f"Failed to load image with OpenCV: {self.path}")
|
raise ImageLoadError(
|
||||||
|
f"Failed to load image with OpenCV: {self.path}"
|
||||||
|
)
|
||||||
|
|
||||||
# Extract metadata
|
# Extract metadata
|
||||||
self._height, self._width = self._data.shape[:2]
|
self._height, self._width = self._data.shape[:2]
|
||||||
self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1
|
self._channels = (
|
||||||
|
self._data.shape[2] if len(self._data.shape) == 3 else 1
|
||||||
|
)
|
||||||
self._format = self.path.suffix.lower().lstrip(".")
|
self._format = self.path.suffix.lower().lstrip(".")
|
||||||
self._size_bytes = self.path.stat().st_size
|
self._size_bytes = self.path.stat().st_size
|
||||||
self._dtype = self._data.dtype
|
self._dtype = self._data.dtype
|
||||||
@@ -277,6 +318,44 @@ class Image:
|
|||||||
"""
|
"""
|
||||||
return self._channels >= 3
|
return self._channels >= 3
|
||||||
|
|
||||||
|
def to_normalized_float32(self) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Convert image data to normalized float32 in range [0, 1].
|
||||||
|
|
||||||
|
For 16-bit images, this properly scales the full dynamic range.
|
||||||
|
For 8-bit images, divides by 255.
|
||||||
|
Already float images are clipped to [0, 1].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized image data as float32 numpy array [0, 1]
|
||||||
|
"""
|
||||||
|
data = self._data.astype(np.float32)
|
||||||
|
|
||||||
|
if self._dtype == np.uint16:
|
||||||
|
# 16-bit: normalize by max value (65535)
|
||||||
|
data = data / 65535.0
|
||||||
|
elif self._dtype == np.uint8:
|
||||||
|
# 8-bit: normalize by 255
|
||||||
|
data = data / 255.0
|
||||||
|
elif np.issubdtype(self._dtype, np.floating):
|
||||||
|
# Already float, just clip to [0, 1]
|
||||||
|
data = np.clip(data, 0.0, 1.0)
|
||||||
|
else:
|
||||||
|
# Other integer types: use dtype info
|
||||||
|
if np.issubdtype(self._dtype, np.integer):
|
||||||
|
max_val = np.iinfo(self._dtype).max
|
||||||
|
data = data / float(max_val)
|
||||||
|
else:
|
||||||
|
# Unknown type: attempt min-max normalization
|
||||||
|
min_val = data.min()
|
||||||
|
max_val = data.max()
|
||||||
|
if max_val > min_val:
|
||||||
|
data = (data - min_val) / (max_val - min_val)
|
||||||
|
else:
|
||||||
|
data = np.zeros_like(data)
|
||||||
|
|
||||||
|
return np.clip(data, 0.0, 1.0)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
"""String representation of the Image object."""
|
"""String representation of the Image object."""
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ class UT:
|
|||||||
# TODO add image coordinates normalization
|
# TODO add image coordinates normalization
|
||||||
coords = ""
|
coords = ""
|
||||||
for x, y in roi.subpixel_coordinates:
|
for x, y in roi.subpixel_coordinates:
|
||||||
coords += f"{x/self.width} {y/self.height}"
|
coords += f"{x/self.width} {y/self.height} "
|
||||||
f.write(f"{class_index} {coords}\n")
|
f.write(f"{class_index} {coords}\n")
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user