diff --git a/requirements.txt b/requirements.txt index 1ddc6b2..c8072e9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,6 +11,7 @@ pyqtgraph>=0.13.0 opencv-python>=4.8.0 Pillow>=10.0.0 numpy>=1.24.0 +tifffile>=2023.0.0 # Database sqlalchemy>=2.0.0 diff --git a/src/model/yolo_wrapper.py b/src/model/yolo_wrapper.py index d5e31b3..fbe71eb 100644 --- a/src/model/yolo_wrapper.py +++ b/src/model/yolo_wrapper.py @@ -9,6 +9,7 @@ from typing import Optional, List, Dict, Callable, Any import torch import tempfile import os +import numpy as np from src.utils.image import Image, convert_grayscale_to_rgb_preserve_range from src.utils.logger import get_logger @@ -228,7 +229,14 @@ class YOLOWrapper: raise def _prepare_source(self, source): - """Convert single-channel images to RGB temporarily for inference.""" + """Convert single-channel images to RGB for inference. + + For 16-bit TIFF files, this will: + 1. Load using tifffile + 2. Normalize to float32 [0-1] (NO uint8 conversion to avoid data loss) + 3. Replicate grayscale → RGB (3 channels) + 4. Pass directly as numpy array to YOLO + """ cleanup_path = None if isinstance(source, (str, Path)): @@ -236,22 +244,61 @@ class YOLOWrapper: if source_path.is_file(): try: img_obj = Image(source_path) - pil_img = img_obj.pil_image - if len(pil_img.getbands()) == 1: - rgb_img = convert_grayscale_to_rgb_preserve_range(pil_img) - else: - rgb_img = pil_img.convert("RGB") - suffix = source_path.suffix or ".png" - tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) - tmp_path = tmp.name - tmp.close() - rgb_img.save(tmp_path) - cleanup_path = tmp_path - logger.info( - f"Converted image {source_path} to RGB for inference at {tmp_path}" + # Check if it's a 16-bit TIFF file + is_16bit_tiff = ( + source_path.suffix.lower() in [".tif", ".tiff"] + and img_obj.dtype == np.uint16 ) - return tmp_path, cleanup_path + + if is_16bit_tiff: + # Process 16-bit TIFF: normalize to float32 [0-1] + # NO uint8 conversion - pass float32 directly to avoid data loss + normalized_float = img_obj.to_normalized_float32() + + # Convert grayscale to RGB by replicating channels + if len(normalized_float.shape) == 2: + # Grayscale: H,W → H,W,3 + rgb_float = np.stack([normalized_float] * 3, axis=-1) + elif ( + len(normalized_float.shape) == 3 + and normalized_float.shape[2] == 1 + ): + # Grayscale with channel dim: H,W,1 → H,W,3 + rgb_float = np.repeat(normalized_float, 3, axis=2) + else: + # Already multi-channel + rgb_float = normalized_float + + # Ensure contiguous array and float32 + rgb_float = np.ascontiguousarray(rgb_float, dtype=np.float32) + + logger.info( + f"Loaded 16-bit TIFF {source_path} as float32 [0-1] RGB " + f"(shape: {rgb_float.shape}, dtype: {rgb_float.dtype}, " + f"range: [{rgb_float.min():.4f}, {rgb_float.max():.4f}])" + ) + + # Return numpy array directly - YOLO can handle it + return rgb_float, cleanup_path + else: + # Standard processing for other images + pil_img = img_obj.pil_image + if len(pil_img.getbands()) == 1: + rgb_img = convert_grayscale_to_rgb_preserve_range(pil_img) + else: + rgb_img = pil_img.convert("RGB") + + suffix = source_path.suffix or ".png" + tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) + tmp_path = tmp.name + tmp.close() + rgb_img.save(tmp_path) + cleanup_path = tmp_path + logger.info( + f"Converted image {source_path} to RGB for inference at {tmp_path}" + ) + return tmp_path, cleanup_path except Exception as convert_error: logger.warning( f"Failed to preprocess {source_path} as RGB, continuing with original file: {convert_error}" diff --git a/src/utils/image.py b/src/utils/image.py index 69139cd..b3bdabb 100644 --- a/src/utils/image.py +++ b/src/utils/image.py @@ -7,6 +7,7 @@ import numpy as np from pathlib import Path from typing import Optional, Tuple, Union from PIL import Image as PILImage +import tifffile from src.utils.logger import get_logger from src.utils.file_utils import validate_file_path, is_image_file @@ -85,35 +86,75 @@ class Image: ) try: - # Load with OpenCV (returns BGR format) - self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED) + # Check if it's a TIFF file - use tifffile for better support + if self.path.suffix.lower() in [".tif", ".tiff"]: + self._data = tifffile.imread(str(self.path)) - if self._data is None: - raise ImageLoadError(f"Failed to load image with OpenCV: {self.path}") + if self._data is None: + raise ImageLoadError( + f"Failed to load TIFF with tifffile: {self.path}" + ) - # Extract metadata - self._height, self._width = self._data.shape[:2] - self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1 - self._format = self.path.suffix.lower().lstrip(".") - self._size_bytes = self.path.stat().st_size - self._dtype = self._data.dtype + # Extract metadata + self._height, self._width = ( + self._data.shape[:2] + if len(self._data.shape) >= 2 + else (self._data.shape[0], 1) + ) + self._channels = ( + self._data.shape[2] if len(self._data.shape) == 3 else 1 + ) + self._format = self.path.suffix.lower().lstrip(".") + self._size_bytes = self.path.stat().st_size + self._dtype = self._data.dtype - # Load PIL version for compatibility (convert BGR to RGB) - if self._channels == 3: - rgb_data = cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB) - self._pil_image = PILImage.fromarray(rgb_data) - elif self._channels == 4: - rgba_data = cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA) - self._pil_image = PILImage.fromarray(rgba_data) + # Load PIL version for compatibility + if self._channels == 1: + # Grayscale + self._pil_image = PILImage.fromarray(self._data) + else: + # Multi-channel (RGB or RGBA) + self._pil_image = PILImage.fromarray(self._data) + + logger.info( + f"Successfully loaded TIFF image: {self.path.name} " + f"({self._width}x{self._height}, {self._channels} channels, " + f"dtype={self._dtype}, {self._format.upper()})" + ) else: - # Grayscale - self._pil_image = PILImage.fromarray(self._data) + # Load with OpenCV (returns BGR format) for non-TIFF images + self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED) - logger.info( - f"Successfully loaded image: {self.path.name} " - f"({self._width}x{self._height}, {self._channels} channels, " - f"{self._format.upper()})" - ) + if self._data is None: + raise ImageLoadError( + f"Failed to load image with OpenCV: {self.path}" + ) + + # Extract metadata + self._height, self._width = self._data.shape[:2] + self._channels = ( + self._data.shape[2] if len(self._data.shape) == 3 else 1 + ) + self._format = self.path.suffix.lower().lstrip(".") + self._size_bytes = self.path.stat().st_size + self._dtype = self._data.dtype + + # Load PIL version for compatibility (convert BGR to RGB) + if self._channels == 3: + rgb_data = cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB) + self._pil_image = PILImage.fromarray(rgb_data) + elif self._channels == 4: + rgba_data = cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA) + self._pil_image = PILImage.fromarray(rgba_data) + else: + # Grayscale + self._pil_image = PILImage.fromarray(self._data) + + logger.info( + f"Successfully loaded image: {self.path.name} " + f"({self._width}x{self._height}, {self._channels} channels, " + f"{self._format.upper()})" + ) except Exception as e: logger.error(f"Error loading image {self.path}: {e}") @@ -277,6 +318,44 @@ class Image: """ return self._channels >= 3 + def to_normalized_float32(self) -> np.ndarray: + """ + Convert image data to normalized float32 in range [0, 1]. + + For 16-bit images, this properly scales the full dynamic range. + For 8-bit images, divides by 255. + Already float images are clipped to [0, 1]. + + Returns: + Normalized image data as float32 numpy array [0, 1] + """ + data = self._data.astype(np.float32) + + if self._dtype == np.uint16: + # 16-bit: normalize by max value (65535) + data = data / 65535.0 + elif self._dtype == np.uint8: + # 8-bit: normalize by 255 + data = data / 255.0 + elif np.issubdtype(self._dtype, np.floating): + # Already float, just clip to [0, 1] + data = np.clip(data, 0.0, 1.0) + else: + # Other integer types: use dtype info + if np.issubdtype(self._dtype, np.integer): + max_val = np.iinfo(self._dtype).max + data = data / float(max_val) + else: + # Unknown type: attempt min-max normalization + min_val = data.min() + max_val = data.max() + if max_val > min_val: + data = (data - min_val) / (max_val - min_val) + else: + data = np.zeros_like(data) + + return np.clip(data, 0.0, 1.0) + def __repr__(self) -> str: """String representation of the Image object.""" return ( diff --git a/src/utils/image_converters.py b/src/utils/image_converters.py index 17b52eb..726b59c 100644 --- a/src/utils/image_converters.py +++ b/src/utils/image_converters.py @@ -86,7 +86,7 @@ class UT: # TODO add image coordinates normalization coords = "" for x, y in roi.subpixel_coordinates: - coords += f"{x/self.width} {y/self.height}" + coords += f"{x/self.width} {y/self.height} " f.write(f"{class_index} {coords}\n") return