From 6b6d6fad03748c2f25b925717c00cb96b2845709 Mon Sep 17 00:00:00 2001 From: Martin Laasmaa Date: Thu, 11 Dec 2025 12:50:34 +0200 Subject: [PATCH] 2Stage training fix --- config/app_config.yaml | 2 +- src/gui/tabs/training_tab.py | 187 +++++++++++++++++++++++++++++------ 2 files changed, 159 insertions(+), 30 deletions(-) diff --git a/config/app_config.yaml b/config/app_config.yaml index bc3d8bd..4fa6003 100644 --- a/config/app_config.yaml +++ b/config/app_config.yaml @@ -18,7 +18,7 @@ models: training: default_epochs: 100 default_batch_size: 16 - default_imgsz: 640 + default_imgsz: 1024 default_patience: 50 default_lr0: 0.01 two_stage: diff --git a/src/gui/tabs/training_tab.py b/src/gui/tabs/training_tab.py index 7c3bbf3..3c2ca9e 100644 --- a/src/gui/tabs/training_tab.py +++ b/src/gui/tabs/training_tab.py @@ -68,6 +68,8 @@ class TrainingWorker(QThread): save_dir: str, run_name: str, parent: Optional[QThread] = None, + stage_plan: Optional[List[Dict[str, Any]]] = None, + total_epochs: Optional[int] = None, ): super().__init__(parent) self.data_yaml = data_yaml @@ -79,6 +81,27 @@ class TrainingWorker(QThread): self.lr0 = lr0 self.save_dir = save_dir self.run_name = run_name + self.stage_plan = stage_plan or [ + { + "label": "Single Stage", + "model_path": base_model, + "use_previous_best": False, + "params": { + "epochs": epochs, + "batch": batch, + "imgsz": imgsz, + "patience": patience, + "lr0": lr0, + "freeze": 0, + "name": run_name, + }, + } + ] + computed_total = sum( + max(0, int((stage.get("params") or {}).get("epochs", 0))) + for stage in self.stage_plan + ) + self.total_epochs = total_epochs if total_epochs else computed_total or epochs self._stop_requested = False def stop(self): @@ -87,36 +110,98 @@ class TrainingWorker(QThread): self.requestInterruption() def run(self): - """Execute YOLO training and emit progress/finished signals.""" - wrapper = YOLOWrapper(self.base_model) + """Execute YOLO training over one or more stages and emit progress/finished signals.""" - def on_epoch_end(trainer): - current_epoch = getattr(trainer, "epoch", 0) + 1 - metrics: Dict[str, float] = {} - loss_items = getattr(trainer, "loss_items", None) - if loss_items: - metrics["loss"] = float(loss_items[-1]) - self.progress.emit(current_epoch, self.epochs, metrics) - if self.isInterruptionRequested() or self._stop_requested: - setattr(trainer, "stop_training", True) + completed_epochs = 0 + stage_history: List[Dict[str, Any]] = [] + last_stage_results: Optional[Dict[str, Any]] = None - callbacks = {"on_fit_epoch_end": on_epoch_end} + for stage_index, stage in enumerate(self.stage_plan, start=1): + if self._stop_requested or self.isInterruptionRequested(): + break - try: - results = wrapper.train( - data_yaml=self.data_yaml, - epochs=self.epochs, - imgsz=self.imgsz, - batch=self.batch, - patience=self.patience, - save_dir=self.save_dir, - name=self.run_name, - lr0=self.lr0, - callbacks=callbacks, + stage_label = stage.get("label") or f"Stage {stage_index}" + stage_params = dict(stage.get("params") or {}) + stage_epochs = int(stage_params.get("epochs", self.epochs)) + if stage_epochs <= 0: + stage_epochs = 1 + batch = int(stage_params.get("batch", self.batch)) + imgsz = int(stage_params.get("imgsz", self.imgsz)) + patience = int(stage_params.get("patience", self.patience)) + lr0 = float(stage_params.get("lr0", self.lr0)) + freeze = int(stage_params.get("freeze", 0)) + run_name = stage_params.get("name") or f"{self.run_name}_stage{stage_index}" + + weights_path = stage.get("model_path") or self.base_model + if stage.get("use_previous_best") and last_stage_results: + weights_path = ( + last_stage_results.get("best_model_path") + or last_stage_results.get("last_model_path") + or weights_path + ) + + wrapper = YOLOWrapper(weights_path) + stage_offset = completed_epochs + + def on_epoch_end(trainer, offset=stage_offset): + current_epoch = getattr(trainer, "epoch", 0) + 1 + metrics: Dict[str, float] = {} + loss_items = getattr(trainer, "loss_items", None) + if loss_items: + metrics["loss"] = float(loss_items[-1]) + absolute_epoch = min( + max(1, offset + current_epoch), + max(1, self.total_epochs), + ) + self.progress.emit(absolute_epoch, self.total_epochs, metrics) + if self.isInterruptionRequested() or self._stop_requested: + setattr(trainer, "stop_training", True) + + callbacks = {"on_fit_epoch_end": on_epoch_end} + + try: + stage_result = wrapper.train( + data_yaml=self.data_yaml, + epochs=stage_epochs, + imgsz=imgsz, + batch=batch, + patience=patience, + save_dir=self.save_dir, + name=run_name, + lr0=lr0, + callbacks=callbacks, + freeze=freeze, + ) + except Exception as exc: + self.error.emit(str(exc)) + return + + stage_history.append( + { + "label": stage_label, + "params": stage_params, + "weights_used": weights_path, + "results": stage_result, + } ) - self.finished.emit(results) - except Exception as exc: - self.error.emit(str(exc)) + last_stage_results = stage_result + completed_epochs += stage_epochs + + final_payload: Dict[str, Any] + if last_stage_results: + final_payload = dict(last_stage_results) + else: + final_payload = { + "success": False, + "message": "Training stopped before any stage completed.", + } + + final_payload["stage_results"] = stage_history + final_payload["total_epochs_completed"] = completed_epochs + final_payload["total_epochs_planned"] = self.total_epochs + final_payload["stages_completed"] = len(stage_history) + + self.finished.emit(final_payload) class TrainingTab(QWidget): @@ -1162,7 +1247,7 @@ class TrainingTab(QWidget): "imgsz": params["imgsz"], "patience": stage2.get("patience", params["patience"]), "lr0": stage2.get("lr0", params["lr0"]), - "freeze": 0, + "freeze": stage2.get("freeze", 0), "name": f"{params['run_name']}_full_ft", }, } @@ -1170,6 +1255,30 @@ class TrainingTab(QWidget): return stage_plan + def _calculate_total_stage_epochs(self, stage_plan: List[Dict[str, Any]]) -> int: + total = 0 + for stage in stage_plan: + params = stage.get("params") or {} + try: + stage_epochs = int(params.get("epochs", 0)) + except (TypeError, ValueError): + stage_epochs = 0 + if stage_epochs > 0: + total += stage_epochs + return total + + def _log_stage_plan(self, stage_plan: List[Dict[str, Any]]): + for index, stage in enumerate(stage_plan, start=1): + stage_label = stage.get("label") or f"Stage {index}" + params = stage.get("params") or {} + epochs = params.get("epochs", "?") + lr0 = params.get("lr0", "?") + patience = params.get("patience", "?") + freeze = params.get("freeze", 0) + self._append_training_log( + f" • {stage_label}: epochs={epochs}, lr0={lr0}, patience={patience}, freeze={freeze}" + ) + def _get_rgb_cache_root(self, dataset_yaml: Path) -> Path: cache_base = Path("data/datasets/_rgb_cache") cache_base.mkdir(parents=True, exist_ok=True) @@ -1359,15 +1468,25 @@ class TrainingTab(QWidget): ) params = self._collect_training_params() + stage_plan = self._compose_stage_plan(params) + params["stage_plan"] = stage_plan + total_planned_epochs = ( + self._calculate_total_stage_epochs(stage_plan) or params["epochs"] + ) + params["total_planned_epochs"] = total_planned_epochs self._active_training_params = params self._training_cancelled = False + if len(stage_plan) > 1: + self._append_training_log("Two-stage fine-tuning schedule:") + self._log_stage_plan(stage_plan) + self._append_training_log( f"Starting training run '{params['run_name']}' using {params['base_model']}" ) self.training_progress_bar.setVisible(True) - self.training_progress_bar.setMaximum(params["epochs"]) + self.training_progress_bar.setMaximum(max(1, total_planned_epochs)) self.training_progress_bar.setValue(0) self._set_training_state(True) @@ -1381,6 +1500,8 @@ class TrainingTab(QWidget): lr0=params["lr0"], save_dir=params["save_dir"], run_name=params["run_name"], + stage_plan=stage_plan, + total_epochs=total_planned_epochs, ) self.training_worker.progress.connect(self._on_training_progress) self.training_worker.finished.connect(self._on_training_finished) @@ -1505,14 +1626,22 @@ class TrainingTab(QWidget): if not model_path: raise ValueError("Training results did not include a model path.") + effective_epochs = params.get("total_planned_epochs", params["epochs"]) training_params = { - "epochs": params["epochs"], + "epochs": effective_epochs, "batch": params["batch"], "imgsz": params["imgsz"], "patience": params["patience"], "lr0": params["lr0"], "run_name": params["run_name"], + "two_stage": params.get("two_stage"), } + if params.get("stage_plan"): + training_params["stage_plan"] = params["stage_plan"] + if results.get("stage_results"): + training_params["stage_results"] = results["stage_results"] + if results.get("total_epochs_completed") is not None: + training_params["epochs_completed"] = results["total_epochs_completed"] model_id = self.db_manager.add_model( model_name=params["model_name"],