This commit is contained in:
2025-12-16 13:25:20 +02:00
parent aec0fbf83c
commit 2dbfa54256

View File

@@ -22,6 +22,7 @@ import sys
import time import time
from pathlib import Path from pathlib import Path
import cv2
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -88,12 +89,15 @@ class Float32YOLODataset(Dataset):
elif img.ndim == 3 and img.shape[2] == 1: elif img.ndim == 3 and img.shape[2] == 1:
img = np.repeat(img, 3, axis=2) img = np.repeat(img, 3, axis=2)
return img # float32 (H,W,3) [0,1] # Resize to model input size
img = cv2.resize(img, (self.img_size, self.img_size))
def _parse_label(self, path: Path) -> np.ndarray: return img # float32 (img_size, img_size, 3) [0,1] BGR
def _parse_label(self, path: Path) -> list:
"""Parse YOLO label with variable-length rows.""" """Parse YOLO label with variable-length rows."""
if not path.exists(): if not path.exists():
return np.zeros((0, 5), dtype=np.float32) return []
labels = [] labels = []
with open(path, "r") as f: with open(path, "r") as f:
@@ -102,11 +106,7 @@ class Float32YOLODataset(Dataset):
if len(vals) >= 5: if len(vals) >= 5:
labels.append([float(v) for v in vals]) labels.append([float(v) for v in vals])
return ( return labels
np.array(labels, dtype=np.float32)
if labels
else np.zeros((0, 5), dtype=np.float32)
)
def __getitem__(self, idx): def __getitem__(self, idx):
img_path = self.paths[idx] img_path = self.paths[idx]
@@ -144,8 +144,6 @@ def get_pytorch_model(ul_model):
# Try common patterns # Try common patterns
if hasattr(ul_model, "model"): if hasattr(ul_model, "model"):
pt_model = ul_model.model pt_model = ul_model.model
if pt_model and hasattr(pt_model, "model"):
pt_model = pt_model.model
# Find loss # Find loss
if pt_model and hasattr(pt_model, "loss"): if pt_model and hasattr(pt_model, "loss"):
@@ -194,6 +192,8 @@ def train(args):
pt_model.to(device) pt_model.to(device)
pt_model.train() pt_model.train()
for param in pt_model.parameters():
param.requires_grad = True
# Create datasets # Create datasets
train_ds = Float32YOLODataset(str(train_img), str(train_lbl), args.imgsz) train_ds = Float32YOLODataset(str(train_img), str(train_lbl), args.imgsz)
@@ -254,11 +254,12 @@ def train(args):
# This may fail - see logs # This may fail - see logs
try: try:
loss_out = loss_fn(preds, labels_list) loss_out = loss_fn(preds, labels_list)
loss = ( if isinstance(loss_out, dict):
loss_out[0] loss = loss_out["loss"]
if isinstance(loss_out, (tuple, list)) elif isinstance(loss_out, (tuple, list)):
else loss_out loss = loss_out[0]
) else:
loss = loss_out
except Exception as e: except Exception as e:
logger.error(f"Loss computation failed: {e}") logger.error(f"Loss computation failed: {e}")
logger.error( logger.error(
@@ -273,6 +274,7 @@ def train(args):
raise RuntimeError(f"Unexpected preds format: {type(preds)}") raise RuntimeError(f"Unexpected preds format: {type(preds)}")
# Backward # Backward
loss = loss.mean()
loss.backward() loss.backward()
optimizer.step() optimizer.step()