Update
This commit is contained in:
@@ -22,6 +22,7 @@ import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -88,12 +89,15 @@ class Float32YOLODataset(Dataset):
|
||||
elif img.ndim == 3 and img.shape[2] == 1:
|
||||
img = np.repeat(img, 3, axis=2)
|
||||
|
||||
return img # float32 (H,W,3) [0,1]
|
||||
# Resize to model input size
|
||||
img = cv2.resize(img, (self.img_size, self.img_size))
|
||||
|
||||
def _parse_label(self, path: Path) -> np.ndarray:
|
||||
return img # float32 (img_size, img_size, 3) [0,1] BGR
|
||||
|
||||
def _parse_label(self, path: Path) -> list:
|
||||
"""Parse YOLO label with variable-length rows."""
|
||||
if not path.exists():
|
||||
return np.zeros((0, 5), dtype=np.float32)
|
||||
return []
|
||||
|
||||
labels = []
|
||||
with open(path, "r") as f:
|
||||
@@ -102,11 +106,7 @@ class Float32YOLODataset(Dataset):
|
||||
if len(vals) >= 5:
|
||||
labels.append([float(v) for v in vals])
|
||||
|
||||
return (
|
||||
np.array(labels, dtype=np.float32)
|
||||
if labels
|
||||
else np.zeros((0, 5), dtype=np.float32)
|
||||
)
|
||||
return labels
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_path = self.paths[idx]
|
||||
@@ -144,8 +144,6 @@ def get_pytorch_model(ul_model):
|
||||
# Try common patterns
|
||||
if hasattr(ul_model, "model"):
|
||||
pt_model = ul_model.model
|
||||
if pt_model and hasattr(pt_model, "model"):
|
||||
pt_model = pt_model.model
|
||||
|
||||
# Find loss
|
||||
if pt_model and hasattr(pt_model, "loss"):
|
||||
@@ -194,6 +192,8 @@ def train(args):
|
||||
|
||||
pt_model.to(device)
|
||||
pt_model.train()
|
||||
for param in pt_model.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
# Create datasets
|
||||
train_ds = Float32YOLODataset(str(train_img), str(train_lbl), args.imgsz)
|
||||
@@ -254,11 +254,12 @@ def train(args):
|
||||
# This may fail - see logs
|
||||
try:
|
||||
loss_out = loss_fn(preds, labels_list)
|
||||
loss = (
|
||||
loss_out[0]
|
||||
if isinstance(loss_out, (tuple, list))
|
||||
else loss_out
|
||||
)
|
||||
if isinstance(loss_out, dict):
|
||||
loss = loss_out["loss"]
|
||||
elif isinstance(loss_out, (tuple, list)):
|
||||
loss = loss_out[0]
|
||||
else:
|
||||
loss = loss_out
|
||||
except Exception as e:
|
||||
logger.error(f"Loss computation failed: {e}")
|
||||
logger.error(
|
||||
@@ -273,6 +274,7 @@ def train(args):
|
||||
raise RuntimeError(f"Unexpected preds format: {type(preds)}")
|
||||
|
||||
# Backward
|
||||
loss = loss.mean()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user