Adding standalone training script and update
This commit is contained in:
211
tests/test_float32_training_loader.py
Normal file
211
tests/test_float32_training_loader.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
Test script for Float32 on-the-fly loading for 16-bit TIFFs.
|
||||
|
||||
This test verifies that:
|
||||
1. Float32YOLODataset can load 16-bit TIFF files
|
||||
2. Images are converted to float32 [0-1] in memory
|
||||
3. Grayscale is replicated to 3 channels (RGB)
|
||||
4. No disk caching is used
|
||||
5. Full 16-bit precision is preserved
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
import numpy as np
|
||||
import tifffile
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
|
||||
|
||||
def create_test_dataset():
|
||||
"""Create a minimal test dataset with 16-bit TIFF images."""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
dataset_dir = temp_dir / "test_dataset"
|
||||
|
||||
# Create directory structure
|
||||
train_images = dataset_dir / "train" / "images"
|
||||
train_labels = dataset_dir / "train" / "labels"
|
||||
train_images.mkdir(parents=True, exist_ok=True)
|
||||
train_labels.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create a 16-bit TIFF test image
|
||||
img_16bit = np.random.randint(0, 65536, (100, 100), dtype=np.uint16)
|
||||
img_path = train_images / "test_image.tif"
|
||||
tifffile.imwrite(str(img_path), img_16bit)
|
||||
|
||||
# Create a dummy label file
|
||||
label_path = train_labels / "test_image.txt"
|
||||
with open(label_path, "w") as f:
|
||||
f.write("0 0.5 0.5 0.2 0.2\n") # class_id x_center y_center width height
|
||||
|
||||
# Create data.yaml
|
||||
data_yaml = {
|
||||
"path": str(dataset_dir),
|
||||
"train": "train/images",
|
||||
"val": "train/images", # Use same for val in test
|
||||
"names": {0: "object"},
|
||||
"nc": 1,
|
||||
}
|
||||
yaml_path = dataset_dir / "data.yaml"
|
||||
with open(yaml_path, "w") as f:
|
||||
yaml.safe_dump(data_yaml, f)
|
||||
|
||||
print(f"✓ Created test dataset at: {dataset_dir}")
|
||||
print(f" - Image: {img_path} (shape={img_16bit.shape}, dtype={img_16bit.dtype})")
|
||||
print(f" - Min value: {img_16bit.min()}, Max value: {img_16bit.max()}")
|
||||
print(f" - data.yaml: {yaml_path}")
|
||||
|
||||
return dataset_dir, img_path, img_16bit
|
||||
|
||||
|
||||
def test_float32_dataset():
|
||||
"""Test the Float32YOLODataset class directly."""
|
||||
print("\n=== Testing Float32YOLODataset ===\n")
|
||||
|
||||
try:
|
||||
from src.utils.train_ultralytics_float import Float32YOLODataset
|
||||
|
||||
print("✓ Successfully imported Float32YOLODataset")
|
||||
except ImportError as e:
|
||||
print(f"✗ Failed to import Float32YOLODataset: {e}")
|
||||
return False
|
||||
|
||||
# Create test dataset
|
||||
dataset_dir, img_path, original_img = create_test_dataset()
|
||||
|
||||
try:
|
||||
# Initialize the dataset
|
||||
print("\nInitializing Float32YOLODataset...")
|
||||
dataset = Float32YOLODataset(
|
||||
images_dir=str(dataset_dir / "train" / "images"),
|
||||
labels_dir=str(dataset_dir / "train" / "labels"),
|
||||
img_size=640,
|
||||
)
|
||||
print(f"✓ Float32YOLODataset initialized with {len(dataset)} images")
|
||||
|
||||
# Get an item
|
||||
if len(dataset) > 0:
|
||||
print("\nGetting first item...")
|
||||
img_tensor, labels, filename = dataset[0]
|
||||
|
||||
print(f"✓ Item retrieved successfully")
|
||||
print(f" - Image tensor shape: {img_tensor.shape}")
|
||||
print(f" - Image tensor dtype: {img_tensor.dtype}")
|
||||
print(f" - Value range: [{img_tensor.min():.6f}, {img_tensor.max():.6f}]")
|
||||
print(f" - Filename: {filename}")
|
||||
print(f" - Labels: {len(labels)} annotations")
|
||||
if labels:
|
||||
print(
|
||||
f" - First label shape: {labels[0].shape if len(labels) > 0 else 'N/A'}"
|
||||
)
|
||||
|
||||
# Verify it's float32
|
||||
if img_tensor.dtype == torch.float32:
|
||||
print("✓ Correct dtype: float32")
|
||||
else:
|
||||
print(f"✗ Wrong dtype: {img_tensor.dtype} (expected float32)")
|
||||
return False
|
||||
|
||||
# Verify it's 3-channel in correct format (C, H, W)
|
||||
if len(img_tensor.shape) == 3 and img_tensor.shape[0] == 3:
|
||||
print(
|
||||
f"✓ Correct format: (C, H, W) = {img_tensor.shape} with 3 channels"
|
||||
)
|
||||
else:
|
||||
print(f"✗ Wrong shape: {img_tensor.shape} (expected (3, H, W))")
|
||||
return False
|
||||
|
||||
# Verify it's in [0, 1] range
|
||||
if 0.0 <= img_tensor.min() and img_tensor.max() <= 1.0:
|
||||
print("✓ Values in correct range: [0, 1]")
|
||||
else:
|
||||
print(
|
||||
f"✗ Values out of range: [{img_tensor.min()}, {img_tensor.max()}]"
|
||||
)
|
||||
return False
|
||||
|
||||
# Verify precision (should have many unique values)
|
||||
unique_values = len(torch.unique(img_tensor))
|
||||
print(f" - Unique values: {unique_values}")
|
||||
if unique_values > 256:
|
||||
print(f"✓ High precision maintained ({unique_values} > 256 levels)")
|
||||
else:
|
||||
print(f"⚠ Low precision: only {unique_values} unique values")
|
||||
|
||||
print("\n✓ All Float32YOLODataset tests passed!")
|
||||
return True
|
||||
else:
|
||||
print("✗ No items in dataset")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Error during testing: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def test_integration():
|
||||
"""Test integration with train_with_float32_loader."""
|
||||
print("\n=== Testing Integration with train_with_float32_loader ===\n")
|
||||
|
||||
# Create test dataset
|
||||
dataset_dir, img_path, original_img = create_test_dataset()
|
||||
data_yaml = dataset_dir / "data.yaml"
|
||||
|
||||
print(f"\nTest dataset ready at: {data_yaml}")
|
||||
print("\nTo test full training, run:")
|
||||
print(f" from src.utils.train_ultralytics_float import train_with_float32_loader")
|
||||
print(f" results = train_with_float32_loader(")
|
||||
print(f" model_path='yolov8n-seg.pt',")
|
||||
print(f" data_yaml='{data_yaml}',")
|
||||
print(f" epochs=1,")
|
||||
print(f" batch=1,")
|
||||
print(f" imgsz=640")
|
||||
print(f" )")
|
||||
print("\nThis will use custom training loop with Float32YOLODataset")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests."""
|
||||
import torch # Import here to ensure torch is available
|
||||
|
||||
print("=" * 70)
|
||||
print("Float32 Training Loader Test Suite")
|
||||
print("=" * 70)
|
||||
|
||||
results = []
|
||||
|
||||
# Test 1: Float32YOLODataset
|
||||
results.append(("Float32YOLODataset", test_float32_dataset()))
|
||||
|
||||
# Test 2: Integration check
|
||||
results.append(("Integration Check", test_integration()))
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 70)
|
||||
print("Test Summary")
|
||||
print("=" * 70)
|
||||
for test_name, passed in results:
|
||||
status = "✓ PASSED" if passed else "✗ FAILED"
|
||||
print(f"{status}: {test_name}")
|
||||
|
||||
all_passed = all(passed for _, passed in results)
|
||||
print("=" * 70)
|
||||
if all_passed:
|
||||
print("✓ All tests passed!")
|
||||
else:
|
||||
print("✗ Some tests failed")
|
||||
print("=" * 70)
|
||||
|
||||
return all_passed
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
import torch # Make torch available
|
||||
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
Reference in New Issue
Block a user