2Stage training fix
This commit is contained in:
@@ -18,7 +18,7 @@ models:
|
|||||||
training:
|
training:
|
||||||
default_epochs: 100
|
default_epochs: 100
|
||||||
default_batch_size: 16
|
default_batch_size: 16
|
||||||
default_imgsz: 640
|
default_imgsz: 1024
|
||||||
default_patience: 50
|
default_patience: 50
|
||||||
default_lr0: 0.01
|
default_lr0: 0.01
|
||||||
two_stage:
|
two_stage:
|
||||||
|
|||||||
@@ -68,6 +68,8 @@ class TrainingWorker(QThread):
|
|||||||
save_dir: str,
|
save_dir: str,
|
||||||
run_name: str,
|
run_name: str,
|
||||||
parent: Optional[QThread] = None,
|
parent: Optional[QThread] = None,
|
||||||
|
stage_plan: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
total_epochs: Optional[int] = None,
|
||||||
):
|
):
|
||||||
super().__init__(parent)
|
super().__init__(parent)
|
||||||
self.data_yaml = data_yaml
|
self.data_yaml = data_yaml
|
||||||
@@ -79,6 +81,27 @@ class TrainingWorker(QThread):
|
|||||||
self.lr0 = lr0
|
self.lr0 = lr0
|
||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
self.run_name = run_name
|
self.run_name = run_name
|
||||||
|
self.stage_plan = stage_plan or [
|
||||||
|
{
|
||||||
|
"label": "Single Stage",
|
||||||
|
"model_path": base_model,
|
||||||
|
"use_previous_best": False,
|
||||||
|
"params": {
|
||||||
|
"epochs": epochs,
|
||||||
|
"batch": batch,
|
||||||
|
"imgsz": imgsz,
|
||||||
|
"patience": patience,
|
||||||
|
"lr0": lr0,
|
||||||
|
"freeze": 0,
|
||||||
|
"name": run_name,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
computed_total = sum(
|
||||||
|
max(0, int((stage.get("params") or {}).get("epochs", 0)))
|
||||||
|
for stage in self.stage_plan
|
||||||
|
)
|
||||||
|
self.total_epochs = total_epochs if total_epochs else computed_total or epochs
|
||||||
self._stop_requested = False
|
self._stop_requested = False
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
@@ -87,36 +110,98 @@ class TrainingWorker(QThread):
|
|||||||
self.requestInterruption()
|
self.requestInterruption()
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
"""Execute YOLO training and emit progress/finished signals."""
|
"""Execute YOLO training over one or more stages and emit progress/finished signals."""
|
||||||
wrapper = YOLOWrapper(self.base_model)
|
|
||||||
|
|
||||||
def on_epoch_end(trainer):
|
completed_epochs = 0
|
||||||
current_epoch = getattr(trainer, "epoch", 0) + 1
|
stage_history: List[Dict[str, Any]] = []
|
||||||
metrics: Dict[str, float] = {}
|
last_stage_results: Optional[Dict[str, Any]] = None
|
||||||
loss_items = getattr(trainer, "loss_items", None)
|
|
||||||
if loss_items:
|
|
||||||
metrics["loss"] = float(loss_items[-1])
|
|
||||||
self.progress.emit(current_epoch, self.epochs, metrics)
|
|
||||||
if self.isInterruptionRequested() or self._stop_requested:
|
|
||||||
setattr(trainer, "stop_training", True)
|
|
||||||
|
|
||||||
callbacks = {"on_fit_epoch_end": on_epoch_end}
|
for stage_index, stage in enumerate(self.stage_plan, start=1):
|
||||||
|
if self._stop_requested or self.isInterruptionRequested():
|
||||||
|
break
|
||||||
|
|
||||||
try:
|
stage_label = stage.get("label") or f"Stage {stage_index}"
|
||||||
results = wrapper.train(
|
stage_params = dict(stage.get("params") or {})
|
||||||
data_yaml=self.data_yaml,
|
stage_epochs = int(stage_params.get("epochs", self.epochs))
|
||||||
epochs=self.epochs,
|
if stage_epochs <= 0:
|
||||||
imgsz=self.imgsz,
|
stage_epochs = 1
|
||||||
batch=self.batch,
|
batch = int(stage_params.get("batch", self.batch))
|
||||||
patience=self.patience,
|
imgsz = int(stage_params.get("imgsz", self.imgsz))
|
||||||
save_dir=self.save_dir,
|
patience = int(stage_params.get("patience", self.patience))
|
||||||
name=self.run_name,
|
lr0 = float(stage_params.get("lr0", self.lr0))
|
||||||
lr0=self.lr0,
|
freeze = int(stage_params.get("freeze", 0))
|
||||||
callbacks=callbacks,
|
run_name = stage_params.get("name") or f"{self.run_name}_stage{stage_index}"
|
||||||
|
|
||||||
|
weights_path = stage.get("model_path") or self.base_model
|
||||||
|
if stage.get("use_previous_best") and last_stage_results:
|
||||||
|
weights_path = (
|
||||||
|
last_stage_results.get("best_model_path")
|
||||||
|
or last_stage_results.get("last_model_path")
|
||||||
|
or weights_path
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper = YOLOWrapper(weights_path)
|
||||||
|
stage_offset = completed_epochs
|
||||||
|
|
||||||
|
def on_epoch_end(trainer, offset=stage_offset):
|
||||||
|
current_epoch = getattr(trainer, "epoch", 0) + 1
|
||||||
|
metrics: Dict[str, float] = {}
|
||||||
|
loss_items = getattr(trainer, "loss_items", None)
|
||||||
|
if loss_items:
|
||||||
|
metrics["loss"] = float(loss_items[-1])
|
||||||
|
absolute_epoch = min(
|
||||||
|
max(1, offset + current_epoch),
|
||||||
|
max(1, self.total_epochs),
|
||||||
|
)
|
||||||
|
self.progress.emit(absolute_epoch, self.total_epochs, metrics)
|
||||||
|
if self.isInterruptionRequested() or self._stop_requested:
|
||||||
|
setattr(trainer, "stop_training", True)
|
||||||
|
|
||||||
|
callbacks = {"on_fit_epoch_end": on_epoch_end}
|
||||||
|
|
||||||
|
try:
|
||||||
|
stage_result = wrapper.train(
|
||||||
|
data_yaml=self.data_yaml,
|
||||||
|
epochs=stage_epochs,
|
||||||
|
imgsz=imgsz,
|
||||||
|
batch=batch,
|
||||||
|
patience=patience,
|
||||||
|
save_dir=self.save_dir,
|
||||||
|
name=run_name,
|
||||||
|
lr0=lr0,
|
||||||
|
callbacks=callbacks,
|
||||||
|
freeze=freeze,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
self.error.emit(str(exc))
|
||||||
|
return
|
||||||
|
|
||||||
|
stage_history.append(
|
||||||
|
{
|
||||||
|
"label": stage_label,
|
||||||
|
"params": stage_params,
|
||||||
|
"weights_used": weights_path,
|
||||||
|
"results": stage_result,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
self.finished.emit(results)
|
last_stage_results = stage_result
|
||||||
except Exception as exc:
|
completed_epochs += stage_epochs
|
||||||
self.error.emit(str(exc))
|
|
||||||
|
final_payload: Dict[str, Any]
|
||||||
|
if last_stage_results:
|
||||||
|
final_payload = dict(last_stage_results)
|
||||||
|
else:
|
||||||
|
final_payload = {
|
||||||
|
"success": False,
|
||||||
|
"message": "Training stopped before any stage completed.",
|
||||||
|
}
|
||||||
|
|
||||||
|
final_payload["stage_results"] = stage_history
|
||||||
|
final_payload["total_epochs_completed"] = completed_epochs
|
||||||
|
final_payload["total_epochs_planned"] = self.total_epochs
|
||||||
|
final_payload["stages_completed"] = len(stage_history)
|
||||||
|
|
||||||
|
self.finished.emit(final_payload)
|
||||||
|
|
||||||
|
|
||||||
class TrainingTab(QWidget):
|
class TrainingTab(QWidget):
|
||||||
@@ -1162,7 +1247,7 @@ class TrainingTab(QWidget):
|
|||||||
"imgsz": params["imgsz"],
|
"imgsz": params["imgsz"],
|
||||||
"patience": stage2.get("patience", params["patience"]),
|
"patience": stage2.get("patience", params["patience"]),
|
||||||
"lr0": stage2.get("lr0", params["lr0"]),
|
"lr0": stage2.get("lr0", params["lr0"]),
|
||||||
"freeze": 0,
|
"freeze": stage2.get("freeze", 0),
|
||||||
"name": f"{params['run_name']}_full_ft",
|
"name": f"{params['run_name']}_full_ft",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -1170,6 +1255,30 @@ class TrainingTab(QWidget):
|
|||||||
|
|
||||||
return stage_plan
|
return stage_plan
|
||||||
|
|
||||||
|
def _calculate_total_stage_epochs(self, stage_plan: List[Dict[str, Any]]) -> int:
|
||||||
|
total = 0
|
||||||
|
for stage in stage_plan:
|
||||||
|
params = stage.get("params") or {}
|
||||||
|
try:
|
||||||
|
stage_epochs = int(params.get("epochs", 0))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
stage_epochs = 0
|
||||||
|
if stage_epochs > 0:
|
||||||
|
total += stage_epochs
|
||||||
|
return total
|
||||||
|
|
||||||
|
def _log_stage_plan(self, stage_plan: List[Dict[str, Any]]):
|
||||||
|
for index, stage in enumerate(stage_plan, start=1):
|
||||||
|
stage_label = stage.get("label") or f"Stage {index}"
|
||||||
|
params = stage.get("params") or {}
|
||||||
|
epochs = params.get("epochs", "?")
|
||||||
|
lr0 = params.get("lr0", "?")
|
||||||
|
patience = params.get("patience", "?")
|
||||||
|
freeze = params.get("freeze", 0)
|
||||||
|
self._append_training_log(
|
||||||
|
f" • {stage_label}: epochs={epochs}, lr0={lr0}, patience={patience}, freeze={freeze}"
|
||||||
|
)
|
||||||
|
|
||||||
def _get_rgb_cache_root(self, dataset_yaml: Path) -> Path:
|
def _get_rgb_cache_root(self, dataset_yaml: Path) -> Path:
|
||||||
cache_base = Path("data/datasets/_rgb_cache")
|
cache_base = Path("data/datasets/_rgb_cache")
|
||||||
cache_base.mkdir(parents=True, exist_ok=True)
|
cache_base.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -1359,15 +1468,25 @@ class TrainingTab(QWidget):
|
|||||||
)
|
)
|
||||||
|
|
||||||
params = self._collect_training_params()
|
params = self._collect_training_params()
|
||||||
|
stage_plan = self._compose_stage_plan(params)
|
||||||
|
params["stage_plan"] = stage_plan
|
||||||
|
total_planned_epochs = (
|
||||||
|
self._calculate_total_stage_epochs(stage_plan) or params["epochs"]
|
||||||
|
)
|
||||||
|
params["total_planned_epochs"] = total_planned_epochs
|
||||||
self._active_training_params = params
|
self._active_training_params = params
|
||||||
self._training_cancelled = False
|
self._training_cancelled = False
|
||||||
|
|
||||||
|
if len(stage_plan) > 1:
|
||||||
|
self._append_training_log("Two-stage fine-tuning schedule:")
|
||||||
|
self._log_stage_plan(stage_plan)
|
||||||
|
|
||||||
self._append_training_log(
|
self._append_training_log(
|
||||||
f"Starting training run '{params['run_name']}' using {params['base_model']}"
|
f"Starting training run '{params['run_name']}' using {params['base_model']}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.training_progress_bar.setVisible(True)
|
self.training_progress_bar.setVisible(True)
|
||||||
self.training_progress_bar.setMaximum(params["epochs"])
|
self.training_progress_bar.setMaximum(max(1, total_planned_epochs))
|
||||||
self.training_progress_bar.setValue(0)
|
self.training_progress_bar.setValue(0)
|
||||||
self._set_training_state(True)
|
self._set_training_state(True)
|
||||||
|
|
||||||
@@ -1381,6 +1500,8 @@ class TrainingTab(QWidget):
|
|||||||
lr0=params["lr0"],
|
lr0=params["lr0"],
|
||||||
save_dir=params["save_dir"],
|
save_dir=params["save_dir"],
|
||||||
run_name=params["run_name"],
|
run_name=params["run_name"],
|
||||||
|
stage_plan=stage_plan,
|
||||||
|
total_epochs=total_planned_epochs,
|
||||||
)
|
)
|
||||||
self.training_worker.progress.connect(self._on_training_progress)
|
self.training_worker.progress.connect(self._on_training_progress)
|
||||||
self.training_worker.finished.connect(self._on_training_finished)
|
self.training_worker.finished.connect(self._on_training_finished)
|
||||||
@@ -1505,14 +1626,22 @@ class TrainingTab(QWidget):
|
|||||||
if not model_path:
|
if not model_path:
|
||||||
raise ValueError("Training results did not include a model path.")
|
raise ValueError("Training results did not include a model path.")
|
||||||
|
|
||||||
|
effective_epochs = params.get("total_planned_epochs", params["epochs"])
|
||||||
training_params = {
|
training_params = {
|
||||||
"epochs": params["epochs"],
|
"epochs": effective_epochs,
|
||||||
"batch": params["batch"],
|
"batch": params["batch"],
|
||||||
"imgsz": params["imgsz"],
|
"imgsz": params["imgsz"],
|
||||||
"patience": params["patience"],
|
"patience": params["patience"],
|
||||||
"lr0": params["lr0"],
|
"lr0": params["lr0"],
|
||||||
"run_name": params["run_name"],
|
"run_name": params["run_name"],
|
||||||
|
"two_stage": params.get("two_stage"),
|
||||||
}
|
}
|
||||||
|
if params.get("stage_plan"):
|
||||||
|
training_params["stage_plan"] = params["stage_plan"]
|
||||||
|
if results.get("stage_results"):
|
||||||
|
training_params["stage_results"] = results["stage_results"]
|
||||||
|
if results.get("total_epochs_completed") is not None:
|
||||||
|
training_params["epochs_completed"] = results["total_epochs_completed"]
|
||||||
|
|
||||||
model_id = self.db_manager.add_model(
|
model_id = self.db_manager.add_model(
|
||||||
model_name=params["model_name"],
|
model_name=params["model_name"],
|
||||||
|
|||||||
Reference in New Issue
Block a user