This commit is contained in:
2025-12-16 13:25:20 +02:00
parent aec0fbf83c
commit 2dbfa54256

View File

@@ -22,6 +22,7 @@ import sys
import time
from pathlib import Path
import cv2
import numpy as np
import torch
import torch.nn as nn
@@ -88,12 +89,15 @@ class Float32YOLODataset(Dataset):
elif img.ndim == 3 and img.shape[2] == 1:
img = np.repeat(img, 3, axis=2)
return img # float32 (H,W,3) [0,1]
# Resize to model input size
img = cv2.resize(img, (self.img_size, self.img_size))
def _parse_label(self, path: Path) -> np.ndarray:
return img # float32 (img_size, img_size, 3) [0,1] BGR
def _parse_label(self, path: Path) -> list:
"""Parse YOLO label with variable-length rows."""
if not path.exists():
return np.zeros((0, 5), dtype=np.float32)
return []
labels = []
with open(path, "r") as f:
@@ -102,11 +106,7 @@ class Float32YOLODataset(Dataset):
if len(vals) >= 5:
labels.append([float(v) for v in vals])
return (
np.array(labels, dtype=np.float32)
if labels
else np.zeros((0, 5), dtype=np.float32)
)
return labels
def __getitem__(self, idx):
img_path = self.paths[idx]
@@ -144,8 +144,6 @@ def get_pytorch_model(ul_model):
# Try common patterns
if hasattr(ul_model, "model"):
pt_model = ul_model.model
if pt_model and hasattr(pt_model, "model"):
pt_model = pt_model.model
# Find loss
if pt_model and hasattr(pt_model, "loss"):
@@ -194,6 +192,8 @@ def train(args):
pt_model.to(device)
pt_model.train()
for param in pt_model.parameters():
param.requires_grad = True
# Create datasets
train_ds = Float32YOLODataset(str(train_img), str(train_lbl), args.imgsz)
@@ -254,11 +254,12 @@ def train(args):
# This may fail - see logs
try:
loss_out = loss_fn(preds, labels_list)
loss = (
loss_out[0]
if isinstance(loss_out, (tuple, list))
else loss_out
)
if isinstance(loss_out, dict):
loss = loss_out["loss"]
elif isinstance(loss_out, (tuple, list)):
loss = loss_out[0]
else:
loss = loss_out
except Exception as e:
logger.error(f"Loss computation failed: {e}")
logger.error(
@@ -273,6 +274,7 @@ def train(args):
raise RuntimeError(f"Unexpected preds format: {type(preds)}")
# Backward
loss = loss.mean()
loss.backward()
optimizer.step()