2Stage training fix
This commit is contained in:
@@ -18,7 +18,7 @@ models:
|
||||
training:
|
||||
default_epochs: 100
|
||||
default_batch_size: 16
|
||||
default_imgsz: 640
|
||||
default_imgsz: 1024
|
||||
default_patience: 50
|
||||
default_lr0: 0.01
|
||||
two_stage:
|
||||
|
||||
@@ -68,6 +68,8 @@ class TrainingWorker(QThread):
|
||||
save_dir: str,
|
||||
run_name: str,
|
||||
parent: Optional[QThread] = None,
|
||||
stage_plan: Optional[List[Dict[str, Any]]] = None,
|
||||
total_epochs: Optional[int] = None,
|
||||
):
|
||||
super().__init__(parent)
|
||||
self.data_yaml = data_yaml
|
||||
@@ -79,6 +81,27 @@ class TrainingWorker(QThread):
|
||||
self.lr0 = lr0
|
||||
self.save_dir = save_dir
|
||||
self.run_name = run_name
|
||||
self.stage_plan = stage_plan or [
|
||||
{
|
||||
"label": "Single Stage",
|
||||
"model_path": base_model,
|
||||
"use_previous_best": False,
|
||||
"params": {
|
||||
"epochs": epochs,
|
||||
"batch": batch,
|
||||
"imgsz": imgsz,
|
||||
"patience": patience,
|
||||
"lr0": lr0,
|
||||
"freeze": 0,
|
||||
"name": run_name,
|
||||
},
|
||||
}
|
||||
]
|
||||
computed_total = sum(
|
||||
max(0, int((stage.get("params") or {}).get("epochs", 0)))
|
||||
for stage in self.stage_plan
|
||||
)
|
||||
self.total_epochs = total_epochs if total_epochs else computed_total or epochs
|
||||
self._stop_requested = False
|
||||
|
||||
def stop(self):
|
||||
@@ -87,36 +110,98 @@ class TrainingWorker(QThread):
|
||||
self.requestInterruption()
|
||||
|
||||
def run(self):
|
||||
"""Execute YOLO training and emit progress/finished signals."""
|
||||
wrapper = YOLOWrapper(self.base_model)
|
||||
"""Execute YOLO training over one or more stages and emit progress/finished signals."""
|
||||
|
||||
def on_epoch_end(trainer):
|
||||
completed_epochs = 0
|
||||
stage_history: List[Dict[str, Any]] = []
|
||||
last_stage_results: Optional[Dict[str, Any]] = None
|
||||
|
||||
for stage_index, stage in enumerate(self.stage_plan, start=1):
|
||||
if self._stop_requested or self.isInterruptionRequested():
|
||||
break
|
||||
|
||||
stage_label = stage.get("label") or f"Stage {stage_index}"
|
||||
stage_params = dict(stage.get("params") or {})
|
||||
stage_epochs = int(stage_params.get("epochs", self.epochs))
|
||||
if stage_epochs <= 0:
|
||||
stage_epochs = 1
|
||||
batch = int(stage_params.get("batch", self.batch))
|
||||
imgsz = int(stage_params.get("imgsz", self.imgsz))
|
||||
patience = int(stage_params.get("patience", self.patience))
|
||||
lr0 = float(stage_params.get("lr0", self.lr0))
|
||||
freeze = int(stage_params.get("freeze", 0))
|
||||
run_name = stage_params.get("name") or f"{self.run_name}_stage{stage_index}"
|
||||
|
||||
weights_path = stage.get("model_path") or self.base_model
|
||||
if stage.get("use_previous_best") and last_stage_results:
|
||||
weights_path = (
|
||||
last_stage_results.get("best_model_path")
|
||||
or last_stage_results.get("last_model_path")
|
||||
or weights_path
|
||||
)
|
||||
|
||||
wrapper = YOLOWrapper(weights_path)
|
||||
stage_offset = completed_epochs
|
||||
|
||||
def on_epoch_end(trainer, offset=stage_offset):
|
||||
current_epoch = getattr(trainer, "epoch", 0) + 1
|
||||
metrics: Dict[str, float] = {}
|
||||
loss_items = getattr(trainer, "loss_items", None)
|
||||
if loss_items:
|
||||
metrics["loss"] = float(loss_items[-1])
|
||||
self.progress.emit(current_epoch, self.epochs, metrics)
|
||||
absolute_epoch = min(
|
||||
max(1, offset + current_epoch),
|
||||
max(1, self.total_epochs),
|
||||
)
|
||||
self.progress.emit(absolute_epoch, self.total_epochs, metrics)
|
||||
if self.isInterruptionRequested() or self._stop_requested:
|
||||
setattr(trainer, "stop_training", True)
|
||||
|
||||
callbacks = {"on_fit_epoch_end": on_epoch_end}
|
||||
|
||||
try:
|
||||
results = wrapper.train(
|
||||
stage_result = wrapper.train(
|
||||
data_yaml=self.data_yaml,
|
||||
epochs=self.epochs,
|
||||
imgsz=self.imgsz,
|
||||
batch=self.batch,
|
||||
patience=self.patience,
|
||||
epochs=stage_epochs,
|
||||
imgsz=imgsz,
|
||||
batch=batch,
|
||||
patience=patience,
|
||||
save_dir=self.save_dir,
|
||||
name=self.run_name,
|
||||
lr0=self.lr0,
|
||||
name=run_name,
|
||||
lr0=lr0,
|
||||
callbacks=callbacks,
|
||||
freeze=freeze,
|
||||
)
|
||||
self.finished.emit(results)
|
||||
except Exception as exc:
|
||||
self.error.emit(str(exc))
|
||||
return
|
||||
|
||||
stage_history.append(
|
||||
{
|
||||
"label": stage_label,
|
||||
"params": stage_params,
|
||||
"weights_used": weights_path,
|
||||
"results": stage_result,
|
||||
}
|
||||
)
|
||||
last_stage_results = stage_result
|
||||
completed_epochs += stage_epochs
|
||||
|
||||
final_payload: Dict[str, Any]
|
||||
if last_stage_results:
|
||||
final_payload = dict(last_stage_results)
|
||||
else:
|
||||
final_payload = {
|
||||
"success": False,
|
||||
"message": "Training stopped before any stage completed.",
|
||||
}
|
||||
|
||||
final_payload["stage_results"] = stage_history
|
||||
final_payload["total_epochs_completed"] = completed_epochs
|
||||
final_payload["total_epochs_planned"] = self.total_epochs
|
||||
final_payload["stages_completed"] = len(stage_history)
|
||||
|
||||
self.finished.emit(final_payload)
|
||||
|
||||
|
||||
class TrainingTab(QWidget):
|
||||
@@ -1162,7 +1247,7 @@ class TrainingTab(QWidget):
|
||||
"imgsz": params["imgsz"],
|
||||
"patience": stage2.get("patience", params["patience"]),
|
||||
"lr0": stage2.get("lr0", params["lr0"]),
|
||||
"freeze": 0,
|
||||
"freeze": stage2.get("freeze", 0),
|
||||
"name": f"{params['run_name']}_full_ft",
|
||||
},
|
||||
}
|
||||
@@ -1170,6 +1255,30 @@ class TrainingTab(QWidget):
|
||||
|
||||
return stage_plan
|
||||
|
||||
def _calculate_total_stage_epochs(self, stage_plan: List[Dict[str, Any]]) -> int:
|
||||
total = 0
|
||||
for stage in stage_plan:
|
||||
params = stage.get("params") or {}
|
||||
try:
|
||||
stage_epochs = int(params.get("epochs", 0))
|
||||
except (TypeError, ValueError):
|
||||
stage_epochs = 0
|
||||
if stage_epochs > 0:
|
||||
total += stage_epochs
|
||||
return total
|
||||
|
||||
def _log_stage_plan(self, stage_plan: List[Dict[str, Any]]):
|
||||
for index, stage in enumerate(stage_plan, start=1):
|
||||
stage_label = stage.get("label") or f"Stage {index}"
|
||||
params = stage.get("params") or {}
|
||||
epochs = params.get("epochs", "?")
|
||||
lr0 = params.get("lr0", "?")
|
||||
patience = params.get("patience", "?")
|
||||
freeze = params.get("freeze", 0)
|
||||
self._append_training_log(
|
||||
f" • {stage_label}: epochs={epochs}, lr0={lr0}, patience={patience}, freeze={freeze}"
|
||||
)
|
||||
|
||||
def _get_rgb_cache_root(self, dataset_yaml: Path) -> Path:
|
||||
cache_base = Path("data/datasets/_rgb_cache")
|
||||
cache_base.mkdir(parents=True, exist_ok=True)
|
||||
@@ -1359,15 +1468,25 @@ class TrainingTab(QWidget):
|
||||
)
|
||||
|
||||
params = self._collect_training_params()
|
||||
stage_plan = self._compose_stage_plan(params)
|
||||
params["stage_plan"] = stage_plan
|
||||
total_planned_epochs = (
|
||||
self._calculate_total_stage_epochs(stage_plan) or params["epochs"]
|
||||
)
|
||||
params["total_planned_epochs"] = total_planned_epochs
|
||||
self._active_training_params = params
|
||||
self._training_cancelled = False
|
||||
|
||||
if len(stage_plan) > 1:
|
||||
self._append_training_log("Two-stage fine-tuning schedule:")
|
||||
self._log_stage_plan(stage_plan)
|
||||
|
||||
self._append_training_log(
|
||||
f"Starting training run '{params['run_name']}' using {params['base_model']}"
|
||||
)
|
||||
|
||||
self.training_progress_bar.setVisible(True)
|
||||
self.training_progress_bar.setMaximum(params["epochs"])
|
||||
self.training_progress_bar.setMaximum(max(1, total_planned_epochs))
|
||||
self.training_progress_bar.setValue(0)
|
||||
self._set_training_state(True)
|
||||
|
||||
@@ -1381,6 +1500,8 @@ class TrainingTab(QWidget):
|
||||
lr0=params["lr0"],
|
||||
save_dir=params["save_dir"],
|
||||
run_name=params["run_name"],
|
||||
stage_plan=stage_plan,
|
||||
total_epochs=total_planned_epochs,
|
||||
)
|
||||
self.training_worker.progress.connect(self._on_training_progress)
|
||||
self.training_worker.finished.connect(self._on_training_finished)
|
||||
@@ -1505,14 +1626,22 @@ class TrainingTab(QWidget):
|
||||
if not model_path:
|
||||
raise ValueError("Training results did not include a model path.")
|
||||
|
||||
effective_epochs = params.get("total_planned_epochs", params["epochs"])
|
||||
training_params = {
|
||||
"epochs": params["epochs"],
|
||||
"epochs": effective_epochs,
|
||||
"batch": params["batch"],
|
||||
"imgsz": params["imgsz"],
|
||||
"patience": params["patience"],
|
||||
"lr0": params["lr0"],
|
||||
"run_name": params["run_name"],
|
||||
"two_stage": params.get("two_stage"),
|
||||
}
|
||||
if params.get("stage_plan"):
|
||||
training_params["stage_plan"] = params["stage_plan"]
|
||||
if results.get("stage_results"):
|
||||
training_params["stage_results"] = results["stage_results"]
|
||||
if results.get("total_epochs_completed") is not None:
|
||||
training_params["epochs_completed"] = results["total_epochs_completed"]
|
||||
|
||||
model_id = self.db_manager.add_model(
|
||||
model_name=params["model_name"],
|
||||
|
||||
Reference in New Issue
Block a user