Adding standalone training script and update

This commit is contained in:
2025-12-13 09:28:24 +02:00
parent 908e9a5b82
commit aec0fbf83c
8 changed files with 1434 additions and 290 deletions

View File

@@ -0,0 +1,211 @@
"""
Test script for Float32 on-the-fly loading for 16-bit TIFFs.
This test verifies that:
1. Float32YOLODataset can load 16-bit TIFF files
2. Images are converted to float32 [0-1] in memory
3. Grayscale is replicated to 3 channels (RGB)
4. No disk caching is used
5. Full 16-bit precision is preserved
"""
import tempfile
import numpy as np
import tifffile
from pathlib import Path
import yaml
def create_test_dataset():
"""Create a minimal test dataset with 16-bit TIFF images."""
temp_dir = Path(tempfile.mkdtemp())
dataset_dir = temp_dir / "test_dataset"
# Create directory structure
train_images = dataset_dir / "train" / "images"
train_labels = dataset_dir / "train" / "labels"
train_images.mkdir(parents=True, exist_ok=True)
train_labels.mkdir(parents=True, exist_ok=True)
# Create a 16-bit TIFF test image
img_16bit = np.random.randint(0, 65536, (100, 100), dtype=np.uint16)
img_path = train_images / "test_image.tif"
tifffile.imwrite(str(img_path), img_16bit)
# Create a dummy label file
label_path = train_labels / "test_image.txt"
with open(label_path, "w") as f:
f.write("0 0.5 0.5 0.2 0.2\n") # class_id x_center y_center width height
# Create data.yaml
data_yaml = {
"path": str(dataset_dir),
"train": "train/images",
"val": "train/images", # Use same for val in test
"names": {0: "object"},
"nc": 1,
}
yaml_path = dataset_dir / "data.yaml"
with open(yaml_path, "w") as f:
yaml.safe_dump(data_yaml, f)
print(f"✓ Created test dataset at: {dataset_dir}")
print(f" - Image: {img_path} (shape={img_16bit.shape}, dtype={img_16bit.dtype})")
print(f" - Min value: {img_16bit.min()}, Max value: {img_16bit.max()}")
print(f" - data.yaml: {yaml_path}")
return dataset_dir, img_path, img_16bit
def test_float32_dataset():
"""Test the Float32YOLODataset class directly."""
print("\n=== Testing Float32YOLODataset ===\n")
try:
from src.utils.train_ultralytics_float import Float32YOLODataset
print("✓ Successfully imported Float32YOLODataset")
except ImportError as e:
print(f"✗ Failed to import Float32YOLODataset: {e}")
return False
# Create test dataset
dataset_dir, img_path, original_img = create_test_dataset()
try:
# Initialize the dataset
print("\nInitializing Float32YOLODataset...")
dataset = Float32YOLODataset(
images_dir=str(dataset_dir / "train" / "images"),
labels_dir=str(dataset_dir / "train" / "labels"),
img_size=640,
)
print(f"✓ Float32YOLODataset initialized with {len(dataset)} images")
# Get an item
if len(dataset) > 0:
print("\nGetting first item...")
img_tensor, labels, filename = dataset[0]
print(f"✓ Item retrieved successfully")
print(f" - Image tensor shape: {img_tensor.shape}")
print(f" - Image tensor dtype: {img_tensor.dtype}")
print(f" - Value range: [{img_tensor.min():.6f}, {img_tensor.max():.6f}]")
print(f" - Filename: {filename}")
print(f" - Labels: {len(labels)} annotations")
if labels:
print(
f" - First label shape: {labels[0].shape if len(labels) > 0 else 'N/A'}"
)
# Verify it's float32
if img_tensor.dtype == torch.float32:
print("✓ Correct dtype: float32")
else:
print(f"✗ Wrong dtype: {img_tensor.dtype} (expected float32)")
return False
# Verify it's 3-channel in correct format (C, H, W)
if len(img_tensor.shape) == 3 and img_tensor.shape[0] == 3:
print(
f"✓ Correct format: (C, H, W) = {img_tensor.shape} with 3 channels"
)
else:
print(f"✗ Wrong shape: {img_tensor.shape} (expected (3, H, W))")
return False
# Verify it's in [0, 1] range
if 0.0 <= img_tensor.min() and img_tensor.max() <= 1.0:
print("✓ Values in correct range: [0, 1]")
else:
print(
f"✗ Values out of range: [{img_tensor.min()}, {img_tensor.max()}]"
)
return False
# Verify precision (should have many unique values)
unique_values = len(torch.unique(img_tensor))
print(f" - Unique values: {unique_values}")
if unique_values > 256:
print(f"✓ High precision maintained ({unique_values} > 256 levels)")
else:
print(f"⚠ Low precision: only {unique_values} unique values")
print("\n✓ All Float32YOLODataset tests passed!")
return True
else:
print("✗ No items in dataset")
return False
except Exception as e:
print(f"✗ Error during testing: {e}")
import traceback
traceback.print_exc()
return False
def test_integration():
"""Test integration with train_with_float32_loader."""
print("\n=== Testing Integration with train_with_float32_loader ===\n")
# Create test dataset
dataset_dir, img_path, original_img = create_test_dataset()
data_yaml = dataset_dir / "data.yaml"
print(f"\nTest dataset ready at: {data_yaml}")
print("\nTo test full training, run:")
print(f" from src.utils.train_ultralytics_float import train_with_float32_loader")
print(f" results = train_with_float32_loader(")
print(f" model_path='yolov8n-seg.pt',")
print(f" data_yaml='{data_yaml}',")
print(f" epochs=1,")
print(f" batch=1,")
print(f" imgsz=640")
print(f" )")
print("\nThis will use custom training loop with Float32YOLODataset")
return True
def main():
"""Run all tests."""
import torch # Import here to ensure torch is available
print("=" * 70)
print("Float32 Training Loader Test Suite")
print("=" * 70)
results = []
# Test 1: Float32YOLODataset
results.append(("Float32YOLODataset", test_float32_dataset()))
# Test 2: Integration check
results.append(("Integration Check", test_integration()))
# Summary
print("\n" + "=" * 70)
print("Test Summary")
print("=" * 70)
for test_name, passed in results:
status = "✓ PASSED" if passed else "✗ FAILED"
print(f"{status}: {test_name}")
all_passed = all(passed for _, passed in results)
print("=" * 70)
if all_passed:
print("✓ All tests passed!")
else:
print("✗ Some tests failed")
print("=" * 70)
return all_passed
if __name__ == "__main__":
import sys
import torch # Make torch available
success = main()
sys.exit(0 if success else 1)

View File

@@ -18,8 +18,8 @@ from src.utils.image import Image
def test_float32_3ch_conversion():
"""Test conversion of 16-bit TIFF to float32 3-channel TIFF."""
print("\n=== Testing Float32 3-Channel Conversion ===")
"""Test conversion of 16-bit TIFF to 16-bit RGB PNG."""
print("\n=== Testing 16-bit RGB PNG Conversion ===")
# Create temporary directory structure
with tempfile.TemporaryDirectory() as tmpdir:
@@ -42,39 +42,65 @@ def test_float32_3ch_conversion():
print(f" Dtype: {test_data.dtype}")
print(f" Range: [{test_data.min()}, {test_data.max()}]")
# Simulate the conversion process
print("\nConverting to float32 3-channel...")
# Simulate the conversion process (matching training_tab.py)
print("\nConverting to 16-bit RGB PNG using PIL merge...")
img_obj = Image(test_file)
from PIL import Image as PILImage
# Convert to float32 [0-1]
float_data = img_obj.to_normalized_float32()
# Get uint16 data
uint16_data = img_obj.data
# Replicate to 3 channels
if len(float_data.shape) == 2:
float_3ch = np.stack([float_data] * 3, axis=-1)
# Use PIL's merge method with 'I;16' channels (proper way for 16-bit RGB)
if len(uint16_data.shape) == 2:
# Grayscale - replicate to RGB
r_img = PILImage.fromarray(uint16_data, mode="I;16")
g_img = PILImage.fromarray(uint16_data, mode="I;16")
b_img = PILImage.fromarray(uint16_data, mode="I;16")
else:
float_3ch = float_data
r_img = PILImage.fromarray(uint16_data[:, :, 0], mode="I;16")
g_img = PILImage.fromarray(
(
uint16_data[:, :, 1]
if uint16_data.shape[2] > 1
else uint16_data[:, :, 0]
),
mode="I;16",
)
b_img = PILImage.fromarray(
(
uint16_data[:, :, 2]
if uint16_data.shape[2] > 2
else uint16_data[:, :, 0]
),
mode="I;16",
)
# Save as float32 TIFF
output_file = dst_dir / "test_float32_3ch.tif"
tifffile.imwrite(output_file, float_3ch.astype(np.float32))
print(f"Saved float32 3-channel TIFF: {output_file}")
# Merge channels into RGB
rgb_img = PILImage.merge("RGB", (r_img, g_img, b_img))
# Verify the output
loaded = tifffile.imread(output_file)
print(f"\nVerifying output:")
# Save as PNG
output_file = dst_dir / "test_16bit_rgb.png"
rgb_img.save(output_file)
print(f"Saved 16-bit RGB PNG: {output_file}")
print(f" PIL mode after merge: {rgb_img.mode}")
# Verify the output - Load with OpenCV (as YOLO does)
import cv2
loaded = cv2.imread(str(output_file), cv2.IMREAD_UNCHANGED)
print(f"\nVerifying output (loaded with OpenCV):")
print(f" Shape: {loaded.shape}")
print(f" Dtype: {loaded.dtype}")
print(f" Channels: {loaded.shape[2] if len(loaded.shape) == 3 else 1}")
print(f" Range: [{loaded.min():.6f}, {loaded.max():.6f}]")
print(f" Range: [{loaded.min()}, {loaded.max()}]")
print(f" Unique values: {len(np.unique(loaded[:,:,0]))}")
# Assertions
assert loaded.dtype == np.float32, f"Expected float32, got {loaded.dtype}"
assert loaded.dtype == np.uint16, f"Expected uint16, got {loaded.dtype}"
assert loaded.shape[2] == 3, f"Expected 3 channels, got {loaded.shape[2]}"
assert (
0.0 <= loaded.min() <= loaded.max() <= 1.0
), f"Expected [0,1] range, got [{loaded.min()}, {loaded.max()}]"
loaded.min() >= 0 and loaded.max() <= 65535
), f"Expected [0,65535] range, got [{loaded.min()}, {loaded.max()}]"
# Verify all channels are identical (replicated grayscale)
assert np.array_equal(
@@ -84,21 +110,20 @@ def test_float32_3ch_conversion():
loaded[:, :, 0], loaded[:, :, 2]
), "Channel 0 and 2 should be identical"
# Verify float32 precision (not quantized to uint8 steps)
# Verify no data loss
unique_vals = len(np.unique(loaded[:, :, 0]))
print(f"\n Precision check:")
print(f" Unique values in channel: {unique_vals}")
print(f" Source unique values: {len(np.unique(test_data))}")
# The final unique values should match source (no loss from conversion)
assert unique_vals == len(
np.unique(test_data)
), f"Expected {len(np.unique(test_data))} unique values, got {unique_vals}"
print("\n✓ All conversion tests passed!")
print(" - Float32 dtype preserved")
print(" - uint16 dtype preserved")
print(" - 3 channels created")
print(" - Range [0-1] maintained")
print(" - Range [0-65535] maintained")
print(" - No precision loss from conversion")
print(" - Channels properly replicated")