Adding standalone training script and update
This commit is contained in:
@@ -10,7 +10,7 @@ This document describes the implementation of 16-bit grayscale TIFF support for
|
|||||||
✅ Converts to float32 [0-1] (NO uint8 conversion)
|
✅ Converts to float32 [0-1] (NO uint8 conversion)
|
||||||
✅ Replicates grayscale → RGB (3 channels)
|
✅ Replicates grayscale → RGB (3 channels)
|
||||||
✅ **Inference**: Passes numpy arrays directly to YOLO (no file I/O)
|
✅ **Inference**: Passes numpy arrays directly to YOLO (no file I/O)
|
||||||
✅ **Training**: Creates float32 3-channel TIFF dataset cache
|
✅ **Training**: On-the-fly float32 conversion (NO disk caching)
|
||||||
✅ Uses Ultralytics YOLOv8/v11 models
|
✅ Uses Ultralytics YOLOv8/v11 models
|
||||||
✅ Works with segmentation models
|
✅ Works with segmentation models
|
||||||
✅ No data loss, no double normalization, no silent clipping
|
✅ No data loss, no double normalization, no silent clipping
|
||||||
@@ -65,18 +65,18 @@ For 16-bit TIFF files during inference:
|
|||||||
|
|
||||||
### For Training (train)
|
### For Training (train)
|
||||||
|
|
||||||
During training, YOLO's internal dataloader loads images from disk, so we create a cached 3-channel dataset:
|
Training now uses a custom dataset loader with on-the-fly conversion (NO disk caching):
|
||||||
|
|
||||||
1. **Detect**: Check if dataset contains 16-bit TIFF files
|
1. **Custom Dataset**: Uses `Float32Dataset` class that extends Ultralytics' `YOLODataset`
|
||||||
2. **Create Cache**: Build float32 3-channel TIFF dataset in `data/datasets/_float32_cache/`
|
2. **Load On-The-Fly**: Each image is loaded and converted during training:
|
||||||
3. **Convert Each Image**:
|
- Detect 16-bit TIFF files automatically
|
||||||
- Load 16-bit TIFF using `tifffile`
|
- Load with `tifffile` (preserves uint16)
|
||||||
- Normalize to float32 [0-1]
|
- Convert to float32 [0-1] in memory
|
||||||
- Replicate to 3 channels
|
- Replicate to 3 channels (RGB)
|
||||||
- Save as float32 TIFF (preserves precision)
|
3. **No Disk Cache**: Conversion happens in memory, no files written
|
||||||
4. **Copy Labels**: Copy label files unchanged
|
4. **Train**: YOLO trains on float32 [0-1] RGB arrays directly
|
||||||
5. **Generate data.yaml**: Points to cached 3-channel dataset
|
|
||||||
6. **Train**: YOLO trains on float32 3-channel TIFFs
|
See [`src/utils/train_ultralytics_float.py`](../src/utils/train_ultralytics_float.py) for implementation.
|
||||||
|
|
||||||
### No Data Loss!
|
### No Data Loss!
|
||||||
|
|
||||||
@@ -214,9 +214,9 @@ For a 2048×2048 single-channel image:
|
|||||||
| Float32 3-channel | 48 MB | ~48 MB | Training cache |
|
| Float32 3-channel | 48 MB | ~48 MB | Training cache |
|
||||||
| uint8 RGB (old) | 12 MB | ~12 MB | OLD approach with data loss |
|
| uint8 RGB (old) | 12 MB | ~12 MB | OLD approach with data loss |
|
||||||
|
|
||||||
The float32 approach uses ~4× more memory and disk space than uint8 but preserves **all information**.
|
The float32 approach uses ~3× more memory than uint8 during training but preserves **all information**.
|
||||||
|
|
||||||
**Cache Directory**: Training creates cached datasets in `data/datasets/_float32_cache/<dataset>_<hash>/`
|
**No Disk Cache**: The new on-the-fly approach eliminates the need for cached datasets on disk.
|
||||||
|
|
||||||
### Why Direct Numpy Array?
|
### Why Direct Numpy Array?
|
||||||
|
|
||||||
@@ -233,50 +233,31 @@ Ultralytics YOLO supports various input types:
|
|||||||
- PIL Images: `PIL.Image`
|
- PIL Images: `PIL.Image`
|
||||||
- Torch tensors: `torch.Tensor`
|
- Torch tensors: `torch.Tensor`
|
||||||
|
|
||||||
## For Training with Custom Dataset
|
## Training with Float32 Dataset Loader
|
||||||
|
|
||||||
If you need to train YOLO on 16-bit TIFF images, you should create a custom dataset loader similar to the example provided by the user:
|
The system now includes a custom dataset loader for 16-bit TIFF training:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import torch
|
from src.utils.train_ultralytics_float import train_with_float32_loader
|
||||||
import numpy as np
|
|
||||||
import tifffile as tiff
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
class FloatYoloSegDataset(torch.utils.data.Dataset):
|
# Train with on-the-fly float32 conversion
|
||||||
def __init__(self, img_dir, label_dir, img_size=640):
|
results = train_with_float32_loader(
|
||||||
self.img_paths = sorted(Path(img_dir).glob('*'))
|
model_path="yolov8s-seg.pt",
|
||||||
self.label_dir = Path(label_dir)
|
data_yaml="data/my_dataset/data.yaml",
|
||||||
self.img_size = img_size
|
epochs=100,
|
||||||
|
batch=16,
|
||||||
def __len__(self):
|
imgsz=640,
|
||||||
return len(self.img_paths)
|
)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
img_path = self.img_paths[idx]
|
|
||||||
|
|
||||||
# Load 16-bit TIFF
|
|
||||||
img = tiff.imread(img_path)
|
|
||||||
|
|
||||||
# Convert to float32 [0-1]
|
|
||||||
img = img.astype(np.float32)
|
|
||||||
if img.max() > 1.5: # Assume 16-bit if max > 1.5
|
|
||||||
img /= 65535.0
|
|
||||||
|
|
||||||
# Grayscale → RGB
|
|
||||||
if img.ndim == 2:
|
|
||||||
img = np.repeat(img[..., None], 3, axis=2)
|
|
||||||
|
|
||||||
# HWC → CHW for PyTorch
|
|
||||||
img = torch.from_numpy(img).permute(2, 0, 1).contiguous()
|
|
||||||
|
|
||||||
# Load labels...
|
|
||||||
# (implementation depends on your label format)
|
|
||||||
|
|
||||||
return img, labels
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Then use this dataset with Ultralytics training API or custom training loop.
|
The `Float32Dataset` class automatically:
|
||||||
|
- Detects 16-bit TIFF files
|
||||||
|
- Loads with `tifffile` (not PIL/cv2)
|
||||||
|
- Converts to float32 [0-1] on-the-fly
|
||||||
|
- Replicates to 3 channels
|
||||||
|
- Integrates seamlessly with Ultralytics training pipeline
|
||||||
|
|
||||||
|
This is used automatically by the training tab in the GUI.
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
|
|||||||
179
scripts/README_FLOAT32_TRAINING.md
Normal file
179
scripts/README_FLOAT32_TRAINING.md
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
# Standalone Float32 Training Script for 16-bit TIFFs
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This standalone script (`train_float32_standalone.py`) trains YOLO models on 16-bit grayscale TIFF datasets with **no data loss**.
|
||||||
|
|
||||||
|
- Loads 16-bit TIFFs with `tifffile` (not PIL/cv2)
|
||||||
|
- Converts to float32 [0-1] on-the-fly (preserves full 16-bit precision)
|
||||||
|
- Replicates grayscale → 3-channel RGB in memory
|
||||||
|
- **No disk caching required**
|
||||||
|
- Uses custom PyTorch Dataset + training loop
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Activate virtual environment
|
||||||
|
source venv/bin/activate
|
||||||
|
|
||||||
|
# Train on your 16-bit TIFF dataset
|
||||||
|
python scripts/train_float32_standalone.py \
|
||||||
|
--data data/my_dataset/data.yaml \
|
||||||
|
--weights yolov8s-seg.pt \
|
||||||
|
--epochs 100 \
|
||||||
|
--batch 16 \
|
||||||
|
--imgsz 640 \
|
||||||
|
--lr 0.0001 \
|
||||||
|
--save-dir runs/my_training \
|
||||||
|
--device cuda
|
||||||
|
```
|
||||||
|
|
||||||
|
## Arguments
|
||||||
|
|
||||||
|
| Argument | Required | Default | Description |
|
||||||
|
|----------|----------|---------|-------------|
|
||||||
|
| `--data` | Yes | - | Path to YOLO data.yaml file |
|
||||||
|
| `--weights` | No | yolov8s-seg.pt | Pretrained model weights |
|
||||||
|
| `--epochs` | No | 100 | Number of training epochs |
|
||||||
|
| `--batch` | No | 16 | Batch size |
|
||||||
|
| `--imgsz` | No | 640 | Input image size |
|
||||||
|
| `--lr` | No | 0.0001 | Learning rate |
|
||||||
|
| `--save-dir` | No | runs/train | Directory to save checkpoints |
|
||||||
|
| `--device` | No | cuda/cpu | Training device (auto-detected) |
|
||||||
|
|
||||||
|
## Dataset Format
|
||||||
|
|
||||||
|
Your data.yaml should follow standard YOLO format:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
path: /path/to/dataset
|
||||||
|
train: train/images
|
||||||
|
val: val/images
|
||||||
|
test: test/images # optional
|
||||||
|
|
||||||
|
names:
|
||||||
|
0: class1
|
||||||
|
1: class2
|
||||||
|
|
||||||
|
nc: 2
|
||||||
|
```
|
||||||
|
|
||||||
|
Directory structure:
|
||||||
|
```
|
||||||
|
dataset/
|
||||||
|
├── train/
|
||||||
|
│ ├── images/
|
||||||
|
│ │ ├── img1.tif (16-bit grayscale TIFF)
|
||||||
|
│ │ └── img2.tif
|
||||||
|
│ └── labels/
|
||||||
|
│ ├── img1.txt (YOLO format)
|
||||||
|
│ └── img2.txt
|
||||||
|
├── val/
|
||||||
|
│ ├── images/
|
||||||
|
│ └── labels/
|
||||||
|
└── data.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
## Output
|
||||||
|
|
||||||
|
The script saves:
|
||||||
|
- `epoch{N}.pt`: Checkpoint after each epoch
|
||||||
|
- `best.pt`: Best model weights (lowest loss)
|
||||||
|
- Training logs to console
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
✅ **16-bit precision preserved**: Float32 [0-1] maintains full dynamic range
|
||||||
|
✅ **No disk caching**: Conversion happens in memory
|
||||||
|
✅ **No PIL/cv2**: Direct tifffile loading
|
||||||
|
✅ **Variable-length labels**: Handles segmentation polygons
|
||||||
|
✅ **Checkpoint saving**: Resume training if interrupted
|
||||||
|
✅ **Best model tracking**: Automatically saves best weights
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
Train a segmentation model on microscopy data:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/train_float32_standalone.py \
|
||||||
|
--data data/microscopy/data.yaml \
|
||||||
|
--weights yolov11s-seg.pt \
|
||||||
|
--epochs 150 \
|
||||||
|
--batch 8 \
|
||||||
|
--imgsz 1024 \
|
||||||
|
--lr 0.0003 \
|
||||||
|
--save-dir data/models/microscopy_v1
|
||||||
|
```
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Out of Memory (OOM)
|
||||||
|
Reduce batch size:
|
||||||
|
```bash
|
||||||
|
--batch 4
|
||||||
|
```
|
||||||
|
|
||||||
|
### Slow Loading
|
||||||
|
Reduce num_workers (edit script line 208):
|
||||||
|
```python
|
||||||
|
num_workers=2 # instead of 4
|
||||||
|
```
|
||||||
|
|
||||||
|
### Different Image Sizes
|
||||||
|
The script expects all images to have the same dimensions. For variable sizes:
|
||||||
|
1. Implement letterbox/resize in dataset's `_read_image()`
|
||||||
|
2. Or preprocess images to same size
|
||||||
|
|
||||||
|
### Loss Computation Errors
|
||||||
|
If you see "Cannot determine loss", the script may need adjustment for your Ultralytics version. Check:
|
||||||
|
```python
|
||||||
|
# In train() function, the preds format may vary
|
||||||
|
# Current script assumes: preds is tuple with loss OR dict with 'loss' key
|
||||||
|
```
|
||||||
|
|
||||||
|
## vs GUI Training
|
||||||
|
|
||||||
|
| Feature | Standalone Script | GUI Training Tab |
|
||||||
|
|---------|------------------|------------------|
|
||||||
|
| Float32 conversion | ✓ Yes | ✓ Yes (automatic) |
|
||||||
|
| Disk caching | ✗ None | ✗ None |
|
||||||
|
| Progress UI | ✗ Console only | ✓ Visual progress bar |
|
||||||
|
| Dataset selection | Manual CLI args | ✓ GUI browsing |
|
||||||
|
| Multi-stage training | Manual runs | ✓ Built-in |
|
||||||
|
| Use case | Advanced users | General users |
|
||||||
|
|
||||||
|
## Technical Details
|
||||||
|
|
||||||
|
### Data Loading Pipeline
|
||||||
|
|
||||||
|
```
|
||||||
|
16-bit TIFF file
|
||||||
|
↓ (tifffile.imread)
|
||||||
|
uint16 [0-65535]
|
||||||
|
↓ (/ 65535.0)
|
||||||
|
float32 [0-1]
|
||||||
|
↓ (replicate channels)
|
||||||
|
float32 RGB (H,W,3) [0-1]
|
||||||
|
↓ (permute to C,H,W)
|
||||||
|
torch.Tensor (3,H,W) float32
|
||||||
|
↓ (DataLoader stack)
|
||||||
|
Batch (B,3,H,W) float32
|
||||||
|
↓
|
||||||
|
YOLO Model
|
||||||
|
```
|
||||||
|
|
||||||
|
### Precision Comparison
|
||||||
|
|
||||||
|
| Method | Unique Values | Data Loss |
|
||||||
|
|--------|---------------|-----------|
|
||||||
|
| **float32 [0-1]** | ~65,536 | None ✓ |
|
||||||
|
| uint16 RGB | 65,536 | None ✓ |
|
||||||
|
| uint8 | 256 | 99.6% ✗ |
|
||||||
|
|
||||||
|
Example: Pixel value 32,768 (middle intensity)
|
||||||
|
- Float32: 32768 / 65535.0 = 0.50000763 (exact)
|
||||||
|
- uint8: 32768 → 128 → many values collapse!
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
Same as main project.
|
||||||
349
scripts/train_float32_standalone.py
Executable file
349
scripts/train_float32_standalone.py
Executable file
@@ -0,0 +1,349 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Standalone training script for YOLO with 16-bit TIFF float32 support.
|
||||||
|
|
||||||
|
This script trains YOLO models on 16-bit grayscale TIFF datasets without data loss.
|
||||||
|
Converts images to float32 [0-1] on-the-fly using tifffile (no PIL/cv2).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python scripts/train_float32_standalone.py \\
|
||||||
|
--data path/to/data.yaml \\
|
||||||
|
--weights yolov8s-seg.pt \\
|
||||||
|
--epochs 100 \\
|
||||||
|
--batch 16 \\
|
||||||
|
--imgsz 640
|
||||||
|
|
||||||
|
Based on the custom dataset approach to avoid Ultralytics' channel conversion issues.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import tifffile
|
||||||
|
import yaml
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from ultralytics import YOLO
|
||||||
|
|
||||||
|
# Add project root to path
|
||||||
|
project_root = Path(__file__).parent.parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
from src.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ===================== Dataset =====================
|
||||||
|
|
||||||
|
|
||||||
|
class Float32YOLODataset(Dataset):
|
||||||
|
"""PyTorch dataset for 16-bit TIFF images with float32 conversion."""
|
||||||
|
|
||||||
|
def __init__(self, images_dir, labels_dir, img_size=640):
|
||||||
|
self.images_dir = Path(images_dir)
|
||||||
|
self.labels_dir = Path(labels_dir)
|
||||||
|
self.img_size = img_size
|
||||||
|
|
||||||
|
# Find images
|
||||||
|
extensions = {".tif", ".tiff", ".png", ".jpg", ".jpeg", ".bmp"}
|
||||||
|
self.paths = sorted(
|
||||||
|
[
|
||||||
|
p
|
||||||
|
for p in self.images_dir.rglob("*")
|
||||||
|
if p.is_file() and p.suffix.lower() in extensions
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.paths:
|
||||||
|
raise ValueError(f"No images found in {images_dir}")
|
||||||
|
|
||||||
|
logger.info(f"Dataset: {len(self.paths)} images from {images_dir}")
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.paths)
|
||||||
|
|
||||||
|
def _read_image(self, path: Path) -> np.ndarray:
|
||||||
|
"""Load image as float32 [0-1] RGB."""
|
||||||
|
# Load with tifffile
|
||||||
|
img = tifffile.imread(str(path))
|
||||||
|
|
||||||
|
# Convert to float32
|
||||||
|
img = img.astype(np.float32)
|
||||||
|
|
||||||
|
# Normalize 16-bit→[0,1]
|
||||||
|
if img.max() > 1.5:
|
||||||
|
img = img / 65535.0
|
||||||
|
|
||||||
|
img = np.clip(img, 0.0, 1.0)
|
||||||
|
|
||||||
|
# Grayscale→RGB
|
||||||
|
if img.ndim == 2:
|
||||||
|
img = np.repeat(img[..., None], 3, axis=2)
|
||||||
|
elif img.ndim == 3 and img.shape[2] == 1:
|
||||||
|
img = np.repeat(img, 3, axis=2)
|
||||||
|
|
||||||
|
return img # float32 (H,W,3) [0,1]
|
||||||
|
|
||||||
|
def _parse_label(self, path: Path) -> np.ndarray:
|
||||||
|
"""Parse YOLO label with variable-length rows."""
|
||||||
|
if not path.exists():
|
||||||
|
return np.zeros((0, 5), dtype=np.float32)
|
||||||
|
|
||||||
|
labels = []
|
||||||
|
with open(path, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
vals = line.strip().split()
|
||||||
|
if len(vals) >= 5:
|
||||||
|
labels.append([float(v) for v in vals])
|
||||||
|
|
||||||
|
return (
|
||||||
|
np.array(labels, dtype=np.float32)
|
||||||
|
if labels
|
||||||
|
else np.zeros((0, 5), dtype=np.float32)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
img_path = self.paths[idx]
|
||||||
|
label_path = self.labels_dir / f"{img_path.stem}.txt"
|
||||||
|
|
||||||
|
# Load & convert to tensor (C,H,W)
|
||||||
|
img = self._read_image(img_path)
|
||||||
|
img_t = torch.from_numpy(img).permute(2, 0, 1).contiguous()
|
||||||
|
|
||||||
|
# Load labels
|
||||||
|
labels = self._parse_label(label_path)
|
||||||
|
|
||||||
|
return img_t, labels, str(img_path.name)
|
||||||
|
|
||||||
|
|
||||||
|
# ===================== Collate =====================
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(batch):
|
||||||
|
"""Stack images, keep labels as list."""
|
||||||
|
imgs = torch.stack([b[0] for b in batch], dim=0)
|
||||||
|
labels = [b[1] for b in batch]
|
||||||
|
names = [b[2] for b in batch]
|
||||||
|
return imgs, labels, names
|
||||||
|
|
||||||
|
|
||||||
|
# ===================== Training =====================
|
||||||
|
|
||||||
|
|
||||||
|
def get_pytorch_model(ul_model):
|
||||||
|
"""Extract PyTorch model and loss from Ultralytics wrapper."""
|
||||||
|
pt_model = None
|
||||||
|
loss_fn = None
|
||||||
|
|
||||||
|
# Try common patterns
|
||||||
|
if hasattr(ul_model, "model"):
|
||||||
|
pt_model = ul_model.model
|
||||||
|
if pt_model and hasattr(pt_model, "model"):
|
||||||
|
pt_model = pt_model.model
|
||||||
|
|
||||||
|
# Find loss
|
||||||
|
if pt_model and hasattr(pt_model, "loss"):
|
||||||
|
loss_fn = pt_model.loss
|
||||||
|
elif pt_model and hasattr(pt_model, "compute_loss"):
|
||||||
|
loss_fn = pt_model.compute_loss
|
||||||
|
|
||||||
|
if pt_model is None:
|
||||||
|
raise RuntimeError("Could not extract PyTorch model")
|
||||||
|
|
||||||
|
return pt_model, loss_fn
|
||||||
|
|
||||||
|
|
||||||
|
def train(args):
|
||||||
|
"""Main training function."""
|
||||||
|
device = args.device
|
||||||
|
logger.info(f"Device: {device}")
|
||||||
|
|
||||||
|
# Parse data.yaml
|
||||||
|
with open(args.data, "r") as f:
|
||||||
|
data_config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
dataset_root = Path(data_config.get("path", Path(args.data).parent))
|
||||||
|
train_img = dataset_root / data_config.get("train", "train/images")
|
||||||
|
val_img = dataset_root / data_config.get("val", "val/images")
|
||||||
|
train_lbl = train_img.parent / "labels"
|
||||||
|
val_lbl = val_img.parent / "labels"
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
logger.info(f"Loading {args.weights}")
|
||||||
|
ul_model = YOLO(args.weights)
|
||||||
|
pt_model, loss_fn = get_pytorch_model(ul_model)
|
||||||
|
|
||||||
|
# Configure model args
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
if not hasattr(pt_model, "args"):
|
||||||
|
pt_model.args = SimpleNamespace()
|
||||||
|
if isinstance(pt_model.args, dict):
|
||||||
|
pt_model.args = SimpleNamespace(**pt_model.args)
|
||||||
|
|
||||||
|
# Set segmentation loss args
|
||||||
|
pt_model.args.overlap_mask = getattr(pt_model.args, "overlap_mask", True)
|
||||||
|
pt_model.args.mask_ratio = getattr(pt_model.args, "mask_ratio", 4)
|
||||||
|
pt_model.args.task = "segment"
|
||||||
|
|
||||||
|
pt_model.to(device)
|
||||||
|
pt_model.train()
|
||||||
|
|
||||||
|
# Create datasets
|
||||||
|
train_ds = Float32YOLODataset(str(train_img), str(train_lbl), args.imgsz)
|
||||||
|
val_ds = Float32YOLODataset(str(val_img), str(val_lbl), args.imgsz)
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_ds,
|
||||||
|
batch_size=args.batch,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=4,
|
||||||
|
pin_memory=(device == "cuda"),
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
)
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_ds,
|
||||||
|
batch_size=args.batch,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=2,
|
||||||
|
pin_memory=(device == "cuda"),
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optimizer
|
||||||
|
optimizer = torch.optim.AdamW(pt_model.parameters(), lr=args.lr)
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
os.makedirs(args.save_dir, exist_ok=True)
|
||||||
|
best_loss = float("inf")
|
||||||
|
|
||||||
|
for epoch in range(args.epochs):
|
||||||
|
t0 = time.time()
|
||||||
|
running_loss = 0.0
|
||||||
|
num_batches = 0
|
||||||
|
|
||||||
|
for imgs, labels_list, names in train_loader:
|
||||||
|
imgs = imgs.to(device)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
num_batches += 1
|
||||||
|
|
||||||
|
# Forward (simple approach - just use preds)
|
||||||
|
preds = pt_model(imgs)
|
||||||
|
|
||||||
|
# Try to compute loss
|
||||||
|
# Simplest fallback: if preds is tuple/list, assume last element is loss
|
||||||
|
if isinstance(preds, (tuple, list)):
|
||||||
|
# Often YOLO forward returns (preds, loss) in training mode
|
||||||
|
if (
|
||||||
|
len(preds) >= 2
|
||||||
|
and isinstance(preds[-1], dict)
|
||||||
|
and "loss" in preds[-1]
|
||||||
|
):
|
||||||
|
loss = preds[-1]["loss"]
|
||||||
|
elif len(preds) >= 2 and isinstance(preds[-1], torch.Tensor):
|
||||||
|
loss = preds[-1]
|
||||||
|
else:
|
||||||
|
# Manually compute using loss_fn if available
|
||||||
|
if loss_fn:
|
||||||
|
# This may fail - see logs
|
||||||
|
try:
|
||||||
|
loss_out = loss_fn(preds, labels_list)
|
||||||
|
loss = (
|
||||||
|
loss_out[0]
|
||||||
|
if isinstance(loss_out, (tuple, list))
|
||||||
|
else loss_out
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Loss computation failed: {e}")
|
||||||
|
logger.error(
|
||||||
|
"Consider using Ultralytics .train() or check model/loss compatibility"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Cannot determine loss from model output")
|
||||||
|
elif isinstance(preds, dict) and "loss" in preds:
|
||||||
|
loss = preds["loss"]
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unexpected preds format: {type(preds)}")
|
||||||
|
|
||||||
|
# Backward
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
running_loss += loss.item()
|
||||||
|
|
||||||
|
if (num_batches % 10) == 0:
|
||||||
|
logger.info(
|
||||||
|
f"Epoch {epoch+1} Batch {num_batches} Loss: {loss.item():.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
epoch_loss = running_loss / max(1, num_batches)
|
||||||
|
epoch_time = time.time() - t0
|
||||||
|
logger.info(
|
||||||
|
f"Epoch {epoch+1}/{args.epochs} - Loss: {epoch_loss:.4f}, Time: {epoch_time:.1f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save checkpoint
|
||||||
|
ckpt = Path(args.save_dir) / f"epoch{epoch+1}.pt"
|
||||||
|
torch.save(
|
||||||
|
{
|
||||||
|
"epoch": epoch + 1,
|
||||||
|
"model_state_dict": pt_model.state_dict(),
|
||||||
|
"optimizer_state_dict": optimizer.state_dict(),
|
||||||
|
"loss": epoch_loss,
|
||||||
|
},
|
||||||
|
ckpt,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save best
|
||||||
|
if epoch_loss < best_loss:
|
||||||
|
best_loss = epoch_loss
|
||||||
|
best_ckpt = Path(args.save_dir) / "best.pt"
|
||||||
|
torch.save(pt_model.state_dict(), best_ckpt)
|
||||||
|
logger.info(f"New best: {best_ckpt}")
|
||||||
|
|
||||||
|
logger.info("Training complete")
|
||||||
|
|
||||||
|
|
||||||
|
# ===================== Main =====================
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Train YOLO on 16-bit TIFF with float32"
|
||||||
|
)
|
||||||
|
parser.add_argument("--data", type=str, required=True, help="Path to data.yaml")
|
||||||
|
parser.add_argument(
|
||||||
|
"--weights", type=str, default="yolov8s-seg.pt", help="Pretrained weights"
|
||||||
|
)
|
||||||
|
parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
|
||||||
|
parser.add_argument("--batch", type=int, default=16, help="Batch size")
|
||||||
|
parser.add_argument("--imgsz", type=int, default=640, help="Image size")
|
||||||
|
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-dir", type=str, default="runs/train", help="Save directory"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
logger.info("=" * 70)
|
||||||
|
logger.info("Float32 16-bit TIFF Training - Standalone Script")
|
||||||
|
logger.info("=" * 70)
|
||||||
|
logger.info(f"Data: {args.data}")
|
||||||
|
logger.info(f"Weights: {args.weights}")
|
||||||
|
logger.info(f"Epochs: {args.epochs}, Batch: {args.batch}, ImgSz: {args.imgsz}")
|
||||||
|
logger.info(f"LR: {args.lr}, Device: {args.device}")
|
||||||
|
logger.info("=" * 70)
|
||||||
|
|
||||||
|
train(args)
|
||||||
@@ -3,14 +3,11 @@ Training tab for the microscopy object detection application.
|
|||||||
Handles model training with YOLO.
|
Handles model training with YOLO.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import shutil
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tifffile
|
|
||||||
import yaml
|
import yaml
|
||||||
from PySide6.QtCore import Qt, QThread, Signal
|
from PySide6.QtCore import Qt, QThread, Signal
|
||||||
from PySide6.QtWidgets import (
|
from PySide6.QtWidgets import (
|
||||||
@@ -949,9 +946,6 @@ class TrainingTab(QWidget):
|
|||||||
for msg in split_messages:
|
for msg in split_messages:
|
||||||
self._append_training_log(msg)
|
self._append_training_log(msg)
|
||||||
|
|
||||||
if dataset_yaml:
|
|
||||||
self._clear_rgb_cache_for_dataset(dataset_yaml)
|
|
||||||
|
|
||||||
def _export_labels_for_split(
|
def _export_labels_for_split(
|
||||||
self,
|
self,
|
||||||
split_name: str,
|
split_name: str,
|
||||||
@@ -1166,49 +1160,6 @@ class TrainingTab(QWidget):
|
|||||||
return 1.0
|
return 1.0
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def _prepare_dataset_for_training(
|
|
||||||
self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None
|
|
||||||
) -> Path:
|
|
||||||
"""Prepare dataset for training.
|
|
||||||
|
|
||||||
For 16-bit TIFF files: creates 3-channel float32 TIFF versions for training.
|
|
||||||
This is necessary because YOLO's training dataloader loads images directly
|
|
||||||
from disk and expects 3 channels.
|
|
||||||
"""
|
|
||||||
dataset_info = dataset_info or (
|
|
||||||
self.selected_dataset
|
|
||||||
if self.selected_dataset
|
|
||||||
and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
|
|
||||||
else self._parse_dataset_yaml(dataset_yaml)
|
|
||||||
)
|
|
||||||
|
|
||||||
train_split = dataset_info.get("splits", {}).get("train") or {}
|
|
||||||
images_path_str = train_split.get("path")
|
|
||||||
if not images_path_str:
|
|
||||||
return dataset_yaml
|
|
||||||
|
|
||||||
images_path = Path(images_path_str)
|
|
||||||
if not images_path.exists():
|
|
||||||
return dataset_yaml
|
|
||||||
|
|
||||||
# Check if dataset has 16-bit TIFF files that need 3-channel conversion
|
|
||||||
if not self._dataset_has_16bit_tiff(images_path):
|
|
||||||
return dataset_yaml
|
|
||||||
|
|
||||||
cache_root = self._get_float32_cache_root(dataset_yaml)
|
|
||||||
float32_yaml = cache_root / "data.yaml"
|
|
||||||
if float32_yaml.exists():
|
|
||||||
self._append_training_log(
|
|
||||||
f"Detected 16-bit TIFF dataset; reusing float32 3-channel cache at {cache_root}"
|
|
||||||
)
|
|
||||||
return float32_yaml
|
|
||||||
|
|
||||||
self._append_training_log(
|
|
||||||
f"Detected 16-bit TIFF dataset; creating float32 3-channel cache at {cache_root}"
|
|
||||||
)
|
|
||||||
self._build_float32_dataset(cache_root, dataset_info)
|
|
||||||
return float32_yaml
|
|
||||||
|
|
||||||
def _compose_stage_plan(self, params: Dict[str, Any]) -> List[Dict[str, Any]]:
|
def _compose_stage_plan(self, params: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||||
two_stage = params.get("two_stage") or {}
|
two_stage = params.get("two_stage") or {}
|
||||||
base_stage = {
|
base_stage = {
|
||||||
@@ -1293,140 +1244,6 @@ class TrainingTab(QWidget):
|
|||||||
f" • {stage_label}: epochs={epochs}, lr0={lr0}, patience={patience}, freeze={freeze}"
|
f" • {stage_label}: epochs={epochs}, lr0={lr0}, patience={patience}, freeze={freeze}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_float32_cache_root(self, dataset_yaml: Path) -> Path:
|
|
||||||
"""Get cache directory for float32 3-channel converted datasets."""
|
|
||||||
cache_base = Path("data/datasets/_float32_cache")
|
|
||||||
cache_base.mkdir(parents=True, exist_ok=True)
|
|
||||||
key = hashlib.md5(str(dataset_yaml.parent.resolve()).encode()).hexdigest()[:8]
|
|
||||||
return cache_base / f"{dataset_yaml.parent.name}_{key}"
|
|
||||||
|
|
||||||
def _clear_rgb_cache_for_dataset(self, dataset_yaml: Path):
|
|
||||||
"""Clear float32 cache for dataset."""
|
|
||||||
cache_root = self._get_float32_cache_root(dataset_yaml)
|
|
||||||
if cache_root.exists():
|
|
||||||
try:
|
|
||||||
shutil.rmtree(cache_root)
|
|
||||||
logger.debug(f"Removed float32 cache at {cache_root}")
|
|
||||||
except OSError as exc:
|
|
||||||
logger.warning(f"Failed to remove float32 cache {cache_root}: {exc}")
|
|
||||||
|
|
||||||
def _dataset_has_16bit_tiff(self, images_dir: Path) -> bool:
|
|
||||||
"""Check if dataset contains 16-bit TIFF files."""
|
|
||||||
sample_image = self._find_first_image(images_dir)
|
|
||||||
if not sample_image:
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
if sample_image.suffix.lower() not in [".tif", ".tiff"]:
|
|
||||||
return False
|
|
||||||
img = Image(sample_image)
|
|
||||||
return img.dtype == np.uint16
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(f"Failed to inspect image {sample_image}: {exc}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _find_first_image(self, directory: Path) -> Optional[Path]:
|
|
||||||
"""Find first image in directory."""
|
|
||||||
if not directory.exists():
|
|
||||||
return None
|
|
||||||
for path in directory.rglob("*"):
|
|
||||||
if path.is_file() and path.suffix.lower() in self.allowed_extensions:
|
|
||||||
return path
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _build_float32_dataset(self, cache_root: Path, dataset_info: Dict[str, Any]):
|
|
||||||
"""Build float32 3-channel version of 16-bit TIFF dataset."""
|
|
||||||
if cache_root.exists():
|
|
||||||
shutil.rmtree(cache_root)
|
|
||||||
cache_root.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
splits = dataset_info.get("splits", {})
|
|
||||||
for split_name in ("train", "val", "test"):
|
|
||||||
split_entry = splits.get(split_name)
|
|
||||||
if not split_entry:
|
|
||||||
continue
|
|
||||||
images_src = Path(split_entry.get("path", ""))
|
|
||||||
if not images_src.exists():
|
|
||||||
continue
|
|
||||||
images_dst = cache_root / split_name / "images"
|
|
||||||
self._convert_16bit_to_float32_3ch(images_src, images_dst)
|
|
||||||
|
|
||||||
labels_src = self._infer_labels_dir(images_src)
|
|
||||||
if labels_src.exists():
|
|
||||||
labels_dst = cache_root / split_name / "labels"
|
|
||||||
self._copy_labels(labels_src, labels_dst)
|
|
||||||
|
|
||||||
class_names = dataset_info.get("class_names") or []
|
|
||||||
names_map = {idx: name for idx, name in enumerate(class_names)}
|
|
||||||
num_classes = dataset_info.get("num_classes") or len(class_names)
|
|
||||||
|
|
||||||
yaml_payload: Dict[str, Any] = {
|
|
||||||
"path": cache_root.as_posix(),
|
|
||||||
"names": names_map,
|
|
||||||
"nc": num_classes,
|
|
||||||
}
|
|
||||||
|
|
||||||
for split_name in ("train", "val", "test"):
|
|
||||||
images_dir = cache_root / split_name / "images"
|
|
||||||
if images_dir.exists():
|
|
||||||
yaml_payload[split_name] = f"{split_name}/images"
|
|
||||||
|
|
||||||
with open(cache_root / "data.yaml", "w", encoding="utf-8") as handle:
|
|
||||||
yaml.safe_dump(yaml_payload, handle, sort_keys=False)
|
|
||||||
|
|
||||||
def _convert_16bit_to_float32_3ch(self, src_dir: Path, dst_dir: Path):
|
|
||||||
"""Convert 16-bit TIFF images to float32 [0-1] 3-channel TIFFs.
|
|
||||||
|
|
||||||
This preserves the full dynamic range (no uint8 conversion) while
|
|
||||||
creating the 3-channel format that YOLO training expects.
|
|
||||||
"""
|
|
||||||
for src in src_dir.rglob("*"):
|
|
||||||
if not src.is_file() or src.suffix.lower() not in self.allowed_extensions:
|
|
||||||
continue
|
|
||||||
relative = src.relative_to(src_dir)
|
|
||||||
dst = dst_dir / relative.with_suffix(".tif")
|
|
||||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
try:
|
|
||||||
img_obj = Image(src)
|
|
||||||
|
|
||||||
# Check if it's a 16-bit TIFF
|
|
||||||
is_16bit_tiff = (
|
|
||||||
src.suffix.lower() in [".tif", ".tiff"]
|
|
||||||
and img_obj.dtype == np.uint16
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_16bit_tiff:
|
|
||||||
# Convert to float32 [0-1]
|
|
||||||
float_data = img_obj.to_normalized_float32()
|
|
||||||
|
|
||||||
# Replicate to 3 channels
|
|
||||||
if len(float_data.shape) == 2:
|
|
||||||
# H,W → H,W,3
|
|
||||||
float_3ch = np.stack([float_data] * 3, axis=-1)
|
|
||||||
elif len(float_data.shape) == 3 and float_data.shape[2] == 1:
|
|
||||||
# H,W,1 → H,W,3
|
|
||||||
float_3ch = np.repeat(float_data, 3, axis=2)
|
|
||||||
else:
|
|
||||||
# Already multi-channel
|
|
||||||
float_3ch = float_data
|
|
||||||
|
|
||||||
# Save as float32 TIFF (preserves full precision)
|
|
||||||
tifffile.imwrite(dst, float_3ch.astype(np.float32))
|
|
||||||
logger.debug(f"Converted {src} to float32 3-channel TIFF at {dst}")
|
|
||||||
else:
|
|
||||||
# For non-16-bit images, just copy
|
|
||||||
shutil.copy2(src, dst)
|
|
||||||
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(f"Failed to convert {src}: {exc}")
|
|
||||||
|
|
||||||
def _copy_labels(self, labels_src: Path, labels_dst: Path):
|
|
||||||
label_files = list(labels_src.rglob("*.txt"))
|
|
||||||
for label_file in label_files:
|
|
||||||
relative = label_file.relative_to(labels_src)
|
|
||||||
dst = labels_dst / relative
|
|
||||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
shutil.copy2(label_file, dst)
|
|
||||||
|
|
||||||
def _infer_labels_dir(self, images_dir: Path) -> Path:
|
def _infer_labels_dir(self, images_dir: Path) -> Path:
|
||||||
return images_dir.parent / "labels"
|
return images_dir.parent / "labels"
|
||||||
|
|
||||||
@@ -1514,11 +1331,9 @@ class TrainingTab(QWidget):
|
|||||||
self.training_log.clear()
|
self.training_log.clear()
|
||||||
self._export_labels_from_database(dataset_info)
|
self._export_labels_from_database(dataset_info)
|
||||||
|
|
||||||
dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info)
|
self._append_training_log(
|
||||||
if dataset_to_use != dataset_path:
|
"Using Float32 on-the-fly loader for 16-bit TIFF support (no disk caching)"
|
||||||
self._append_training_log(
|
)
|
||||||
f"Using float32 3-channel dataset at {dataset_to_use.parent}"
|
|
||||||
)
|
|
||||||
|
|
||||||
params = self._collect_training_params()
|
params = self._collect_training_params()
|
||||||
stage_plan = self._compose_stage_plan(params)
|
stage_plan = self._compose_stage_plan(params)
|
||||||
@@ -1544,7 +1359,7 @@ class TrainingTab(QWidget):
|
|||||||
self._set_training_state(True)
|
self._set_training_state(True)
|
||||||
|
|
||||||
self.training_worker = TrainingWorker(
|
self.training_worker = TrainingWorker(
|
||||||
data_yaml=dataset_to_use.as_posix(),
|
data_yaml=dataset_path.as_posix(),
|
||||||
base_model=params["base_model"],
|
base_model=params["base_model"],
|
||||||
epochs=params["epochs"],
|
epochs=params["epochs"],
|
||||||
batch=params["batch"],
|
batch=params["batch"],
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import os
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from src.utils.image import Image, convert_grayscale_to_rgb_preserve_range
|
from src.utils.image import Image, convert_grayscale_to_rgb_preserve_range
|
||||||
from src.utils.logger import get_logger
|
from src.utils.logger import get_logger
|
||||||
|
from src.utils.train_ultralytics_float import train_with_float32_loader
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -60,10 +61,11 @@ class YOLOWrapper:
|
|||||||
name: str = "custom_model",
|
name: str = "custom_model",
|
||||||
resume: bool = False,
|
resume: bool = False,
|
||||||
callbacks: Optional[Dict[str, Callable]] = None,
|
callbacks: Optional[Dict[str, Callable]] = None,
|
||||||
|
use_float32_loader: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Train the YOLO model.
|
Train the YOLO model with optional float32 loader for 16-bit TIFFs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data_yaml: Path to data.yaml configuration file
|
data_yaml: Path to data.yaml configuration file
|
||||||
@@ -75,41 +77,62 @@ class YOLOWrapper:
|
|||||||
name: Name for the training run
|
name: Name for the training run
|
||||||
resume: Resume training from last checkpoint
|
resume: Resume training from last checkpoint
|
||||||
callbacks: Optional Ultralytics callback dictionary
|
callbacks: Optional Ultralytics callback dictionary
|
||||||
|
use_float32_loader: Use custom Float32Dataset for 16-bit TIFFs (default: True)
|
||||||
**kwargs: Additional training arguments
|
**kwargs: Additional training arguments
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with training results
|
Dictionary with training results
|
||||||
"""
|
"""
|
||||||
if self.model is None:
|
if 1:
|
||||||
if not self.load_model():
|
|
||||||
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.info(f"Starting training: {name}")
|
logger.info(f"Starting training: {name}")
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
|
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Train the model
|
# Check if dataset has 16-bit TIFFs and use float32 loader
|
||||||
results = self.model.train(
|
if use_float32_loader:
|
||||||
data=data_yaml,
|
logger.info("Using Float32Dataset loader for 16-bit TIFF support")
|
||||||
epochs=epochs,
|
return train_with_float32_loader(
|
||||||
imgsz=imgsz,
|
model_path=self.model_path,
|
||||||
batch=batch,
|
data_yaml=data_yaml,
|
||||||
patience=patience,
|
epochs=epochs,
|
||||||
project=save_dir,
|
imgsz=imgsz,
|
||||||
name=name,
|
batch=batch,
|
||||||
device=self.device,
|
patience=patience,
|
||||||
resume=resume,
|
save_dir=save_dir,
|
||||||
**kwargs,
|
name=name,
|
||||||
)
|
callbacks=callbacks,
|
||||||
|
device=self.device,
|
||||||
|
resume=resume,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Standard training (old behavior)
|
||||||
|
if self.model is None:
|
||||||
|
if not self.load_model():
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to load model from {self.model_path}"
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Training completed successfully")
|
results = self.model.train(
|
||||||
return self._format_training_results(results)
|
data=data_yaml,
|
||||||
|
epochs=epochs,
|
||||||
|
imgsz=imgsz,
|
||||||
|
batch=batch,
|
||||||
|
patience=patience,
|
||||||
|
project=save_dir,
|
||||||
|
name=name,
|
||||||
|
device=self.device,
|
||||||
|
resume=resume,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
logger.info("Training completed successfully")
|
||||||
logger.error(f"Error during training: {e}")
|
return self._format_training_results(results)
|
||||||
raise
|
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.error(f"Error during training: {e}")
|
||||||
|
# raise
|
||||||
|
|
||||||
def validate(self, data_yaml: str, split: str = "val", **kwargs) -> Dict[str, Any]:
|
def validate(self, data_yaml: str, split: str = "val", **kwargs) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
561
src/utils/train_ultralytics_float.py
Normal file
561
src/utils/train_ultralytics_float.py
Normal file
@@ -0,0 +1,561 @@
|
|||||||
|
"""
|
||||||
|
Custom YOLO training with on-the-fly float32 conversion for 16-bit grayscale images.
|
||||||
|
|
||||||
|
This module provides a custom dataset class and training function that:
|
||||||
|
1. Load 16-bit TIFF images directly with tifffile (no PIL/cv2)
|
||||||
|
2. Convert to float32 [0-1] on-the-fly (no data loss)
|
||||||
|
3. Replicate grayscale to 3-channel RGB in memory
|
||||||
|
4. Use custom training loop to bypass Ultralytics' dataset infrastructure
|
||||||
|
5. No disk caching required
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tifffile
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Dict, Any, List, Tuple
|
||||||
|
from ultralytics import YOLO
|
||||||
|
import yaml
|
||||||
|
import time
|
||||||
|
|
||||||
|
from src.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Float32YOLODataset(Dataset):
|
||||||
|
"""
|
||||||
|
Custom PyTorch dataset for YOLO that loads 16-bit grayscale TIFFs as float32 RGB.
|
||||||
|
|
||||||
|
This dataset:
|
||||||
|
- Loads with tifffile (not PIL/cv2)
|
||||||
|
- Converts uint16 → float32 [0-1] (preserves full dynamic range)
|
||||||
|
- Replicates grayscale to 3 channels
|
||||||
|
- Returns torch tensors in (C, H, W) format
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, images_dir: str, labels_dir: str, img_size: int = 640):
|
||||||
|
"""
|
||||||
|
Initialize dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images_dir: Directory containing images
|
||||||
|
labels_dir: Directory containing YOLO label files (.txt)
|
||||||
|
img_size: Target image size (for reference, actual resizing done by model)
|
||||||
|
"""
|
||||||
|
self.images_dir = Path(images_dir)
|
||||||
|
self.labels_dir = Path(labels_dir)
|
||||||
|
self.img_size = img_size
|
||||||
|
|
||||||
|
# Find all image files
|
||||||
|
extensions = {".tif", ".tiff", ".png", ".jpg", ".jpeg", ".bmp"}
|
||||||
|
self.image_paths = sorted(
|
||||||
|
[
|
||||||
|
p
|
||||||
|
for p in self.images_dir.rglob("*")
|
||||||
|
if p.is_file() and p.suffix.lower() in extensions
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.image_paths:
|
||||||
|
raise ValueError(f"No images found in {images_dir}")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Float32YOLODataset initialized with {len(self.image_paths)} images from {images_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.image_paths)
|
||||||
|
|
||||||
|
def _read_image(self, img_path: Path) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Read image and convert to float32 [0-1] RGB.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
numpy array, shape (H, W, 3), dtype float32, range [0, 1]
|
||||||
|
"""
|
||||||
|
# Load image with tifffile
|
||||||
|
img = tifffile.imread(str(img_path))
|
||||||
|
|
||||||
|
# Convert to float32
|
||||||
|
img = img.astype(np.float32)
|
||||||
|
|
||||||
|
# Normalize if 16-bit (values > 1.5 indicates uint16)
|
||||||
|
if img.max() > 1.5:
|
||||||
|
img = img / 65535.0
|
||||||
|
|
||||||
|
# Ensure [0, 1] range
|
||||||
|
img = np.clip(img, 0.0, 1.0)
|
||||||
|
|
||||||
|
# Convert grayscale to RGB
|
||||||
|
if img.ndim == 2:
|
||||||
|
# H,W → H,W,3
|
||||||
|
img = np.repeat(img[..., None], 3, axis=2)
|
||||||
|
elif img.ndim == 3 and img.shape[2] == 1:
|
||||||
|
# H,W,1 → H,W,3
|
||||||
|
img = np.repeat(img, 3, axis=2)
|
||||||
|
|
||||||
|
return img # float32 (H, W, 3) in [0, 1]
|
||||||
|
|
||||||
|
def _parse_label(self, label_path: Path) -> List[np.ndarray]:
|
||||||
|
"""
|
||||||
|
Parse YOLO label file with variable-length rows (segmentation polygons).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of numpy arrays, one per annotation
|
||||||
|
"""
|
||||||
|
if not label_path.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
labels = []
|
||||||
|
try:
|
||||||
|
with open(label_path, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
# Parse space-separated values
|
||||||
|
values = line.split()
|
||||||
|
if len(values) >= 5: # At minimum: class_id x y w h
|
||||||
|
labels.append(
|
||||||
|
np.array([float(v) for v in values], dtype=np.float32)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error parsing label {label_path}: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
return labels
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, List[np.ndarray], str]:
|
||||||
|
"""
|
||||||
|
Get a single training sample.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (image_tensor, labels, filename)
|
||||||
|
- image_tensor: shape (3, H, W), dtype float32, range [0, 1]
|
||||||
|
- labels: list of numpy arrays with YOLO format labels (variable length for segmentation)
|
||||||
|
- filename: image filename
|
||||||
|
"""
|
||||||
|
img_path = self.image_paths[idx]
|
||||||
|
label_path = self.labels_dir / f"{img_path.stem}.txt"
|
||||||
|
|
||||||
|
# Load image as float32 RGB
|
||||||
|
img = self._read_image(img_path)
|
||||||
|
|
||||||
|
# Convert to tensor: (H, W, 3) → (3, H, W)
|
||||||
|
img_tensor = torch.from_numpy(img).permute(2, 0, 1).contiguous()
|
||||||
|
|
||||||
|
# Load labels (list of variable-length arrays for segmentation)
|
||||||
|
labels = self._parse_label(label_path)
|
||||||
|
|
||||||
|
return img_tensor, labels, img_path.name
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(
|
||||||
|
batch: List[Tuple[torch.Tensor, List[np.ndarray], str]],
|
||||||
|
) -> Tuple[torch.Tensor, List[List[np.ndarray]], List[str]]:
|
||||||
|
"""
|
||||||
|
Collate function for DataLoader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: List of (img_tensor, labels_list, filename) tuples
|
||||||
|
where labels_list is a list of variable-length numpy arrays
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (stacked_images, list_of_labels_lists, list_of_filenames)
|
||||||
|
"""
|
||||||
|
imgs = [b[0] for b in batch]
|
||||||
|
labels = [b[1] for b in batch] # Each element is a list of arrays
|
||||||
|
names = [b[2] for b in batch]
|
||||||
|
|
||||||
|
# Stack images - requires same H,W
|
||||||
|
# For different sizes, implement letterbox/resize in dataset
|
||||||
|
imgs_batch = torch.stack(imgs, dim=0)
|
||||||
|
|
||||||
|
return imgs_batch, labels, names
|
||||||
|
|
||||||
|
|
||||||
|
def train_with_float32_loader(
|
||||||
|
model_path: str,
|
||||||
|
data_yaml: str,
|
||||||
|
epochs: int = 100,
|
||||||
|
imgsz: int = 640,
|
||||||
|
batch: int = 16,
|
||||||
|
patience: int = 50,
|
||||||
|
save_dir: str = "data/models",
|
||||||
|
name: str = "custom_model",
|
||||||
|
callbacks: Optional[Dict] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Train YOLO model with custom Float32 dataset for 16-bit TIFF support.
|
||||||
|
|
||||||
|
Uses a custom training loop to bypass Ultralytics' dataset pipeline,
|
||||||
|
avoiding channel conversion issues.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to base model weights (.pt file)
|
||||||
|
data_yaml: Path to dataset YAML configuration
|
||||||
|
epochs: Number of training epochs
|
||||||
|
imgsz: Input image size
|
||||||
|
batch: Batch size
|
||||||
|
patience: Early stopping patience
|
||||||
|
save_dir: Directory to save trained model
|
||||||
|
name: Name for the training run
|
||||||
|
callbacks: Optional callback dictionary (for progress reporting)
|
||||||
|
**kwargs: Additional training arguments (lr0, freeze, device, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with training results including model paths and metrics
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Starting Float32 custom training: {name}")
|
||||||
|
logger.info(
|
||||||
|
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse data.yaml to get dataset paths
|
||||||
|
with open(data_yaml, "r") as f:
|
||||||
|
data_config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
dataset_root = Path(data_config.get("path", Path(data_yaml).parent))
|
||||||
|
train_images = dataset_root / data_config.get("train", "train/images")
|
||||||
|
val_images = dataset_root / data_config.get("val", "val/images")
|
||||||
|
|
||||||
|
# Infer label directories
|
||||||
|
train_labels = train_images.parent / "labels"
|
||||||
|
val_labels = val_images.parent / "labels"
|
||||||
|
|
||||||
|
logger.info(f"Train images: {train_images}")
|
||||||
|
logger.info(f"Train labels: {train_labels}")
|
||||||
|
logger.info(f"Val images: {val_images}")
|
||||||
|
logger.info(f"Val labels: {val_labels}")
|
||||||
|
|
||||||
|
# Create datasets
|
||||||
|
train_dataset = Float32YOLODataset(
|
||||||
|
str(train_images), str(train_labels), img_size=imgsz
|
||||||
|
)
|
||||||
|
val_dataset = Float32YOLODataset(
|
||||||
|
str(val_images), str(val_labels), img_size=imgsz
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create data loaders
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=batch,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=4,
|
||||||
|
pin_memory=True,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=batch,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=2,
|
||||||
|
pin_memory=True,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
logger.info(f"Loading model from {model_path}")
|
||||||
|
ul_model = YOLO(model_path)
|
||||||
|
|
||||||
|
# Get PyTorch model
|
||||||
|
pt_model, loss_fn = _get_pytorch_model(ul_model)
|
||||||
|
|
||||||
|
# Setup device
|
||||||
|
device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
# Configure model args for loss function
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
# Required args for segmentation loss
|
||||||
|
required_args = {
|
||||||
|
"overlap_mask": True,
|
||||||
|
"mask_ratio": 4,
|
||||||
|
"task": "segment",
|
||||||
|
"single_cls": False,
|
||||||
|
"box": 7.5,
|
||||||
|
"cls": 0.5,
|
||||||
|
"dfl": 1.5,
|
||||||
|
}
|
||||||
|
|
||||||
|
if not hasattr(pt_model, "args"):
|
||||||
|
# No args - create SimpleNamespace
|
||||||
|
pt_model.args = SimpleNamespace(**required_args)
|
||||||
|
elif isinstance(pt_model.args, dict):
|
||||||
|
# Args is dict - MUST convert to SimpleNamespace for attribute access
|
||||||
|
# The loss function uses model.args.overlap_mask (attribute access)
|
||||||
|
merged = {**pt_model.args, **required_args}
|
||||||
|
pt_model.args = SimpleNamespace(**merged)
|
||||||
|
logger.info(
|
||||||
|
"Converted model.args from dict to SimpleNamespace for loss function compatibility"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Args is SimpleNamespace or other - set attributes
|
||||||
|
for key, value in required_args.items():
|
||||||
|
if not hasattr(pt_model.args, key):
|
||||||
|
setattr(pt_model.args, key, value)
|
||||||
|
|
||||||
|
pt_model.to(device)
|
||||||
|
pt_model.train()
|
||||||
|
|
||||||
|
logger.info(f"Training on device: {device}")
|
||||||
|
logger.info(f"PyTorch model type: {type(pt_model)}")
|
||||||
|
logger.info(f"Model args configured for segmentation loss")
|
||||||
|
|
||||||
|
# Setup optimizer
|
||||||
|
lr0 = kwargs.get("lr0", 0.01)
|
||||||
|
optimizer = torch.optim.AdamW(pt_model.parameters(), lr=lr0)
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
save_path = Path(save_dir) / name
|
||||||
|
save_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
weights_dir = save_path / "weights"
|
||||||
|
weights_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
best_loss = float("inf")
|
||||||
|
patience_counter = 0
|
||||||
|
|
||||||
|
for epoch in range(epochs):
|
||||||
|
epoch_start = time.time()
|
||||||
|
running_loss = 0.0
|
||||||
|
num_batches = 0
|
||||||
|
|
||||||
|
logger.info(f"Epoch {epoch+1}/{epochs} starting...")
|
||||||
|
|
||||||
|
for batch_idx, (imgs, labels_list, names) in enumerate(train_loader):
|
||||||
|
imgs = imgs.to(device) # (B, 3, H, W) float32
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
try:
|
||||||
|
preds = pt_model(imgs)
|
||||||
|
except Exception as e:
|
||||||
|
# Try with labels
|
||||||
|
preds = pt_model(imgs, labels_list)
|
||||||
|
|
||||||
|
# Compute loss
|
||||||
|
# For Ultralytics models, the easiest approach is to construct a batch dict
|
||||||
|
# and call the model in training mode which returns preds + loss
|
||||||
|
batch_dict = {
|
||||||
|
"img": imgs, # Already on device
|
||||||
|
"batch_idx": (
|
||||||
|
torch.cat(
|
||||||
|
[
|
||||||
|
torch.full((len(lab),), i, dtype=torch.long)
|
||||||
|
for i, lab in enumerate(labels_list)
|
||||||
|
]
|
||||||
|
).to(device)
|
||||||
|
if any(len(lab) > 0 for lab in labels_list)
|
||||||
|
else torch.tensor([], dtype=torch.long, device=device)
|
||||||
|
),
|
||||||
|
"cls": (
|
||||||
|
torch.cat(
|
||||||
|
[
|
||||||
|
torch.from_numpy(lab[:, 0:1])
|
||||||
|
for lab in labels_list
|
||||||
|
if len(lab) > 0
|
||||||
|
]
|
||||||
|
).to(device)
|
||||||
|
if any(len(lab) > 0 for lab in labels_list)
|
||||||
|
else torch.tensor([], dtype=torch.float32, device=device)
|
||||||
|
),
|
||||||
|
"bboxes": (
|
||||||
|
torch.cat(
|
||||||
|
[
|
||||||
|
torch.from_numpy(lab[:, 1:5])
|
||||||
|
for lab in labels_list
|
||||||
|
if len(lab) > 0
|
||||||
|
]
|
||||||
|
).to(device)
|
||||||
|
if any(len(lab) > 0 for lab in labels_list)
|
||||||
|
else torch.tensor([], dtype=torch.float32, device=device)
|
||||||
|
),
|
||||||
|
"ori_shape": (imgs.shape[2], imgs.shape[3]), # H, W
|
||||||
|
"resized_shape": (imgs.shape[2], imgs.shape[3]),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add masks if segmentation labels exist
|
||||||
|
if any(len(lab) > 5 for lab in labels_list if len(lab) > 0):
|
||||||
|
masks = []
|
||||||
|
for lab in labels_list:
|
||||||
|
if len(lab) > 0 and lab.shape[1] > 5:
|
||||||
|
# Has segmentation points
|
||||||
|
masks.append(torch.from_numpy(lab[:, 5:]))
|
||||||
|
if masks:
|
||||||
|
batch_dict["masks"] = masks
|
||||||
|
|
||||||
|
# Call model loss (it will compute loss internally)
|
||||||
|
try:
|
||||||
|
loss_output = pt_model.loss(batch_dict, preds)
|
||||||
|
if isinstance(loss_output, (tuple, list)):
|
||||||
|
loss = loss_output[0]
|
||||||
|
else:
|
||||||
|
loss = loss_output
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Model loss computation failed: {e}")
|
||||||
|
# Last resort: maybe preds is already a dict with 'loss'
|
||||||
|
if isinstance(preds, dict) and "loss" in preds:
|
||||||
|
loss = preds["loss"]
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Cannot compute loss: {e}")
|
||||||
|
|
||||||
|
# Backward pass
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
running_loss += loss.item()
|
||||||
|
num_batches += 1
|
||||||
|
|
||||||
|
# Report progress via callback
|
||||||
|
if callbacks and "on_fit_epoch_end" in callbacks:
|
||||||
|
# Create a mock trainer object for callback
|
||||||
|
class MockTrainer:
|
||||||
|
def __init__(self, epoch):
|
||||||
|
self.epoch = epoch
|
||||||
|
self.loss_items = [loss.item()]
|
||||||
|
|
||||||
|
callbacks["on_fit_epoch_end"](MockTrainer(epoch))
|
||||||
|
|
||||||
|
epoch_loss = running_loss / max(1, num_batches)
|
||||||
|
epoch_time = time.time() - epoch_start
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Epoch {epoch+1}/{epochs} completed. Avg Loss: {epoch_loss:.4f}, Time: {epoch_time:.1f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save checkpoint
|
||||||
|
ckpt_path = weights_dir / f"epoch{epoch+1}.pt"
|
||||||
|
torch.save(
|
||||||
|
{
|
||||||
|
"epoch": epoch + 1,
|
||||||
|
"model_state_dict": pt_model.state_dict(),
|
||||||
|
"optimizer_state_dict": optimizer.state_dict(),
|
||||||
|
"loss": epoch_loss,
|
||||||
|
},
|
||||||
|
ckpt_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save as last.pt
|
||||||
|
last_path = weights_dir / "last.pt"
|
||||||
|
torch.save(pt_model.state_dict(), last_path)
|
||||||
|
|
||||||
|
# Check for best model
|
||||||
|
if epoch_loss < best_loss:
|
||||||
|
best_loss = epoch_loss
|
||||||
|
patience_counter = 0
|
||||||
|
best_path = weights_dir / "best.pt"
|
||||||
|
torch.save(pt_model.state_dict(), best_path)
|
||||||
|
logger.info(f"New best model saved: {best_path}")
|
||||||
|
else:
|
||||||
|
patience_counter += 1
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
if patience_counter >= patience:
|
||||||
|
logger.info(f"Early stopping triggered after {epoch+1} epochs")
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info("Training completed successfully")
|
||||||
|
|
||||||
|
# Format results
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"final_epoch": epoch + 1,
|
||||||
|
"metrics": {
|
||||||
|
"final_loss": epoch_loss,
|
||||||
|
"best_loss": best_loss,
|
||||||
|
},
|
||||||
|
"best_model_path": str(weights_dir / "best.pt"),
|
||||||
|
"last_model_path": str(weights_dir / "last.pt"),
|
||||||
|
"save_dir": str(save_path),
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during Float32 training: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def _get_pytorch_model(ul_model: YOLO) -> Tuple[torch.nn.Module, Optional[callable]]:
|
||||||
|
"""
|
||||||
|
Extract PyTorch model and loss function from Ultralytics YOLO wrapper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ul_model: Ultralytics YOLO model wrapper
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (pytorch_model, loss_function)
|
||||||
|
"""
|
||||||
|
# Try to get the underlying PyTorch model
|
||||||
|
candidates = []
|
||||||
|
|
||||||
|
# Direct model attribute
|
||||||
|
if hasattr(ul_model, "model"):
|
||||||
|
candidates.append(ul_model.model)
|
||||||
|
|
||||||
|
# Sometimes nested
|
||||||
|
if hasattr(ul_model, "model") and hasattr(ul_model.model, "model"):
|
||||||
|
candidates.append(ul_model.model.model)
|
||||||
|
|
||||||
|
# The wrapper itself
|
||||||
|
if isinstance(ul_model, torch.nn.Module):
|
||||||
|
candidates.append(ul_model)
|
||||||
|
|
||||||
|
# Find a valid model
|
||||||
|
pt_model = None
|
||||||
|
loss_fn = None
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
if candidate is None or not isinstance(candidate, torch.nn.Module):
|
||||||
|
continue
|
||||||
|
|
||||||
|
pt_model = candidate
|
||||||
|
|
||||||
|
# Try to find loss function
|
||||||
|
if hasattr(candidate, "loss") and callable(getattr(candidate, "loss")):
|
||||||
|
loss_fn = getattr(candidate, "loss")
|
||||||
|
elif hasattr(candidate, "compute_loss") and callable(
|
||||||
|
getattr(candidate, "compute_loss")
|
||||||
|
):
|
||||||
|
loss_fn = getattr(candidate, "compute_loss")
|
||||||
|
|
||||||
|
break
|
||||||
|
|
||||||
|
if pt_model is None:
|
||||||
|
raise RuntimeError("Could not extract PyTorch model from Ultralytics wrapper")
|
||||||
|
|
||||||
|
logger.info(f"Extracted PyTorch model: {type(pt_model)}")
|
||||||
|
logger.info(
|
||||||
|
f"Loss function: {type(loss_fn) if loss_fn else 'None (will attempt fallback)'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return pt_model, loss_fn
|
||||||
|
|
||||||
|
|
||||||
|
# Compatibility function (kept for backwards compatibility)
|
||||||
|
def train_float32(model: YOLO, data_yaml: str, **train_kwargs) -> Any:
|
||||||
|
"""
|
||||||
|
Train YOLO model with Float32YOLODataset (alternative API).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Initialized YOLO model instance
|
||||||
|
data_yaml: Path to dataset YAML
|
||||||
|
**train_kwargs: Training parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Training results dict
|
||||||
|
"""
|
||||||
|
return train_with_float32_loader(
|
||||||
|
model_path=(
|
||||||
|
model.model_path if hasattr(model, "model_path") else "yolov8s-seg.pt"
|
||||||
|
),
|
||||||
|
data_yaml=data_yaml,
|
||||||
|
**train_kwargs,
|
||||||
|
)
|
||||||
211
tests/test_float32_training_loader.py
Normal file
211
tests/test_float32_training_loader.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
"""
|
||||||
|
Test script for Float32 on-the-fly loading for 16-bit TIFFs.
|
||||||
|
|
||||||
|
This test verifies that:
|
||||||
|
1. Float32YOLODataset can load 16-bit TIFF files
|
||||||
|
2. Images are converted to float32 [0-1] in memory
|
||||||
|
3. Grayscale is replicated to 3 channels (RGB)
|
||||||
|
4. No disk caching is used
|
||||||
|
5. Full 16-bit precision is preserved
|
||||||
|
"""
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
import numpy as np
|
||||||
|
import tifffile
|
||||||
|
from pathlib import Path
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_dataset():
|
||||||
|
"""Create a minimal test dataset with 16-bit TIFF images."""
|
||||||
|
temp_dir = Path(tempfile.mkdtemp())
|
||||||
|
dataset_dir = temp_dir / "test_dataset"
|
||||||
|
|
||||||
|
# Create directory structure
|
||||||
|
train_images = dataset_dir / "train" / "images"
|
||||||
|
train_labels = dataset_dir / "train" / "labels"
|
||||||
|
train_images.mkdir(parents=True, exist_ok=True)
|
||||||
|
train_labels.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Create a 16-bit TIFF test image
|
||||||
|
img_16bit = np.random.randint(0, 65536, (100, 100), dtype=np.uint16)
|
||||||
|
img_path = train_images / "test_image.tif"
|
||||||
|
tifffile.imwrite(str(img_path), img_16bit)
|
||||||
|
|
||||||
|
# Create a dummy label file
|
||||||
|
label_path = train_labels / "test_image.txt"
|
||||||
|
with open(label_path, "w") as f:
|
||||||
|
f.write("0 0.5 0.5 0.2 0.2\n") # class_id x_center y_center width height
|
||||||
|
|
||||||
|
# Create data.yaml
|
||||||
|
data_yaml = {
|
||||||
|
"path": str(dataset_dir),
|
||||||
|
"train": "train/images",
|
||||||
|
"val": "train/images", # Use same for val in test
|
||||||
|
"names": {0: "object"},
|
||||||
|
"nc": 1,
|
||||||
|
}
|
||||||
|
yaml_path = dataset_dir / "data.yaml"
|
||||||
|
with open(yaml_path, "w") as f:
|
||||||
|
yaml.safe_dump(data_yaml, f)
|
||||||
|
|
||||||
|
print(f"✓ Created test dataset at: {dataset_dir}")
|
||||||
|
print(f" - Image: {img_path} (shape={img_16bit.shape}, dtype={img_16bit.dtype})")
|
||||||
|
print(f" - Min value: {img_16bit.min()}, Max value: {img_16bit.max()}")
|
||||||
|
print(f" - data.yaml: {yaml_path}")
|
||||||
|
|
||||||
|
return dataset_dir, img_path, img_16bit
|
||||||
|
|
||||||
|
|
||||||
|
def test_float32_dataset():
|
||||||
|
"""Test the Float32YOLODataset class directly."""
|
||||||
|
print("\n=== Testing Float32YOLODataset ===\n")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.utils.train_ultralytics_float import Float32YOLODataset
|
||||||
|
|
||||||
|
print("✓ Successfully imported Float32YOLODataset")
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"✗ Failed to import Float32YOLODataset: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Create test dataset
|
||||||
|
dataset_dir, img_path, original_img = create_test_dataset()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize the dataset
|
||||||
|
print("\nInitializing Float32YOLODataset...")
|
||||||
|
dataset = Float32YOLODataset(
|
||||||
|
images_dir=str(dataset_dir / "train" / "images"),
|
||||||
|
labels_dir=str(dataset_dir / "train" / "labels"),
|
||||||
|
img_size=640,
|
||||||
|
)
|
||||||
|
print(f"✓ Float32YOLODataset initialized with {len(dataset)} images")
|
||||||
|
|
||||||
|
# Get an item
|
||||||
|
if len(dataset) > 0:
|
||||||
|
print("\nGetting first item...")
|
||||||
|
img_tensor, labels, filename = dataset[0]
|
||||||
|
|
||||||
|
print(f"✓ Item retrieved successfully")
|
||||||
|
print(f" - Image tensor shape: {img_tensor.shape}")
|
||||||
|
print(f" - Image tensor dtype: {img_tensor.dtype}")
|
||||||
|
print(f" - Value range: [{img_tensor.min():.6f}, {img_tensor.max():.6f}]")
|
||||||
|
print(f" - Filename: {filename}")
|
||||||
|
print(f" - Labels: {len(labels)} annotations")
|
||||||
|
if labels:
|
||||||
|
print(
|
||||||
|
f" - First label shape: {labels[0].shape if len(labels) > 0 else 'N/A'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify it's float32
|
||||||
|
if img_tensor.dtype == torch.float32:
|
||||||
|
print("✓ Correct dtype: float32")
|
||||||
|
else:
|
||||||
|
print(f"✗ Wrong dtype: {img_tensor.dtype} (expected float32)")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Verify it's 3-channel in correct format (C, H, W)
|
||||||
|
if len(img_tensor.shape) == 3 and img_tensor.shape[0] == 3:
|
||||||
|
print(
|
||||||
|
f"✓ Correct format: (C, H, W) = {img_tensor.shape} with 3 channels"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(f"✗ Wrong shape: {img_tensor.shape} (expected (3, H, W))")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Verify it's in [0, 1] range
|
||||||
|
if 0.0 <= img_tensor.min() and img_tensor.max() <= 1.0:
|
||||||
|
print("✓ Values in correct range: [0, 1]")
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"✗ Values out of range: [{img_tensor.min()}, {img_tensor.max()}]"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Verify precision (should have many unique values)
|
||||||
|
unique_values = len(torch.unique(img_tensor))
|
||||||
|
print(f" - Unique values: {unique_values}")
|
||||||
|
if unique_values > 256:
|
||||||
|
print(f"✓ High precision maintained ({unique_values} > 256 levels)")
|
||||||
|
else:
|
||||||
|
print(f"⚠ Low precision: only {unique_values} unique values")
|
||||||
|
|
||||||
|
print("\n✓ All Float32YOLODataset tests passed!")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print("✗ No items in dataset")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ Error during testing: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def test_integration():
|
||||||
|
"""Test integration with train_with_float32_loader."""
|
||||||
|
print("\n=== Testing Integration with train_with_float32_loader ===\n")
|
||||||
|
|
||||||
|
# Create test dataset
|
||||||
|
dataset_dir, img_path, original_img = create_test_dataset()
|
||||||
|
data_yaml = dataset_dir / "data.yaml"
|
||||||
|
|
||||||
|
print(f"\nTest dataset ready at: {data_yaml}")
|
||||||
|
print("\nTo test full training, run:")
|
||||||
|
print(f" from src.utils.train_ultralytics_float import train_with_float32_loader")
|
||||||
|
print(f" results = train_with_float32_loader(")
|
||||||
|
print(f" model_path='yolov8n-seg.pt',")
|
||||||
|
print(f" data_yaml='{data_yaml}',")
|
||||||
|
print(f" epochs=1,")
|
||||||
|
print(f" batch=1,")
|
||||||
|
print(f" imgsz=640")
|
||||||
|
print(f" )")
|
||||||
|
print("\nThis will use custom training loop with Float32YOLODataset")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run all tests."""
|
||||||
|
import torch # Import here to ensure torch is available
|
||||||
|
|
||||||
|
print("=" * 70)
|
||||||
|
print("Float32 Training Loader Test Suite")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# Test 1: Float32YOLODataset
|
||||||
|
results.append(("Float32YOLODataset", test_float32_dataset()))
|
||||||
|
|
||||||
|
# Test 2: Integration check
|
||||||
|
results.append(("Integration Check", test_integration()))
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("Test Summary")
|
||||||
|
print("=" * 70)
|
||||||
|
for test_name, passed in results:
|
||||||
|
status = "✓ PASSED" if passed else "✗ FAILED"
|
||||||
|
print(f"{status}: {test_name}")
|
||||||
|
|
||||||
|
all_passed = all(passed for _, passed in results)
|
||||||
|
print("=" * 70)
|
||||||
|
if all_passed:
|
||||||
|
print("✓ All tests passed!")
|
||||||
|
else:
|
||||||
|
print("✗ Some tests failed")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
return all_passed
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
import torch # Make torch available
|
||||||
|
|
||||||
|
success = main()
|
||||||
|
sys.exit(0 if success else 1)
|
||||||
@@ -18,8 +18,8 @@ from src.utils.image import Image
|
|||||||
|
|
||||||
|
|
||||||
def test_float32_3ch_conversion():
|
def test_float32_3ch_conversion():
|
||||||
"""Test conversion of 16-bit TIFF to float32 3-channel TIFF."""
|
"""Test conversion of 16-bit TIFF to 16-bit RGB PNG."""
|
||||||
print("\n=== Testing Float32 3-Channel Conversion ===")
|
print("\n=== Testing 16-bit RGB PNG Conversion ===")
|
||||||
|
|
||||||
# Create temporary directory structure
|
# Create temporary directory structure
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
@@ -42,39 +42,65 @@ def test_float32_3ch_conversion():
|
|||||||
print(f" Dtype: {test_data.dtype}")
|
print(f" Dtype: {test_data.dtype}")
|
||||||
print(f" Range: [{test_data.min()}, {test_data.max()}]")
|
print(f" Range: [{test_data.min()}, {test_data.max()}]")
|
||||||
|
|
||||||
# Simulate the conversion process
|
# Simulate the conversion process (matching training_tab.py)
|
||||||
print("\nConverting to float32 3-channel...")
|
print("\nConverting to 16-bit RGB PNG using PIL merge...")
|
||||||
img_obj = Image(test_file)
|
img_obj = Image(test_file)
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
# Convert to float32 [0-1]
|
# Get uint16 data
|
||||||
float_data = img_obj.to_normalized_float32()
|
uint16_data = img_obj.data
|
||||||
|
|
||||||
# Replicate to 3 channels
|
# Use PIL's merge method with 'I;16' channels (proper way for 16-bit RGB)
|
||||||
if len(float_data.shape) == 2:
|
if len(uint16_data.shape) == 2:
|
||||||
float_3ch = np.stack([float_data] * 3, axis=-1)
|
# Grayscale - replicate to RGB
|
||||||
|
r_img = PILImage.fromarray(uint16_data, mode="I;16")
|
||||||
|
g_img = PILImage.fromarray(uint16_data, mode="I;16")
|
||||||
|
b_img = PILImage.fromarray(uint16_data, mode="I;16")
|
||||||
else:
|
else:
|
||||||
float_3ch = float_data
|
r_img = PILImage.fromarray(uint16_data[:, :, 0], mode="I;16")
|
||||||
|
g_img = PILImage.fromarray(
|
||||||
|
(
|
||||||
|
uint16_data[:, :, 1]
|
||||||
|
if uint16_data.shape[2] > 1
|
||||||
|
else uint16_data[:, :, 0]
|
||||||
|
),
|
||||||
|
mode="I;16",
|
||||||
|
)
|
||||||
|
b_img = PILImage.fromarray(
|
||||||
|
(
|
||||||
|
uint16_data[:, :, 2]
|
||||||
|
if uint16_data.shape[2] > 2
|
||||||
|
else uint16_data[:, :, 0]
|
||||||
|
),
|
||||||
|
mode="I;16",
|
||||||
|
)
|
||||||
|
|
||||||
# Save as float32 TIFF
|
# Merge channels into RGB
|
||||||
output_file = dst_dir / "test_float32_3ch.tif"
|
rgb_img = PILImage.merge("RGB", (r_img, g_img, b_img))
|
||||||
tifffile.imwrite(output_file, float_3ch.astype(np.float32))
|
|
||||||
print(f"Saved float32 3-channel TIFF: {output_file}")
|
|
||||||
|
|
||||||
# Verify the output
|
# Save as PNG
|
||||||
loaded = tifffile.imread(output_file)
|
output_file = dst_dir / "test_16bit_rgb.png"
|
||||||
print(f"\nVerifying output:")
|
rgb_img.save(output_file)
|
||||||
|
print(f"Saved 16-bit RGB PNG: {output_file}")
|
||||||
|
print(f" PIL mode after merge: {rgb_img.mode}")
|
||||||
|
|
||||||
|
# Verify the output - Load with OpenCV (as YOLO does)
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
loaded = cv2.imread(str(output_file), cv2.IMREAD_UNCHANGED)
|
||||||
|
print(f"\nVerifying output (loaded with OpenCV):")
|
||||||
print(f" Shape: {loaded.shape}")
|
print(f" Shape: {loaded.shape}")
|
||||||
print(f" Dtype: {loaded.dtype}")
|
print(f" Dtype: {loaded.dtype}")
|
||||||
print(f" Channels: {loaded.shape[2] if len(loaded.shape) == 3 else 1}")
|
print(f" Channels: {loaded.shape[2] if len(loaded.shape) == 3 else 1}")
|
||||||
print(f" Range: [{loaded.min():.6f}, {loaded.max():.6f}]")
|
print(f" Range: [{loaded.min()}, {loaded.max()}]")
|
||||||
print(f" Unique values: {len(np.unique(loaded[:,:,0]))}")
|
print(f" Unique values: {len(np.unique(loaded[:,:,0]))}")
|
||||||
|
|
||||||
# Assertions
|
# Assertions
|
||||||
assert loaded.dtype == np.float32, f"Expected float32, got {loaded.dtype}"
|
assert loaded.dtype == np.uint16, f"Expected uint16, got {loaded.dtype}"
|
||||||
assert loaded.shape[2] == 3, f"Expected 3 channels, got {loaded.shape[2]}"
|
assert loaded.shape[2] == 3, f"Expected 3 channels, got {loaded.shape[2]}"
|
||||||
assert (
|
assert (
|
||||||
0.0 <= loaded.min() <= loaded.max() <= 1.0
|
loaded.min() >= 0 and loaded.max() <= 65535
|
||||||
), f"Expected [0,1] range, got [{loaded.min()}, {loaded.max()}]"
|
), f"Expected [0,65535] range, got [{loaded.min()}, {loaded.max()}]"
|
||||||
|
|
||||||
# Verify all channels are identical (replicated grayscale)
|
# Verify all channels are identical (replicated grayscale)
|
||||||
assert np.array_equal(
|
assert np.array_equal(
|
||||||
@@ -84,21 +110,20 @@ def test_float32_3ch_conversion():
|
|||||||
loaded[:, :, 0], loaded[:, :, 2]
|
loaded[:, :, 0], loaded[:, :, 2]
|
||||||
), "Channel 0 and 2 should be identical"
|
), "Channel 0 and 2 should be identical"
|
||||||
|
|
||||||
# Verify float32 precision (not quantized to uint8 steps)
|
# Verify no data loss
|
||||||
unique_vals = len(np.unique(loaded[:, :, 0]))
|
unique_vals = len(np.unique(loaded[:, :, 0]))
|
||||||
print(f"\n Precision check:")
|
print(f"\n Precision check:")
|
||||||
print(f" Unique values in channel: {unique_vals}")
|
print(f" Unique values in channel: {unique_vals}")
|
||||||
print(f" Source unique values: {len(np.unique(test_data))}")
|
print(f" Source unique values: {len(np.unique(test_data))}")
|
||||||
|
|
||||||
# The final unique values should match source (no loss from conversion)
|
|
||||||
assert unique_vals == len(
|
assert unique_vals == len(
|
||||||
np.unique(test_data)
|
np.unique(test_data)
|
||||||
), f"Expected {len(np.unique(test_data))} unique values, got {unique_vals}"
|
), f"Expected {len(np.unique(test_data))} unique values, got {unique_vals}"
|
||||||
|
|
||||||
print("\n✓ All conversion tests passed!")
|
print("\n✓ All conversion tests passed!")
|
||||||
print(" - Float32 dtype preserved")
|
print(" - uint16 dtype preserved")
|
||||||
print(" - 3 channels created")
|
print(" - 3 channels created")
|
||||||
print(" - Range [0-1] maintained")
|
print(" - Range [0-65535] maintained")
|
||||||
print(" - No precision loss from conversion")
|
print(" - No precision loss from conversion")
|
||||||
print(" - Channels properly replicated")
|
print(" - Channels properly replicated")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user