Files
object-segmentation/scripts/train_float32_standalone.py
2025-12-16 13:25:20 +02:00

352 lines
11 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Standalone training script for YOLO with 16-bit TIFF float32 support.
This script trains YOLO models on 16-bit grayscale TIFF datasets without data loss.
Converts images to float32 [0-1] on-the-fly using tifffile (no PIL/cv2).
Usage:
python scripts/train_float32_standalone.py \\
--data path/to/data.yaml \\
--weights yolov8s-seg.pt \\
--epochs 100 \\
--batch 16 \\
--imgsz 640
Based on the custom dataset approach to avoid Ultralytics' channel conversion issues.
"""
import argparse
import os
import sys
import time
from pathlib import Path
import cv2
import numpy as np
import torch
import torch.nn as nn
import tifffile
import yaml
from torch.utils.data import Dataset, DataLoader
from ultralytics import YOLO
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from src.utils.logger import get_logger
logger = get_logger(__name__)
# ===================== Dataset =====================
class Float32YOLODataset(Dataset):
"""PyTorch dataset for 16-bit TIFF images with float32 conversion."""
def __init__(self, images_dir, labels_dir, img_size=640):
self.images_dir = Path(images_dir)
self.labels_dir = Path(labels_dir)
self.img_size = img_size
# Find images
extensions = {".tif", ".tiff", ".png", ".jpg", ".jpeg", ".bmp"}
self.paths = sorted(
[
p
for p in self.images_dir.rglob("*")
if p.is_file() and p.suffix.lower() in extensions
]
)
if not self.paths:
raise ValueError(f"No images found in {images_dir}")
logger.info(f"Dataset: {len(self.paths)} images from {images_dir}")
def __len__(self):
return len(self.paths)
def _read_image(self, path: Path) -> np.ndarray:
"""Load image as float32 [0-1] RGB."""
# Load with tifffile
img = tifffile.imread(str(path))
# Convert to float32
img = img.astype(np.float32)
# Normalize 16-bit→[0,1]
if img.max() > 1.5:
img = img / 65535.0
img = np.clip(img, 0.0, 1.0)
# Grayscale→RGB
if img.ndim == 2:
img = np.repeat(img[..., None], 3, axis=2)
elif img.ndim == 3 and img.shape[2] == 1:
img = np.repeat(img, 3, axis=2)
# Resize to model input size
img = cv2.resize(img, (self.img_size, self.img_size))
return img # float32 (img_size, img_size, 3) [0,1] BGR
def _parse_label(self, path: Path) -> list:
"""Parse YOLO label with variable-length rows."""
if not path.exists():
return []
labels = []
with open(path, "r") as f:
for line in f:
vals = line.strip().split()
if len(vals) >= 5:
labels.append([float(v) for v in vals])
return labels
def __getitem__(self, idx):
img_path = self.paths[idx]
label_path = self.labels_dir / f"{img_path.stem}.txt"
# Load & convert to tensor (C,H,W)
img = self._read_image(img_path)
img_t = torch.from_numpy(img).permute(2, 0, 1).contiguous()
# Load labels
labels = self._parse_label(label_path)
return img_t, labels, str(img_path.name)
# ===================== Collate =====================
def collate_fn(batch):
"""Stack images, keep labels as list."""
imgs = torch.stack([b[0] for b in batch], dim=0)
labels = [b[1] for b in batch]
names = [b[2] for b in batch]
return imgs, labels, names
# ===================== Training =====================
def get_pytorch_model(ul_model):
"""Extract PyTorch model and loss from Ultralytics wrapper."""
pt_model = None
loss_fn = None
# Try common patterns
if hasattr(ul_model, "model"):
pt_model = ul_model.model
# Find loss
if pt_model and hasattr(pt_model, "loss"):
loss_fn = pt_model.loss
elif pt_model and hasattr(pt_model, "compute_loss"):
loss_fn = pt_model.compute_loss
if pt_model is None:
raise RuntimeError("Could not extract PyTorch model")
return pt_model, loss_fn
def train(args):
"""Main training function."""
device = args.device
logger.info(f"Device: {device}")
# Parse data.yaml
with open(args.data, "r") as f:
data_config = yaml.safe_load(f)
dataset_root = Path(data_config.get("path", Path(args.data).parent))
train_img = dataset_root / data_config.get("train", "train/images")
val_img = dataset_root / data_config.get("val", "val/images")
train_lbl = train_img.parent / "labels"
val_lbl = val_img.parent / "labels"
# Load model
logger.info(f"Loading {args.weights}")
ul_model = YOLO(args.weights)
pt_model, loss_fn = get_pytorch_model(ul_model)
# Configure model args
from types import SimpleNamespace
if not hasattr(pt_model, "args"):
pt_model.args = SimpleNamespace()
if isinstance(pt_model.args, dict):
pt_model.args = SimpleNamespace(**pt_model.args)
# Set segmentation loss args
pt_model.args.overlap_mask = getattr(pt_model.args, "overlap_mask", True)
pt_model.args.mask_ratio = getattr(pt_model.args, "mask_ratio", 4)
pt_model.args.task = "segment"
pt_model.to(device)
pt_model.train()
for param in pt_model.parameters():
param.requires_grad = True
# Create datasets
train_ds = Float32YOLODataset(str(train_img), str(train_lbl), args.imgsz)
val_ds = Float32YOLODataset(str(val_img), str(val_lbl), args.imgsz)
train_loader = DataLoader(
train_ds,
batch_size=args.batch,
shuffle=True,
num_workers=4,
pin_memory=(device == "cuda"),
collate_fn=collate_fn,
)
val_loader = DataLoader(
val_ds,
batch_size=args.batch,
shuffle=False,
num_workers=2,
pin_memory=(device == "cuda"),
collate_fn=collate_fn,
)
# Optimizer
optimizer = torch.optim.AdamW(pt_model.parameters(), lr=args.lr)
# Training loop
os.makedirs(args.save_dir, exist_ok=True)
best_loss = float("inf")
for epoch in range(args.epochs):
t0 = time.time()
running_loss = 0.0
num_batches = 0
for imgs, labels_list, names in train_loader:
imgs = imgs.to(device)
optimizer.zero_grad()
num_batches += 1
# Forward (simple approach - just use preds)
preds = pt_model(imgs)
# Try to compute loss
# Simplest fallback: if preds is tuple/list, assume last element is loss
if isinstance(preds, (tuple, list)):
# Often YOLO forward returns (preds, loss) in training mode
if (
len(preds) >= 2
and isinstance(preds[-1], dict)
and "loss" in preds[-1]
):
loss = preds[-1]["loss"]
elif len(preds) >= 2 and isinstance(preds[-1], torch.Tensor):
loss = preds[-1]
else:
# Manually compute using loss_fn if available
if loss_fn:
# This may fail - see logs
try:
loss_out = loss_fn(preds, labels_list)
if isinstance(loss_out, dict):
loss = loss_out["loss"]
elif isinstance(loss_out, (tuple, list)):
loss = loss_out[0]
else:
loss = loss_out
except Exception as e:
logger.error(f"Loss computation failed: {e}")
logger.error(
"Consider using Ultralytics .train() or check model/loss compatibility"
)
raise
else:
raise RuntimeError("Cannot determine loss from model output")
elif isinstance(preds, dict) and "loss" in preds:
loss = preds["loss"]
else:
raise RuntimeError(f"Unexpected preds format: {type(preds)}")
# Backward
loss = loss.mean()
loss.backward()
optimizer.step()
running_loss += loss.item()
if (num_batches % 10) == 0:
logger.info(
f"Epoch {epoch+1} Batch {num_batches} Loss: {loss.item():.4f}"
)
epoch_loss = running_loss / max(1, num_batches)
epoch_time = time.time() - t0
logger.info(
f"Epoch {epoch+1}/{args.epochs} - Loss: {epoch_loss:.4f}, Time: {epoch_time:.1f}s"
)
# Save checkpoint
ckpt = Path(args.save_dir) / f"epoch{epoch+1}.pt"
torch.save(
{
"epoch": epoch + 1,
"model_state_dict": pt_model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": epoch_loss,
},
ckpt,
)
# Save best
if epoch_loss < best_loss:
best_loss = epoch_loss
best_ckpt = Path(args.save_dir) / "best.pt"
torch.save(pt_model.state_dict(), best_ckpt)
logger.info(f"New best: {best_ckpt}")
logger.info("Training complete")
# ===================== Main =====================
def parse_args():
parser = argparse.ArgumentParser(
description="Train YOLO on 16-bit TIFF with float32"
)
parser.add_argument("--data", type=str, required=True, help="Path to data.yaml")
parser.add_argument(
"--weights", type=str, default="yolov8s-seg.pt", help="Pretrained weights"
)
parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
parser.add_argument("--batch", type=int, default=16, help="Batch size")
parser.add_argument("--imgsz", type=int, default=640, help="Image size")
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
parser.add_argument(
"--save-dir", type=str, default="runs/train", help="Save directory"
)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
logger.info("=" * 70)
logger.info("Float32 16-bit TIFF Training - Standalone Script")
logger.info("=" * 70)
logger.info(f"Data: {args.data}")
logger.info(f"Weights: {args.weights}")
logger.info(f"Epochs: {args.epochs}, Batch: {args.batch}, ImgSz: {args.imgsz}")
logger.info(f"LR: {args.lr}, Device: {args.device}")
logger.info("=" * 70)
train(args)