143 lines
4.9 KiB
Python
143 lines
4.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script for training dataset preparation with 16-bit TIFFs.
|
|
"""
|
|
|
|
import numpy as np
|
|
import tifffile
|
|
from pathlib import Path
|
|
import tempfile
|
|
import sys
|
|
import os
|
|
import shutil
|
|
|
|
# Add parent directory to path to import modules
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
from src.utils.image import Image
|
|
|
|
|
|
def test_float32_3ch_conversion():
|
|
"""Test conversion of 16-bit TIFF to 16-bit RGB PNG."""
|
|
print("\n=== Testing 16-bit RGB PNG Conversion ===")
|
|
|
|
# Create temporary directory structure
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
tmpdir = Path(tmpdir)
|
|
src_dir = tmpdir / "original"
|
|
dst_dir = tmpdir / "converted"
|
|
src_dir.mkdir()
|
|
dst_dir.mkdir()
|
|
|
|
# Create test 16-bit TIFF
|
|
test_data = np.zeros((100, 100), dtype=np.uint16)
|
|
for i in range(100):
|
|
for j in range(100):
|
|
test_data[i, j] = int((i + j) / 198 * 65535)
|
|
|
|
test_file = src_dir / "test_16bit.tif"
|
|
tifffile.imwrite(test_file, test_data)
|
|
print(f"Created test 16-bit TIFF: {test_file}")
|
|
print(f" Shape: {test_data.shape}")
|
|
print(f" Dtype: {test_data.dtype}")
|
|
print(f" Range: [{test_data.min()}, {test_data.max()}]")
|
|
|
|
# Simulate the conversion process (matching training_tab.py)
|
|
print("\nConverting to 16-bit RGB PNG using PIL merge...")
|
|
img_obj = Image(test_file)
|
|
from PIL import Image as PILImage
|
|
|
|
# Get uint16 data
|
|
uint16_data = img_obj.data
|
|
|
|
# Use PIL's merge method with 'I;16' channels (proper way for 16-bit RGB)
|
|
if len(uint16_data.shape) == 2:
|
|
# Grayscale - replicate to RGB
|
|
r_img = PILImage.fromarray(uint16_data, mode="I;16")
|
|
g_img = PILImage.fromarray(uint16_data, mode="I;16")
|
|
b_img = PILImage.fromarray(uint16_data, mode="I;16")
|
|
else:
|
|
r_img = PILImage.fromarray(uint16_data[:, :, 0], mode="I;16")
|
|
g_img = PILImage.fromarray(
|
|
(
|
|
uint16_data[:, :, 1]
|
|
if uint16_data.shape[2] > 1
|
|
else uint16_data[:, :, 0]
|
|
),
|
|
mode="I;16",
|
|
)
|
|
b_img = PILImage.fromarray(
|
|
(
|
|
uint16_data[:, :, 2]
|
|
if uint16_data.shape[2] > 2
|
|
else uint16_data[:, :, 0]
|
|
),
|
|
mode="I;16",
|
|
)
|
|
|
|
# Merge channels into RGB
|
|
rgb_img = PILImage.merge("RGB", (r_img, g_img, b_img))
|
|
|
|
# Save as PNG
|
|
output_file = dst_dir / "test_16bit_rgb.png"
|
|
rgb_img.save(output_file)
|
|
print(f"Saved 16-bit RGB PNG: {output_file}")
|
|
print(f" PIL mode after merge: {rgb_img.mode}")
|
|
|
|
# Verify the output - Load with OpenCV (as YOLO does)
|
|
import cv2
|
|
|
|
loaded = cv2.imread(str(output_file), cv2.IMREAD_UNCHANGED)
|
|
print(f"\nVerifying output (loaded with OpenCV):")
|
|
print(f" Shape: {loaded.shape}")
|
|
print(f" Dtype: {loaded.dtype}")
|
|
print(f" Channels: {loaded.shape[2] if len(loaded.shape) == 3 else 1}")
|
|
print(f" Range: [{loaded.min()}, {loaded.max()}]")
|
|
print(f" Unique values: {len(np.unique(loaded[:,:,0]))}")
|
|
|
|
# Assertions
|
|
assert loaded.dtype == np.uint16, f"Expected uint16, got {loaded.dtype}"
|
|
assert loaded.shape[2] == 3, f"Expected 3 channels, got {loaded.shape[2]}"
|
|
assert (
|
|
loaded.min() >= 0 and loaded.max() <= 65535
|
|
), f"Expected [0,65535] range, got [{loaded.min()}, {loaded.max()}]"
|
|
|
|
# Verify all channels are identical (replicated grayscale)
|
|
assert np.array_equal(
|
|
loaded[:, :, 0], loaded[:, :, 1]
|
|
), "Channel 0 and 1 should be identical"
|
|
assert np.array_equal(
|
|
loaded[:, :, 0], loaded[:, :, 2]
|
|
), "Channel 0 and 2 should be identical"
|
|
|
|
# Verify no data loss
|
|
unique_vals = len(np.unique(loaded[:, :, 0]))
|
|
print(f"\n Precision check:")
|
|
print(f" Unique values in channel: {unique_vals}")
|
|
print(f" Source unique values: {len(np.unique(test_data))}")
|
|
|
|
assert unique_vals == len(
|
|
np.unique(test_data)
|
|
), f"Expected {len(np.unique(test_data))} unique values, got {unique_vals}"
|
|
|
|
print("\n✓ All conversion tests passed!")
|
|
print(" - uint16 dtype preserved")
|
|
print(" - 3 channels created")
|
|
print(" - Range [0-65535] maintained")
|
|
print(" - No precision loss from conversion")
|
|
print(" - Channels properly replicated")
|
|
|
|
return True
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
success = test_float32_3ch_conversion()
|
|
sys.exit(0 if success else 1)
|
|
except Exception as e:
|
|
print(f"\n✗ Test failed with error: {e}")
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
sys.exit(1)
|