2Stage training fix

This commit is contained in:
2025-12-11 12:50:34 +02:00
parent c0684a9c14
commit 6b6d6fad03
2 changed files with 159 additions and 30 deletions

View File

@@ -18,7 +18,7 @@ models:
training:
default_epochs: 100
default_batch_size: 16
default_imgsz: 640
default_imgsz: 1024
default_patience: 50
default_lr0: 0.01
two_stage:

View File

@@ -68,6 +68,8 @@ class TrainingWorker(QThread):
save_dir: str,
run_name: str,
parent: Optional[QThread] = None,
stage_plan: Optional[List[Dict[str, Any]]] = None,
total_epochs: Optional[int] = None,
):
super().__init__(parent)
self.data_yaml = data_yaml
@@ -79,6 +81,27 @@ class TrainingWorker(QThread):
self.lr0 = lr0
self.save_dir = save_dir
self.run_name = run_name
self.stage_plan = stage_plan or [
{
"label": "Single Stage",
"model_path": base_model,
"use_previous_best": False,
"params": {
"epochs": epochs,
"batch": batch,
"imgsz": imgsz,
"patience": patience,
"lr0": lr0,
"freeze": 0,
"name": run_name,
},
}
]
computed_total = sum(
max(0, int((stage.get("params") or {}).get("epochs", 0)))
for stage in self.stage_plan
)
self.total_epochs = total_epochs if total_epochs else computed_total or epochs
self._stop_requested = False
def stop(self):
@@ -87,36 +110,98 @@ class TrainingWorker(QThread):
self.requestInterruption()
def run(self):
"""Execute YOLO training and emit progress/finished signals."""
wrapper = YOLOWrapper(self.base_model)
"""Execute YOLO training over one or more stages and emit progress/finished signals."""
def on_epoch_end(trainer):
completed_epochs = 0
stage_history: List[Dict[str, Any]] = []
last_stage_results: Optional[Dict[str, Any]] = None
for stage_index, stage in enumerate(self.stage_plan, start=1):
if self._stop_requested or self.isInterruptionRequested():
break
stage_label = stage.get("label") or f"Stage {stage_index}"
stage_params = dict(stage.get("params") or {})
stage_epochs = int(stage_params.get("epochs", self.epochs))
if stage_epochs <= 0:
stage_epochs = 1
batch = int(stage_params.get("batch", self.batch))
imgsz = int(stage_params.get("imgsz", self.imgsz))
patience = int(stage_params.get("patience", self.patience))
lr0 = float(stage_params.get("lr0", self.lr0))
freeze = int(stage_params.get("freeze", 0))
run_name = stage_params.get("name") or f"{self.run_name}_stage{stage_index}"
weights_path = stage.get("model_path") or self.base_model
if stage.get("use_previous_best") and last_stage_results:
weights_path = (
last_stage_results.get("best_model_path")
or last_stage_results.get("last_model_path")
or weights_path
)
wrapper = YOLOWrapper(weights_path)
stage_offset = completed_epochs
def on_epoch_end(trainer, offset=stage_offset):
current_epoch = getattr(trainer, "epoch", 0) + 1
metrics: Dict[str, float] = {}
loss_items = getattr(trainer, "loss_items", None)
if loss_items:
metrics["loss"] = float(loss_items[-1])
self.progress.emit(current_epoch, self.epochs, metrics)
absolute_epoch = min(
max(1, offset + current_epoch),
max(1, self.total_epochs),
)
self.progress.emit(absolute_epoch, self.total_epochs, metrics)
if self.isInterruptionRequested() or self._stop_requested:
setattr(trainer, "stop_training", True)
callbacks = {"on_fit_epoch_end": on_epoch_end}
try:
results = wrapper.train(
stage_result = wrapper.train(
data_yaml=self.data_yaml,
epochs=self.epochs,
imgsz=self.imgsz,
batch=self.batch,
patience=self.patience,
epochs=stage_epochs,
imgsz=imgsz,
batch=batch,
patience=patience,
save_dir=self.save_dir,
name=self.run_name,
lr0=self.lr0,
name=run_name,
lr0=lr0,
callbacks=callbacks,
freeze=freeze,
)
self.finished.emit(results)
except Exception as exc:
self.error.emit(str(exc))
return
stage_history.append(
{
"label": stage_label,
"params": stage_params,
"weights_used": weights_path,
"results": stage_result,
}
)
last_stage_results = stage_result
completed_epochs += stage_epochs
final_payload: Dict[str, Any]
if last_stage_results:
final_payload = dict(last_stage_results)
else:
final_payload = {
"success": False,
"message": "Training stopped before any stage completed.",
}
final_payload["stage_results"] = stage_history
final_payload["total_epochs_completed"] = completed_epochs
final_payload["total_epochs_planned"] = self.total_epochs
final_payload["stages_completed"] = len(stage_history)
self.finished.emit(final_payload)
class TrainingTab(QWidget):
@@ -1162,7 +1247,7 @@ class TrainingTab(QWidget):
"imgsz": params["imgsz"],
"patience": stage2.get("patience", params["patience"]),
"lr0": stage2.get("lr0", params["lr0"]),
"freeze": 0,
"freeze": stage2.get("freeze", 0),
"name": f"{params['run_name']}_full_ft",
},
}
@@ -1170,6 +1255,30 @@ class TrainingTab(QWidget):
return stage_plan
def _calculate_total_stage_epochs(self, stage_plan: List[Dict[str, Any]]) -> int:
total = 0
for stage in stage_plan:
params = stage.get("params") or {}
try:
stage_epochs = int(params.get("epochs", 0))
except (TypeError, ValueError):
stage_epochs = 0
if stage_epochs > 0:
total += stage_epochs
return total
def _log_stage_plan(self, stage_plan: List[Dict[str, Any]]):
for index, stage in enumerate(stage_plan, start=1):
stage_label = stage.get("label") or f"Stage {index}"
params = stage.get("params") or {}
epochs = params.get("epochs", "?")
lr0 = params.get("lr0", "?")
patience = params.get("patience", "?")
freeze = params.get("freeze", 0)
self._append_training_log(
f"{stage_label}: epochs={epochs}, lr0={lr0}, patience={patience}, freeze={freeze}"
)
def _get_rgb_cache_root(self, dataset_yaml: Path) -> Path:
cache_base = Path("data/datasets/_rgb_cache")
cache_base.mkdir(parents=True, exist_ok=True)
@@ -1359,15 +1468,25 @@ class TrainingTab(QWidget):
)
params = self._collect_training_params()
stage_plan = self._compose_stage_plan(params)
params["stage_plan"] = stage_plan
total_planned_epochs = (
self._calculate_total_stage_epochs(stage_plan) or params["epochs"]
)
params["total_planned_epochs"] = total_planned_epochs
self._active_training_params = params
self._training_cancelled = False
if len(stage_plan) > 1:
self._append_training_log("Two-stage fine-tuning schedule:")
self._log_stage_plan(stage_plan)
self._append_training_log(
f"Starting training run '{params['run_name']}' using {params['base_model']}"
)
self.training_progress_bar.setVisible(True)
self.training_progress_bar.setMaximum(params["epochs"])
self.training_progress_bar.setMaximum(max(1, total_planned_epochs))
self.training_progress_bar.setValue(0)
self._set_training_state(True)
@@ -1381,6 +1500,8 @@ class TrainingTab(QWidget):
lr0=params["lr0"],
save_dir=params["save_dir"],
run_name=params["run_name"],
stage_plan=stage_plan,
total_epochs=total_planned_epochs,
)
self.training_worker.progress.connect(self._on_training_progress)
self.training_worker.finished.connect(self._on_training_finished)
@@ -1505,14 +1626,22 @@ class TrainingTab(QWidget):
if not model_path:
raise ValueError("Training results did not include a model path.")
effective_epochs = params.get("total_planned_epochs", params["epochs"])
training_params = {
"epochs": params["epochs"],
"epochs": effective_epochs,
"batch": params["batch"],
"imgsz": params["imgsz"],
"patience": params["patience"],
"lr0": params["lr0"],
"run_name": params["run_name"],
"two_stage": params.get("two_stage"),
}
if params.get("stage_plan"):
training_params["stage_plan"] = params["stage_plan"]
if results.get("stage_results"):
training_params["stage_results"] = results["stage_results"]
if results.get("total_epochs_completed") is not None:
training_params["epochs_completed"] = results["total_epochs_completed"]
model_id = self.db_manager.add_model(
model_name=params["model_name"],