diff --git a/scripts/train_float32_standalone.py b/scripts/train_float32_standalone.py index 8c3cf67..8714fe9 100755 --- a/scripts/train_float32_standalone.py +++ b/scripts/train_float32_standalone.py @@ -22,6 +22,7 @@ import sys import time from pathlib import Path +import cv2 import numpy as np import torch import torch.nn as nn @@ -88,12 +89,15 @@ class Float32YOLODataset(Dataset): elif img.ndim == 3 and img.shape[2] == 1: img = np.repeat(img, 3, axis=2) - return img # float32 (H,W,3) [0,1] + # Resize to model input size + img = cv2.resize(img, (self.img_size, self.img_size)) - def _parse_label(self, path: Path) -> np.ndarray: + return img # float32 (img_size, img_size, 3) [0,1] BGR + + def _parse_label(self, path: Path) -> list: """Parse YOLO label with variable-length rows.""" if not path.exists(): - return np.zeros((0, 5), dtype=np.float32) + return [] labels = [] with open(path, "r") as f: @@ -102,11 +106,7 @@ class Float32YOLODataset(Dataset): if len(vals) >= 5: labels.append([float(v) for v in vals]) - return ( - np.array(labels, dtype=np.float32) - if labels - else np.zeros((0, 5), dtype=np.float32) - ) + return labels def __getitem__(self, idx): img_path = self.paths[idx] @@ -144,8 +144,6 @@ def get_pytorch_model(ul_model): # Try common patterns if hasattr(ul_model, "model"): pt_model = ul_model.model - if pt_model and hasattr(pt_model, "model"): - pt_model = pt_model.model # Find loss if pt_model and hasattr(pt_model, "loss"): @@ -194,6 +192,8 @@ def train(args): pt_model.to(device) pt_model.train() + for param in pt_model.parameters(): + param.requires_grad = True # Create datasets train_ds = Float32YOLODataset(str(train_img), str(train_lbl), args.imgsz) @@ -254,11 +254,12 @@ def train(args): # This may fail - see logs try: loss_out = loss_fn(preds, labels_list) - loss = ( - loss_out[0] - if isinstance(loss_out, (tuple, list)) - else loss_out - ) + if isinstance(loss_out, dict): + loss = loss_out["loss"] + elif isinstance(loss_out, (tuple, list)): + loss = loss_out[0] + else: + loss = loss_out except Exception as e: logger.error(f"Loss computation failed: {e}") logger.error( @@ -273,6 +274,7 @@ def train(args): raise RuntimeError(f"Unexpected preds format: {type(preds)}") # Backward + loss = loss.mean() loss.backward() optimizer.step()