Bug fix
This commit is contained in:
@@ -2,14 +2,15 @@
|
||||
|
||||
## Overview
|
||||
|
||||
This document describes the implementation of 16-bit grayscale TIFF support for YOLO object detection. The system properly loads 16-bit TIFF images, normalizes them to float32 [0-1], and passes them directly to YOLO **without uint8 conversion** to preserve the full dynamic range and avoid data loss.
|
||||
This document describes the implementation of 16-bit grayscale TIFF support for YOLO object detection. The system properly loads 16-bit TIFF images, normalizes them to float32 [0-1], and handles them appropriately for both **inference** and **training** **without uint8 conversion** to preserve the full dynamic range and avoid data loss.
|
||||
|
||||
## Key Features
|
||||
|
||||
✅ Reads 16-bit or float32 images using tifffile
|
||||
✅ Converts to float32 [0-1] (NO uint8 conversion)
|
||||
✅ Replicates grayscale → RGB (3 channels)
|
||||
✅ Passes numpy arrays directly to YOLO (no file I/O)
|
||||
✅ **Inference**: Passes numpy arrays directly to YOLO (no file I/O)
|
||||
✅ **Training**: Creates float32 3-channel TIFF dataset cache
|
||||
✅ Uses Ultralytics YOLOv8/v11 models
|
||||
✅ Works with segmentation models
|
||||
✅ No data loss, no double normalization, no silent clipping
|
||||
@@ -46,7 +47,9 @@ Enhanced [`YOLOWrapper._prepare_source()`](../src/model/yolo_wrapper.py:231) to:
|
||||
|
||||
## Processing Pipeline
|
||||
|
||||
For 16-bit TIFF files:
|
||||
### For Inference (predict)
|
||||
|
||||
For 16-bit TIFF files during inference:
|
||||
|
||||
1. **Load**: File loaded using `tifffile` → preserves 16-bit uint16 data
|
||||
2. **Normalize**: Convert to float32 and scale to [0, 1]
|
||||
@@ -60,12 +63,28 @@ For 16-bit TIFF files:
|
||||
4. **Pass to YOLO**: Return float32 array directly (no uint8, no file I/O)
|
||||
5. **Inference**: YOLO processes the float32 [0-1] RGB array
|
||||
|
||||
### For Training (train)
|
||||
|
||||
During training, YOLO's internal dataloader loads images from disk, so we create a cached 3-channel dataset:
|
||||
|
||||
1. **Detect**: Check if dataset contains 16-bit TIFF files
|
||||
2. **Create Cache**: Build float32 3-channel TIFF dataset in `data/datasets/_float32_cache/`
|
||||
3. **Convert Each Image**:
|
||||
- Load 16-bit TIFF using `tifffile`
|
||||
- Normalize to float32 [0-1]
|
||||
- Replicate to 3 channels
|
||||
- Save as float32 TIFF (preserves precision)
|
||||
4. **Copy Labels**: Copy label files unchanged
|
||||
5. **Generate data.yaml**: Points to cached 3-channel dataset
|
||||
6. **Train**: YOLO trains on float32 3-channel TIFFs
|
||||
|
||||
### No Data Loss!
|
||||
|
||||
Unlike the previous approach that converted to uint8 (256 levels), the new implementation:
|
||||
Unlike approaches that convert to uint8 (256 levels), this implementation:
|
||||
- Preserves full 16-bit dynamic range (65536 levels)
|
||||
- Maintains precision with float32 representation
|
||||
- Passes data directly without intermediate file conversions
|
||||
- For inference: passes data directly without file conversions
|
||||
- For training: uses float32 TIFFs (not uint8 PNGs)
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -188,14 +207,16 @@ This test shows the old behavior (uint8 conversion) - kept for comparison.
|
||||
|
||||
For a 2048×2048 single-channel image:
|
||||
|
||||
| Format | Memory | Notes |
|
||||
|--------|--------|-------|
|
||||
| Original 16-bit | 8 MB | uint16 grayscale |
|
||||
| Float32 grayscale | 16 MB | Intermediate |
|
||||
| Float32 RGB | 48 MB | Final (3 channels) |
|
||||
| uint8 RGB (old) | 12 MB | OLD approach with data loss |
|
||||
| Format | Memory | Disk Space | Notes |
|
||||
|--------|--------|------------|-------|
|
||||
| Original 16-bit | 8 MB | ~8 MB | uint16 grayscale TIFF |
|
||||
| Float32 grayscale | 16 MB | - | Intermediate |
|
||||
| Float32 3-channel | 48 MB | ~48 MB | Training cache |
|
||||
| uint8 RGB (old) | 12 MB | ~12 MB | OLD approach with data loss |
|
||||
|
||||
The float32 approach uses ~4× more memory than uint8 but preserves **all information**.
|
||||
The float32 approach uses ~4× more memory and disk space than uint8 but preserves **all information**.
|
||||
|
||||
**Cache Directory**: Training creates cached datasets in `data/datasets/_float32_cache/<dataset>_<hash>/`
|
||||
|
||||
### Why Direct Numpy Array?
|
||||
|
||||
|
||||
@@ -3,11 +3,14 @@ Training tab for the microscopy object detection application.
|
||||
Handles model training with YOLO.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import tifffile
|
||||
import yaml
|
||||
from PySide6.QtCore import Qt, QThread, Signal
|
||||
from PySide6.QtWidgets import (
|
||||
@@ -946,6 +949,9 @@ class TrainingTab(QWidget):
|
||||
for msg in split_messages:
|
||||
self._append_training_log(msg)
|
||||
|
||||
if dataset_yaml:
|
||||
self._clear_rgb_cache_for_dataset(dataset_yaml)
|
||||
|
||||
def _export_labels_for_split(
|
||||
self,
|
||||
split_name: str,
|
||||
@@ -1165,11 +1171,44 @@ class TrainingTab(QWidget):
|
||||
) -> Path:
|
||||
"""Prepare dataset for training.
|
||||
|
||||
Note: With proper 16-bit TIFF support in YOLOWrapper, we no longer need
|
||||
to create RGB-converted copies of the dataset. Images are handled directly.
|
||||
For 16-bit TIFF files: creates 3-channel float32 TIFF versions for training.
|
||||
This is necessary because YOLO's training dataloader loads images directly
|
||||
from disk and expects 3 channels.
|
||||
"""
|
||||
dataset_info = dataset_info or (
|
||||
self.selected_dataset
|
||||
if self.selected_dataset
|
||||
and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
|
||||
else self._parse_dataset_yaml(dataset_yaml)
|
||||
)
|
||||
|
||||
train_split = dataset_info.get("splits", {}).get("train") or {}
|
||||
images_path_str = train_split.get("path")
|
||||
if not images_path_str:
|
||||
return dataset_yaml
|
||||
|
||||
images_path = Path(images_path_str)
|
||||
if not images_path.exists():
|
||||
return dataset_yaml
|
||||
|
||||
# Check if dataset has 16-bit TIFF files that need 3-channel conversion
|
||||
if not self._dataset_has_16bit_tiff(images_path):
|
||||
return dataset_yaml
|
||||
|
||||
cache_root = self._get_float32_cache_root(dataset_yaml)
|
||||
float32_yaml = cache_root / "data.yaml"
|
||||
if float32_yaml.exists():
|
||||
self._append_training_log(
|
||||
f"Detected 16-bit TIFF dataset; reusing float32 3-channel cache at {cache_root}"
|
||||
)
|
||||
return float32_yaml
|
||||
|
||||
self._append_training_log(
|
||||
f"Detected 16-bit TIFF dataset; creating float32 3-channel cache at {cache_root}"
|
||||
)
|
||||
self._build_float32_dataset(cache_root, dataset_info)
|
||||
return float32_yaml
|
||||
|
||||
def _compose_stage_plan(self, params: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
two_stage = params.get("two_stage") or {}
|
||||
base_stage = {
|
||||
@@ -1254,6 +1293,132 @@ class TrainingTab(QWidget):
|
||||
f" • {stage_label}: epochs={epochs}, lr0={lr0}, patience={patience}, freeze={freeze}"
|
||||
)
|
||||
|
||||
def _get_float32_cache_root(self, dataset_yaml: Path) -> Path:
|
||||
"""Get cache directory for float32 3-channel converted datasets."""
|
||||
cache_base = Path("data/datasets/_float32_cache")
|
||||
cache_base.mkdir(parents=True, exist_ok=True)
|
||||
key = hashlib.md5(str(dataset_yaml.parent.resolve()).encode()).hexdigest()[:8]
|
||||
return cache_base / f"{dataset_yaml.parent.name}_{key}"
|
||||
|
||||
def _clear_rgb_cache_for_dataset(self, dataset_yaml: Path):
|
||||
"""Clear float32 cache for dataset."""
|
||||
cache_root = self._get_float32_cache_root(dataset_yaml)
|
||||
if cache_root.exists():
|
||||
try:
|
||||
shutil.rmtree(cache_root)
|
||||
logger.debug(f"Removed float32 cache at {cache_root}")
|
||||
except OSError as exc:
|
||||
logger.warning(f"Failed to remove float32 cache {cache_root}: {exc}")
|
||||
|
||||
def _dataset_has_16bit_tiff(self, images_dir: Path) -> bool:
|
||||
"""Check if dataset contains 16-bit TIFF files."""
|
||||
sample_image = self._find_first_image(images_dir)
|
||||
if not sample_image:
|
||||
return False
|
||||
try:
|
||||
if sample_image.suffix.lower() not in [".tif", ".tiff"]:
|
||||
return False
|
||||
img = Image(sample_image)
|
||||
return img.dtype == np.uint16
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to inspect image {sample_image}: {exc}")
|
||||
return False
|
||||
|
||||
def _find_first_image(self, directory: Path) -> Optional[Path]:
|
||||
"""Find first image in directory."""
|
||||
if not directory.exists():
|
||||
return None
|
||||
for path in directory.rglob("*"):
|
||||
if path.is_file() and path.suffix.lower() in self.allowed_extensions:
|
||||
return path
|
||||
return None
|
||||
|
||||
def _build_float32_dataset(self, cache_root: Path, dataset_info: Dict[str, Any]):
|
||||
"""Build float32 3-channel version of 16-bit TIFF dataset."""
|
||||
if cache_root.exists():
|
||||
shutil.rmtree(cache_root)
|
||||
cache_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
splits = dataset_info.get("splits", {})
|
||||
for split_name in ("train", "val", "test"):
|
||||
split_entry = splits.get(split_name)
|
||||
if not split_entry:
|
||||
continue
|
||||
images_src = Path(split_entry.get("path", ""))
|
||||
if not images_src.exists():
|
||||
continue
|
||||
images_dst = cache_root / split_name / "images"
|
||||
self._convert_16bit_to_float32_3ch(images_src, images_dst)
|
||||
|
||||
labels_src = self._infer_labels_dir(images_src)
|
||||
if labels_src.exists():
|
||||
labels_dst = cache_root / split_name / "labels"
|
||||
self._copy_labels(labels_src, labels_dst)
|
||||
|
||||
class_names = dataset_info.get("class_names") or []
|
||||
names_map = {idx: name for idx, name in enumerate(class_names)}
|
||||
num_classes = dataset_info.get("num_classes") or len(class_names)
|
||||
|
||||
yaml_payload: Dict[str, Any] = {
|
||||
"path": cache_root.as_posix(),
|
||||
"names": names_map,
|
||||
"nc": num_classes,
|
||||
}
|
||||
|
||||
for split_name in ("train", "val", "test"):
|
||||
images_dir = cache_root / split_name / "images"
|
||||
if images_dir.exists():
|
||||
yaml_payload[split_name] = f"{split_name}/images"
|
||||
|
||||
with open(cache_root / "data.yaml", "w", encoding="utf-8") as handle:
|
||||
yaml.safe_dump(yaml_payload, handle, sort_keys=False)
|
||||
|
||||
def _convert_16bit_to_float32_3ch(self, src_dir: Path, dst_dir: Path):
|
||||
"""Convert 16-bit TIFF images to float32 [0-1] 3-channel TIFFs.
|
||||
|
||||
This preserves the full dynamic range (no uint8 conversion) while
|
||||
creating the 3-channel format that YOLO training expects.
|
||||
"""
|
||||
for src in src_dir.rglob("*"):
|
||||
if not src.is_file() or src.suffix.lower() not in self.allowed_extensions:
|
||||
continue
|
||||
relative = src.relative_to(src_dir)
|
||||
dst = dst_dir / relative.with_suffix(".tif")
|
||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
img_obj = Image(src)
|
||||
|
||||
# Check if it's a 16-bit TIFF
|
||||
is_16bit_tiff = (
|
||||
src.suffix.lower() in [".tif", ".tiff"]
|
||||
and img_obj.dtype == np.uint16
|
||||
)
|
||||
|
||||
if is_16bit_tiff:
|
||||
# Convert to float32 [0-1]
|
||||
float_data = img_obj.to_normalized_float32()
|
||||
|
||||
# Replicate to 3 channels
|
||||
if len(float_data.shape) == 2:
|
||||
# H,W → H,W,3
|
||||
float_3ch = np.stack([float_data] * 3, axis=-1)
|
||||
elif len(float_data.shape) == 3 and float_data.shape[2] == 1:
|
||||
# H,W,1 → H,W,3
|
||||
float_3ch = np.repeat(float_data, 3, axis=2)
|
||||
else:
|
||||
# Already multi-channel
|
||||
float_3ch = float_data
|
||||
|
||||
# Save as float32 TIFF (preserves full precision)
|
||||
tifffile.imwrite(dst, float_3ch.astype(np.float32))
|
||||
logger.debug(f"Converted {src} to float32 3-channel TIFF at {dst}")
|
||||
else:
|
||||
# For non-16-bit images, just copy
|
||||
shutil.copy2(src, dst)
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to convert {src}: {exc}")
|
||||
|
||||
def _copy_labels(self, labels_src: Path, labels_dst: Path):
|
||||
label_files = list(labels_src.rglob("*.txt"))
|
||||
for label_file in label_files:
|
||||
@@ -1350,6 +1515,10 @@ class TrainingTab(QWidget):
|
||||
self._export_labels_from_database(dataset_info)
|
||||
|
||||
dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info)
|
||||
if dataset_to_use != dataset_path:
|
||||
self._append_training_log(
|
||||
f"Using float32 3-channel dataset at {dataset_to_use.parent}"
|
||||
)
|
||||
|
||||
params = self._collect_training_params()
|
||||
stage_plan = self._compose_stage_plan(params)
|
||||
|
||||
1774
tests/test_pyside_freehand_tool
Normal file
1774
tests/test_pyside_freehand_tool
Normal file
File diff suppressed because it is too large
Load Diff
117
tests/test_training_dataset_prep.py
Normal file
117
tests/test_training_dataset_prep.py
Normal file
@@ -0,0 +1,117 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for training dataset preparation with 16-bit TIFFs.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import tifffile
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import sys
|
||||
import os
|
||||
import shutil
|
||||
|
||||
# Add parent directory to path to import modules
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from src.utils.image import Image
|
||||
|
||||
|
||||
def test_float32_3ch_conversion():
|
||||
"""Test conversion of 16-bit TIFF to float32 3-channel TIFF."""
|
||||
print("\n=== Testing Float32 3-Channel Conversion ===")
|
||||
|
||||
# Create temporary directory structure
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir = Path(tmpdir)
|
||||
src_dir = tmpdir / "original"
|
||||
dst_dir = tmpdir / "converted"
|
||||
src_dir.mkdir()
|
||||
dst_dir.mkdir()
|
||||
|
||||
# Create test 16-bit TIFF
|
||||
test_data = np.zeros((100, 100), dtype=np.uint16)
|
||||
for i in range(100):
|
||||
for j in range(100):
|
||||
test_data[i, j] = int((i + j) / 198 * 65535)
|
||||
|
||||
test_file = src_dir / "test_16bit.tif"
|
||||
tifffile.imwrite(test_file, test_data)
|
||||
print(f"Created test 16-bit TIFF: {test_file}")
|
||||
print(f" Shape: {test_data.shape}")
|
||||
print(f" Dtype: {test_data.dtype}")
|
||||
print(f" Range: [{test_data.min()}, {test_data.max()}]")
|
||||
|
||||
# Simulate the conversion process
|
||||
print("\nConverting to float32 3-channel...")
|
||||
img_obj = Image(test_file)
|
||||
|
||||
# Convert to float32 [0-1]
|
||||
float_data = img_obj.to_normalized_float32()
|
||||
|
||||
# Replicate to 3 channels
|
||||
if len(float_data.shape) == 2:
|
||||
float_3ch = np.stack([float_data] * 3, axis=-1)
|
||||
else:
|
||||
float_3ch = float_data
|
||||
|
||||
# Save as float32 TIFF
|
||||
output_file = dst_dir / "test_float32_3ch.tif"
|
||||
tifffile.imwrite(output_file, float_3ch.astype(np.float32))
|
||||
print(f"Saved float32 3-channel TIFF: {output_file}")
|
||||
|
||||
# Verify the output
|
||||
loaded = tifffile.imread(output_file)
|
||||
print(f"\nVerifying output:")
|
||||
print(f" Shape: {loaded.shape}")
|
||||
print(f" Dtype: {loaded.dtype}")
|
||||
print(f" Channels: {loaded.shape[2] if len(loaded.shape) == 3 else 1}")
|
||||
print(f" Range: [{loaded.min():.6f}, {loaded.max():.6f}]")
|
||||
print(f" Unique values: {len(np.unique(loaded[:,:,0]))}")
|
||||
|
||||
# Assertions
|
||||
assert loaded.dtype == np.float32, f"Expected float32, got {loaded.dtype}"
|
||||
assert loaded.shape[2] == 3, f"Expected 3 channels, got {loaded.shape[2]}"
|
||||
assert (
|
||||
0.0 <= loaded.min() <= loaded.max() <= 1.0
|
||||
), f"Expected [0,1] range, got [{loaded.min()}, {loaded.max()}]"
|
||||
|
||||
# Verify all channels are identical (replicated grayscale)
|
||||
assert np.array_equal(
|
||||
loaded[:, :, 0], loaded[:, :, 1]
|
||||
), "Channel 0 and 1 should be identical"
|
||||
assert np.array_equal(
|
||||
loaded[:, :, 0], loaded[:, :, 2]
|
||||
), "Channel 0 and 2 should be identical"
|
||||
|
||||
# Verify float32 precision (not quantized to uint8 steps)
|
||||
unique_vals = len(np.unique(loaded[:, :, 0]))
|
||||
print(f"\n Precision check:")
|
||||
print(f" Unique values in channel: {unique_vals}")
|
||||
print(f" Source unique values: {len(np.unique(test_data))}")
|
||||
|
||||
# The final unique values should match source (no loss from conversion)
|
||||
assert unique_vals == len(
|
||||
np.unique(test_data)
|
||||
), f"Expected {len(np.unique(test_data))} unique values, got {unique_vals}"
|
||||
|
||||
print("\n✓ All conversion tests passed!")
|
||||
print(" - Float32 dtype preserved")
|
||||
print(" - 3 channels created")
|
||||
print(" - Range [0-1] maintained")
|
||||
print(" - No precision loss from conversion")
|
||||
print(" - Channels properly replicated")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
success = test_float32_3ch_conversion()
|
||||
sys.exit(0 if success else 1)
|
||||
except Exception as e:
|
||||
print(f"\n✗ Test failed with error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
Reference in New Issue
Block a user