#!/usr/bin/env python3 """ Standalone training script for YOLO with 16-bit TIFF float32 support. This script trains YOLO models on 16-bit grayscale TIFF datasets without data loss. Converts images to float32 [0-1] on-the-fly using tifffile (no PIL/cv2). Usage: python scripts/train_float32_standalone.py \\ --data path/to/data.yaml \\ --weights yolov8s-seg.pt \\ --epochs 100 \\ --batch 16 \\ --imgsz 640 Based on the custom dataset approach to avoid Ultralytics' channel conversion issues. """ import argparse import os import sys import time from pathlib import Path import numpy as np import torch import torch.nn as nn import tifffile import yaml from torch.utils.data import Dataset, DataLoader from ultralytics import YOLO # Add project root to path project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) from src.utils.logger import get_logger logger = get_logger(__name__) # ===================== Dataset ===================== class Float32YOLODataset(Dataset): """PyTorch dataset for 16-bit TIFF images with float32 conversion.""" def __init__(self, images_dir, labels_dir, img_size=640): self.images_dir = Path(images_dir) self.labels_dir = Path(labels_dir) self.img_size = img_size # Find images extensions = {".tif", ".tiff", ".png", ".jpg", ".jpeg", ".bmp"} self.paths = sorted( [ p for p in self.images_dir.rglob("*") if p.is_file() and p.suffix.lower() in extensions ] ) if not self.paths: raise ValueError(f"No images found in {images_dir}") logger.info(f"Dataset: {len(self.paths)} images from {images_dir}") def __len__(self): return len(self.paths) def _read_image(self, path: Path) -> np.ndarray: """Load image as float32 [0-1] RGB.""" # Load with tifffile img = tifffile.imread(str(path)) # Convert to float32 img = img.astype(np.float32) # Normalize 16-bit→[0,1] if img.max() > 1.5: img = img / 65535.0 img = np.clip(img, 0.0, 1.0) # Grayscale→RGB if img.ndim == 2: img = np.repeat(img[..., None], 3, axis=2) elif img.ndim == 3 and img.shape[2] == 1: img = np.repeat(img, 3, axis=2) return img # float32 (H,W,3) [0,1] def _parse_label(self, path: Path) -> np.ndarray: """Parse YOLO label with variable-length rows.""" if not path.exists(): return np.zeros((0, 5), dtype=np.float32) labels = [] with open(path, "r") as f: for line in f: vals = line.strip().split() if len(vals) >= 5: labels.append([float(v) for v in vals]) return ( np.array(labels, dtype=np.float32) if labels else np.zeros((0, 5), dtype=np.float32) ) def __getitem__(self, idx): img_path = self.paths[idx] label_path = self.labels_dir / f"{img_path.stem}.txt" # Load & convert to tensor (C,H,W) img = self._read_image(img_path) img_t = torch.from_numpy(img).permute(2, 0, 1).contiguous() # Load labels labels = self._parse_label(label_path) return img_t, labels, str(img_path.name) # ===================== Collate ===================== def collate_fn(batch): """Stack images, keep labels as list.""" imgs = torch.stack([b[0] for b in batch], dim=0) labels = [b[1] for b in batch] names = [b[2] for b in batch] return imgs, labels, names # ===================== Training ===================== def get_pytorch_model(ul_model): """Extract PyTorch model and loss from Ultralytics wrapper.""" pt_model = None loss_fn = None # Try common patterns if hasattr(ul_model, "model"): pt_model = ul_model.model if pt_model and hasattr(pt_model, "model"): pt_model = pt_model.model # Find loss if pt_model and hasattr(pt_model, "loss"): loss_fn = pt_model.loss elif pt_model and hasattr(pt_model, "compute_loss"): loss_fn = pt_model.compute_loss if pt_model is None: raise RuntimeError("Could not extract PyTorch model") return pt_model, loss_fn def train(args): """Main training function.""" device = args.device logger.info(f"Device: {device}") # Parse data.yaml with open(args.data, "r") as f: data_config = yaml.safe_load(f) dataset_root = Path(data_config.get("path", Path(args.data).parent)) train_img = dataset_root / data_config.get("train", "train/images") val_img = dataset_root / data_config.get("val", "val/images") train_lbl = train_img.parent / "labels" val_lbl = val_img.parent / "labels" # Load model logger.info(f"Loading {args.weights}") ul_model = YOLO(args.weights) pt_model, loss_fn = get_pytorch_model(ul_model) # Configure model args from types import SimpleNamespace if not hasattr(pt_model, "args"): pt_model.args = SimpleNamespace() if isinstance(pt_model.args, dict): pt_model.args = SimpleNamespace(**pt_model.args) # Set segmentation loss args pt_model.args.overlap_mask = getattr(pt_model.args, "overlap_mask", True) pt_model.args.mask_ratio = getattr(pt_model.args, "mask_ratio", 4) pt_model.args.task = "segment" pt_model.to(device) pt_model.train() # Create datasets train_ds = Float32YOLODataset(str(train_img), str(train_lbl), args.imgsz) val_ds = Float32YOLODataset(str(val_img), str(val_lbl), args.imgsz) train_loader = DataLoader( train_ds, batch_size=args.batch, shuffle=True, num_workers=4, pin_memory=(device == "cuda"), collate_fn=collate_fn, ) val_loader = DataLoader( val_ds, batch_size=args.batch, shuffle=False, num_workers=2, pin_memory=(device == "cuda"), collate_fn=collate_fn, ) # Optimizer optimizer = torch.optim.AdamW(pt_model.parameters(), lr=args.lr) # Training loop os.makedirs(args.save_dir, exist_ok=True) best_loss = float("inf") for epoch in range(args.epochs): t0 = time.time() running_loss = 0.0 num_batches = 0 for imgs, labels_list, names in train_loader: imgs = imgs.to(device) optimizer.zero_grad() num_batches += 1 # Forward (simple approach - just use preds) preds = pt_model(imgs) # Try to compute loss # Simplest fallback: if preds is tuple/list, assume last element is loss if isinstance(preds, (tuple, list)): # Often YOLO forward returns (preds, loss) in training mode if ( len(preds) >= 2 and isinstance(preds[-1], dict) and "loss" in preds[-1] ): loss = preds[-1]["loss"] elif len(preds) >= 2 and isinstance(preds[-1], torch.Tensor): loss = preds[-1] else: # Manually compute using loss_fn if available if loss_fn: # This may fail - see logs try: loss_out = loss_fn(preds, labels_list) loss = ( loss_out[0] if isinstance(loss_out, (tuple, list)) else loss_out ) except Exception as e: logger.error(f"Loss computation failed: {e}") logger.error( "Consider using Ultralytics .train() or check model/loss compatibility" ) raise else: raise RuntimeError("Cannot determine loss from model output") elif isinstance(preds, dict) and "loss" in preds: loss = preds["loss"] else: raise RuntimeError(f"Unexpected preds format: {type(preds)}") # Backward loss.backward() optimizer.step() running_loss += loss.item() if (num_batches % 10) == 0: logger.info( f"Epoch {epoch+1} Batch {num_batches} Loss: {loss.item():.4f}" ) epoch_loss = running_loss / max(1, num_batches) epoch_time = time.time() - t0 logger.info( f"Epoch {epoch+1}/{args.epochs} - Loss: {epoch_loss:.4f}, Time: {epoch_time:.1f}s" ) # Save checkpoint ckpt = Path(args.save_dir) / f"epoch{epoch+1}.pt" torch.save( { "epoch": epoch + 1, "model_state_dict": pt_model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss": epoch_loss, }, ckpt, ) # Save best if epoch_loss < best_loss: best_loss = epoch_loss best_ckpt = Path(args.save_dir) / "best.pt" torch.save(pt_model.state_dict(), best_ckpt) logger.info(f"New best: {best_ckpt}") logger.info("Training complete") # ===================== Main ===================== def parse_args(): parser = argparse.ArgumentParser( description="Train YOLO on 16-bit TIFF with float32" ) parser.add_argument("--data", type=str, required=True, help="Path to data.yaml") parser.add_argument( "--weights", type=str, default="yolov8s-seg.pt", help="Pretrained weights" ) parser.add_argument("--epochs", type=int, default=100, help="Number of epochs") parser.add_argument("--batch", type=int, default=16, help="Batch size") parser.add_argument("--imgsz", type=int, default=640, help="Image size") parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") parser.add_argument( "--save-dir", type=str, default="runs/train", help="Save directory" ) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" ) return parser.parse_args() if __name__ == "__main__": args = parse_args() logger.info("=" * 70) logger.info("Float32 16-bit TIFF Training - Standalone Script") logger.info("=" * 70) logger.info(f"Data: {args.data}") logger.info(f"Weights: {args.weights}") logger.info(f"Epochs: {args.epochs}, Batch: {args.batch}, ImgSz: {args.imgsz}") logger.info(f"LR: {args.lr}, Device: {args.device}") logger.info("=" * 70) train(args)