Update
This commit is contained in:
@@ -22,6 +22,7 @@ import sys
|
|||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -88,12 +89,15 @@ class Float32YOLODataset(Dataset):
|
|||||||
elif img.ndim == 3 and img.shape[2] == 1:
|
elif img.ndim == 3 and img.shape[2] == 1:
|
||||||
img = np.repeat(img, 3, axis=2)
|
img = np.repeat(img, 3, axis=2)
|
||||||
|
|
||||||
return img # float32 (H,W,3) [0,1]
|
# Resize to model input size
|
||||||
|
img = cv2.resize(img, (self.img_size, self.img_size))
|
||||||
|
|
||||||
def _parse_label(self, path: Path) -> np.ndarray:
|
return img # float32 (img_size, img_size, 3) [0,1] BGR
|
||||||
|
|
||||||
|
def _parse_label(self, path: Path) -> list:
|
||||||
"""Parse YOLO label with variable-length rows."""
|
"""Parse YOLO label with variable-length rows."""
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
return np.zeros((0, 5), dtype=np.float32)
|
return []
|
||||||
|
|
||||||
labels = []
|
labels = []
|
||||||
with open(path, "r") as f:
|
with open(path, "r") as f:
|
||||||
@@ -102,11 +106,7 @@ class Float32YOLODataset(Dataset):
|
|||||||
if len(vals) >= 5:
|
if len(vals) >= 5:
|
||||||
labels.append([float(v) for v in vals])
|
labels.append([float(v) for v in vals])
|
||||||
|
|
||||||
return (
|
return labels
|
||||||
np.array(labels, dtype=np.float32)
|
|
||||||
if labels
|
|
||||||
else np.zeros((0, 5), dtype=np.float32)
|
|
||||||
)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
img_path = self.paths[idx]
|
img_path = self.paths[idx]
|
||||||
@@ -144,8 +144,6 @@ def get_pytorch_model(ul_model):
|
|||||||
# Try common patterns
|
# Try common patterns
|
||||||
if hasattr(ul_model, "model"):
|
if hasattr(ul_model, "model"):
|
||||||
pt_model = ul_model.model
|
pt_model = ul_model.model
|
||||||
if pt_model and hasattr(pt_model, "model"):
|
|
||||||
pt_model = pt_model.model
|
|
||||||
|
|
||||||
# Find loss
|
# Find loss
|
||||||
if pt_model and hasattr(pt_model, "loss"):
|
if pt_model and hasattr(pt_model, "loss"):
|
||||||
@@ -194,6 +192,8 @@ def train(args):
|
|||||||
|
|
||||||
pt_model.to(device)
|
pt_model.to(device)
|
||||||
pt_model.train()
|
pt_model.train()
|
||||||
|
for param in pt_model.parameters():
|
||||||
|
param.requires_grad = True
|
||||||
|
|
||||||
# Create datasets
|
# Create datasets
|
||||||
train_ds = Float32YOLODataset(str(train_img), str(train_lbl), args.imgsz)
|
train_ds = Float32YOLODataset(str(train_img), str(train_lbl), args.imgsz)
|
||||||
@@ -254,11 +254,12 @@ def train(args):
|
|||||||
# This may fail - see logs
|
# This may fail - see logs
|
||||||
try:
|
try:
|
||||||
loss_out = loss_fn(preds, labels_list)
|
loss_out = loss_fn(preds, labels_list)
|
||||||
loss = (
|
if isinstance(loss_out, dict):
|
||||||
loss_out[0]
|
loss = loss_out["loss"]
|
||||||
if isinstance(loss_out, (tuple, list))
|
elif isinstance(loss_out, (tuple, list)):
|
||||||
else loss_out
|
loss = loss_out[0]
|
||||||
)
|
else:
|
||||||
|
loss = loss_out
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Loss computation failed: {e}")
|
logger.error(f"Loss computation failed: {e}")
|
||||||
logger.error(
|
logger.error(
|
||||||
@@ -273,6 +274,7 @@ def train(args):
|
|||||||
raise RuntimeError(f"Unexpected preds format: {type(preds)}")
|
raise RuntimeError(f"Unexpected preds format: {type(preds)}")
|
||||||
|
|
||||||
# Backward
|
# Backward
|
||||||
|
loss = loss.mean()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user