Files
object-segmentation/tests/test_yolo_16bit_float32.py

151 lines
5.4 KiB
Python
Raw Normal View History

2025-12-13 00:32:32 +02:00
#!/usr/bin/env python3
"""
Test script for YOLO preprocessing of 16-bit TIFF images with float32 passthrough.
Verifies that no uint8 conversion occurs and data is preserved.
"""
import numpy as np
import tifffile
from pathlib import Path
import tempfile
import sys
import os
# Add parent directory to path to import modules
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.model.yolo_wrapper import YOLOWrapper
def create_test_16bit_tiff(output_path: str) -> str:
"""Create a test 16-bit grayscale TIFF file.
Args:
output_path: Path where to save the test TIFF
Returns:
Path to the created TIFF file
"""
# Create a 16-bit grayscale test image (200x200)
# With specific values to test precision preservation
height, width = 200, 200
# Create a gradient pattern with the full 16-bit range
test_data = np.zeros((height, width), dtype=np.uint16)
for i in range(height):
for j in range(width):
# Create a diagonal gradient using full 16-bit range
test_data[i, j] = int((i + j) / (height + width - 2) * 65535)
# Save as TIFF
tifffile.imwrite(output_path, test_data)
print(f"Created test 16-bit TIFF: {output_path}")
print(f" Shape: {test_data.shape}")
print(f" Dtype: {test_data.dtype}")
print(f" Min value: {test_data.min()}")
print(f" Max value: {test_data.max()}")
print(
f" Sample values: {test_data[50, 50]}, {test_data[100, 100]}, {test_data[150, 150]}"
)
return output_path
def test_float32_passthrough():
"""Test that 16-bit TIFF preprocessing passes float32 directly without uint8 conversion."""
print("\n=== Testing Float32 Passthrough (NO uint8) ===")
# Create temporary test file
with tempfile.NamedTemporaryFile(suffix=".tif", delete=False) as tmp:
test_path = tmp.name
try:
# Create test image
create_test_16bit_tiff(test_path)
# Create YOLOWrapper instance
print("\nTesting YOLOWrapper._prepare_source() for float32 passthrough...")
wrapper = YOLOWrapper()
# Call _prepare_source to preprocess the image
prepared_source, cleanup_path = wrapper._prepare_source(test_path)
print(f"\nPreprocessing result:")
print(f" Original path: {test_path}")
print(f" Prepared source type: {type(prepared_source)}")
# Verify it returns a numpy array (not a file path)
if isinstance(prepared_source, np.ndarray):
print(
f"\n✓ SUCCESS: Prepared source is a numpy array (float32 passthrough)"
)
print(f" Shape: {prepared_source.shape}")
print(f" Dtype: {prepared_source.dtype}")
print(f" Min value: {prepared_source.min():.6f}")
print(f" Max value: {prepared_source.max():.6f}")
print(f" Mean value: {prepared_source.mean():.6f}")
# Verify it's float32 in [0, 1] range
assert (
prepared_source.dtype == np.float32
), f"Expected float32, got {prepared_source.dtype}"
assert (
0.0 <= prepared_source.min() <= prepared_source.max() <= 1.0
), f"Expected values in [0, 1], got [{prepared_source.min()}, {prepared_source.max()}]"
# Verify it has 3 channels (RGB)
assert (
prepared_source.shape[2] == 3
), f"Expected 3 channels (RGB), got {prepared_source.shape[2]}"
# Verify no quantization to 256 levels (would happen with uint8 conversion)
unique_values = len(np.unique(prepared_source))
print(f" Unique values: {unique_values}")
# With float32, we should have much more than 256 unique values
if unique_values > 256:
print(f"\n✓ SUCCESS: Data has {unique_values} unique values (> 256)")
print(f" This confirms NO uint8 quantization occurred!")
else:
print(f"\n✗ WARNING: Data has only {unique_values} unique values")
print(f" This might indicate uint8 quantization happened")
# Sample some values to show precision
print(f"\n Sample normalized values:")
print(f" [50, 50]: {prepared_source[50, 50, 0]:.8f}")
print(f" [100, 100]: {prepared_source[100, 100, 0]:.8f}")
print(f" [150, 150]: {prepared_source[150, 150, 0]:.8f}")
# No cleanup needed since we returned array directly
assert (
cleanup_path is None
), "Cleanup path should be None for float32 pass through"
print("\n✓ All float32 passthrough tests passed!")
return True
else:
print(f"\n✗ FAILED: Prepared source is a file path: {prepared_source}")
print(f" This means data was saved to disk, not passed as float32 array")
if cleanup_path and os.path.exists(cleanup_path):
os.remove(cleanup_path)
return False
except Exception as e:
print(f"\n✗ Test failed with error: {e}")
import traceback
traceback.print_exc()
return False
finally:
# Cleanup
if os.path.exists(test_path):
os.remove(test_path)
print(f"\nCleaned up test file: {test_path}")
if __name__ == "__main__":
success = test_float32_passthrough()
sys.exit(0 if success else 1)