correcting label writing and formatting code
This commit is contained in:
@@ -10,6 +10,7 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
import numpy as np
|
||||||
from PySide6.QtCore import Qt, QThread, Signal
|
from PySide6.QtCore import Qt, QThread, Signal
|
||||||
from PySide6.QtWidgets import (
|
from PySide6.QtWidgets import (
|
||||||
QWidget,
|
QWidget,
|
||||||
@@ -91,10 +92,7 @@ class TrainingWorker(QThread):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
computed_total = sum(
|
computed_total = sum(max(0, int((stage.get("params") or {}).get("epochs", 0))) for stage in self.stage_plan)
|
||||||
max(0, int((stage.get("params") or {}).get("epochs", 0)))
|
|
||||||
for stage in self.stage_plan
|
|
||||||
)
|
|
||||||
self.total_epochs = total_epochs if total_epochs else computed_total or epochs
|
self.total_epochs = total_epochs if total_epochs else computed_total or epochs
|
||||||
self._stop_requested = False
|
self._stop_requested = False
|
||||||
|
|
||||||
@@ -201,9 +199,7 @@ class TrainingWorker(QThread):
|
|||||||
class TrainingTab(QWidget):
|
class TrainingTab(QWidget):
|
||||||
"""Training tab for model training."""
|
"""Training tab for model training."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
|
||||||
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
|
||||||
):
|
|
||||||
super().__init__(parent)
|
super().__init__(parent)
|
||||||
self.db_manager = db_manager
|
self.db_manager = db_manager
|
||||||
self.config_manager = config_manager
|
self.config_manager = config_manager
|
||||||
@@ -337,18 +333,14 @@ class TrainingTab(QWidget):
|
|||||||
self.model_version_edit = QLineEdit("v1")
|
self.model_version_edit = QLineEdit("v1")
|
||||||
form_layout.addRow("Version:", self.model_version_edit)
|
form_layout.addRow("Version:", self.model_version_edit)
|
||||||
|
|
||||||
default_base_model = self.config_manager.get(
|
default_base_model = self.config_manager.get("models.default_base_model", "yolov8s-seg.pt")
|
||||||
"models.default_base_model", "yolov8s-seg.pt"
|
|
||||||
)
|
|
||||||
base_model_choices = self.config_manager.get("models.base_model_choices", [])
|
base_model_choices = self.config_manager.get("models.base_model_choices", [])
|
||||||
|
|
||||||
self.base_model_combo = QComboBox()
|
self.base_model_combo = QComboBox()
|
||||||
self.base_model_combo.addItem("Custom path…", "")
|
self.base_model_combo.addItem("Custom path…", "")
|
||||||
for choice in base_model_choices:
|
for choice in base_model_choices:
|
||||||
self.base_model_combo.addItem(choice, choice)
|
self.base_model_combo.addItem(choice, choice)
|
||||||
self.base_model_combo.currentIndexChanged.connect(
|
self.base_model_combo.currentIndexChanged.connect(self._on_base_model_preset_changed)
|
||||||
self._on_base_model_preset_changed
|
|
||||||
)
|
|
||||||
form_layout.addRow("Base Model Preset:", self.base_model_combo)
|
form_layout.addRow("Base Model Preset:", self.base_model_combo)
|
||||||
|
|
||||||
base_model_layout = QHBoxLayout()
|
base_model_layout = QHBoxLayout()
|
||||||
@@ -434,12 +426,8 @@ class TrainingTab(QWidget):
|
|||||||
group_layout = QVBoxLayout()
|
group_layout = QVBoxLayout()
|
||||||
|
|
||||||
self.two_stage_checkbox = QCheckBox("Enable staged head-only + full fine-tune")
|
self.two_stage_checkbox = QCheckBox("Enable staged head-only + full fine-tune")
|
||||||
two_stage_defaults = (
|
two_stage_defaults = training_defaults.get("two_stage", {}) if training_defaults else {}
|
||||||
training_defaults.get("two_stage", {}) if training_defaults else {}
|
self.two_stage_checkbox.setChecked(bool(two_stage_defaults.get("enabled", False)))
|
||||||
)
|
|
||||||
self.two_stage_checkbox.setChecked(
|
|
||||||
bool(two_stage_defaults.get("enabled", False))
|
|
||||||
)
|
|
||||||
self.two_stage_checkbox.toggled.connect(self._on_two_stage_toggled)
|
self.two_stage_checkbox.toggled.connect(self._on_two_stage_toggled)
|
||||||
group_layout.addWidget(self.two_stage_checkbox)
|
group_layout.addWidget(self.two_stage_checkbox)
|
||||||
|
|
||||||
@@ -501,9 +489,7 @@ class TrainingTab(QWidget):
|
|||||||
stage2_group.setLayout(stage2_form)
|
stage2_group.setLayout(stage2_form)
|
||||||
controls_layout.addWidget(stage2_group)
|
controls_layout.addWidget(stage2_group)
|
||||||
|
|
||||||
helper_label = QLabel(
|
helper_label = QLabel("When enabled, staged hyperparameters override the global epochs/patience/lr.")
|
||||||
"When enabled, staged hyperparameters override the global epochs/patience/lr."
|
|
||||||
)
|
|
||||||
helper_label.setWordWrap(True)
|
helper_label.setWordWrap(True)
|
||||||
controls_layout.addWidget(helper_label)
|
controls_layout.addWidget(helper_label)
|
||||||
|
|
||||||
@@ -548,9 +534,7 @@ class TrainingTab(QWidget):
|
|||||||
if normalized == preset_value:
|
if normalized == preset_value:
|
||||||
target_index = idx
|
target_index = idx
|
||||||
break
|
break
|
||||||
if normalized.endswith(f"/{preset_value}") or normalized.endswith(
|
if normalized.endswith(f"/{preset_value}") or normalized.endswith(f"\\{preset_value}"):
|
||||||
f"\\{preset_value}"
|
|
||||||
):
|
|
||||||
target_index = idx
|
target_index = idx
|
||||||
break
|
break
|
||||||
self.base_model_combo.blockSignals(True)
|
self.base_model_combo.blockSignals(True)
|
||||||
@@ -638,9 +622,7 @@ class TrainingTab(QWidget):
|
|||||||
|
|
||||||
def _browse_dataset(self):
|
def _browse_dataset(self):
|
||||||
"""Open a file dialog to manually select data.yaml."""
|
"""Open a file dialog to manually select data.yaml."""
|
||||||
start_dir = self.config_manager.get(
|
start_dir = self.config_manager.get("training.last_dataset_dir", "data/datasets")
|
||||||
"training.last_dataset_dir", "data/datasets"
|
|
||||||
)
|
|
||||||
start_path = Path(start_dir).expanduser()
|
start_path = Path(start_dir).expanduser()
|
||||||
if not start_path.exists():
|
if not start_path.exists():
|
||||||
start_path = Path.cwd()
|
start_path = Path.cwd()
|
||||||
@@ -676,9 +658,7 @@ class TrainingTab(QWidget):
|
|||||||
return
|
return
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.exception("Unexpected error while generating data.yaml")
|
logger.exception("Unexpected error while generating data.yaml")
|
||||||
self._display_dataset_error(
|
self._display_dataset_error("Unexpected error while generating data.yaml. Check logs for details.")
|
||||||
"Unexpected error while generating data.yaml. Check logs for details."
|
|
||||||
)
|
|
||||||
QMessageBox.critical(
|
QMessageBox.critical(
|
||||||
self,
|
self,
|
||||||
"data.yaml Generation Failed",
|
"data.yaml Generation Failed",
|
||||||
@@ -755,13 +735,9 @@ class TrainingTab(QWidget):
|
|||||||
self.selected_dataset = info
|
self.selected_dataset = info
|
||||||
|
|
||||||
self.dataset_root_label.setText(info["root"]) # type: ignore[arg-type]
|
self.dataset_root_label.setText(info["root"]) # type: ignore[arg-type]
|
||||||
self.train_count_label.setText(
|
self.train_count_label.setText(self._format_split_info(info["splits"].get("train")))
|
||||||
self._format_split_info(info["splits"].get("train"))
|
|
||||||
)
|
|
||||||
self.val_count_label.setText(self._format_split_info(info["splits"].get("val")))
|
self.val_count_label.setText(self._format_split_info(info["splits"].get("val")))
|
||||||
self.test_count_label.setText(
|
self.test_count_label.setText(self._format_split_info(info["splits"].get("test")))
|
||||||
self._format_split_info(info["splits"].get("test"))
|
|
||||||
)
|
|
||||||
self.num_classes_label.setText(str(info["num_classes"]))
|
self.num_classes_label.setText(str(info["num_classes"]))
|
||||||
class_names = ", ".join(info["class_names"]) or "–"
|
class_names = ", ".join(info["class_names"]) or "–"
|
||||||
self.class_names_label.setText(class_names)
|
self.class_names_label.setText(class_names)
|
||||||
@@ -815,18 +791,12 @@ class TrainingTab(QWidget):
|
|||||||
if split_path.exists():
|
if split_path.exists():
|
||||||
split_info["count"] = self._count_images(split_path)
|
split_info["count"] = self._count_images(split_path)
|
||||||
if split_info["count"] == 0:
|
if split_info["count"] == 0:
|
||||||
warnings.append(
|
warnings.append(f"No images found for {split_name} split at {split_path}")
|
||||||
f"No images found for {split_name} split at {split_path}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
warnings.append(
|
warnings.append(f"{split_name.capitalize()} path does not exist: {split_path}")
|
||||||
f"{split_name.capitalize()} path does not exist: {split_path}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if split_name in ("train", "val"):
|
if split_name in ("train", "val"):
|
||||||
warnings.append(
|
warnings.append(f"{split_name.capitalize()} split missing in data.yaml")
|
||||||
f"{split_name.capitalize()} split missing in data.yaml"
|
|
||||||
)
|
|
||||||
splits[split_name] = split_info
|
splits[split_name] = split_info
|
||||||
|
|
||||||
names_list = self._normalize_class_names(data.get("names"))
|
names_list = self._normalize_class_names(data.get("names"))
|
||||||
@@ -844,9 +814,7 @@ class TrainingTab(QWidget):
|
|||||||
if not names_list and nc_value:
|
if not names_list and nc_value:
|
||||||
names_list = [f"class_{idx}" for idx in range(int(nc_value))]
|
names_list = [f"class_{idx}" for idx in range(int(nc_value))]
|
||||||
elif nc_value and len(names_list) not in (0, int(nc_value)):
|
elif nc_value and len(names_list) not in (0, int(nc_value)):
|
||||||
warnings.append(
|
warnings.append(f"Number of class names ({len(names_list)}) does not match nc={nc_value}")
|
||||||
f"Number of class names ({len(names_list)}) does not match nc={nc_value}"
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_name = data.get("name") or base_path.name
|
dataset_name = data.get("name") or base_path.name
|
||||||
|
|
||||||
@@ -898,16 +866,12 @@ class TrainingTab(QWidget):
|
|||||||
|
|
||||||
class_index_map = self._build_class_index_map(dataset_info)
|
class_index_map = self._build_class_index_map(dataset_info)
|
||||||
if not class_index_map:
|
if not class_index_map:
|
||||||
self._append_training_log(
|
self._append_training_log("Skipping label export: dataset classes do not match database entries.")
|
||||||
"Skipping label export: dataset classes do not match database entries."
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
dataset_root_str = dataset_info.get("root")
|
dataset_root_str = dataset_info.get("root")
|
||||||
dataset_yaml_path = dataset_info.get("yaml_path")
|
dataset_yaml_path = dataset_info.get("yaml_path")
|
||||||
dataset_yaml = (
|
dataset_yaml = Path(dataset_yaml_path).expanduser() if dataset_yaml_path else None
|
||||||
Path(dataset_yaml_path).expanduser() if dataset_yaml_path else None
|
|
||||||
)
|
|
||||||
dataset_root: Optional[Path]
|
dataset_root: Optional[Path]
|
||||||
if dataset_root_str:
|
if dataset_root_str:
|
||||||
dataset_root = Path(dataset_root_str).resolve()
|
dataset_root = Path(dataset_root_str).resolve()
|
||||||
@@ -941,7 +905,9 @@ class TrainingTab(QWidget):
|
|||||||
if stats["registered_images"]:
|
if stats["registered_images"]:
|
||||||
message += f" {stats['registered_images']} image(s) had database-backed annotations."
|
message += f" {stats['registered_images']} image(s) had database-backed annotations."
|
||||||
if stats["missing_records"]:
|
if stats["missing_records"]:
|
||||||
message += f" {stats['missing_records']} image(s) had no database entry; empty label files were written."
|
message += (
|
||||||
|
f" {stats['missing_records']} image(s) had no database entry; empty label files were written."
|
||||||
|
)
|
||||||
split_messages.append(message)
|
split_messages.append(message)
|
||||||
|
|
||||||
for msg in split_messages:
|
for msg in split_messages:
|
||||||
@@ -973,9 +939,7 @@ class TrainingTab(QWidget):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
processed_images += 1
|
processed_images += 1
|
||||||
label_path = (labels_dir / image_file.relative_to(images_dir)).with_suffix(
|
label_path = (labels_dir / image_file.relative_to(images_dir)).with_suffix(".txt")
|
||||||
".txt"
|
|
||||||
)
|
|
||||||
label_path.parent.mkdir(parents=True, exist_ok=True)
|
label_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
found, annotation_entries = self._fetch_annotations_for_image(
|
found, annotation_entries = self._fetch_annotations_for_image(
|
||||||
@@ -991,25 +955,23 @@ class TrainingTab(QWidget):
|
|||||||
for entry in annotation_entries:
|
for entry in annotation_entries:
|
||||||
polygon = entry.get("polygon") or []
|
polygon = entry.get("polygon") or []
|
||||||
if polygon:
|
if polygon:
|
||||||
|
print(image_file, polygon[:4], polygon[-2:], entry.get("bbox"))
|
||||||
|
# coords = " ".join(f"{value:.6f}" for value in entry.get("bbox"))
|
||||||
|
# coords += " "
|
||||||
coords = " ".join(f"{value:.6f}" for value in polygon)
|
coords = " ".join(f"{value:.6f}" for value in polygon)
|
||||||
handle.write(f"{entry['class_idx']} {coords}\n")
|
handle.write(f"{entry['class_idx']} {coords}\n")
|
||||||
annotations_written += 1
|
annotations_written += 1
|
||||||
elif entry.get("bbox"):
|
elif entry.get("bbox"):
|
||||||
x_center, y_center, width, height = entry["bbox"]
|
x_center, y_center, width, height = entry["bbox"]
|
||||||
handle.write(
|
handle.write(f"{entry['class_idx']} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n")
|
||||||
f"{entry['class_idx']} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n"
|
|
||||||
)
|
|
||||||
annotations_written += 1
|
annotations_written += 1
|
||||||
|
|
||||||
total_annotations += annotations_written
|
total_annotations += annotations_written
|
||||||
|
|
||||||
cache_reset_root = labels_dir.parent
|
cache_reset_root = labels_dir.parent
|
||||||
self._invalidate_split_cache(cache_reset_root)
|
self._invalidate_split_cache(cache_reset_root)
|
||||||
|
|
||||||
if processed_images == 0:
|
if processed_images == 0:
|
||||||
self._append_training_log(
|
self._append_training_log(f"[{split_name}] No images found to export labels for.")
|
||||||
f"[{split_name}] No images found to export labels for."
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -1135,6 +1097,10 @@ class TrainingTab(QWidget):
|
|||||||
xs.append(x_val)
|
xs.append(x_val)
|
||||||
ys.append(y_val)
|
ys.append(y_val)
|
||||||
|
|
||||||
|
if any(np.abs(np.array(coords[:2]) - np.array(coords[-2:])) < 1e-5):
|
||||||
|
print("Closing polygon")
|
||||||
|
coords.extend(coords[:2])
|
||||||
|
|
||||||
if len(coords) < 6:
|
if len(coords) < 6:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -1147,6 +1113,11 @@ class TrainingTab(QWidget):
|
|||||||
+ abs((min(ys) if ys else 0.0) - y_min)
|
+ abs((min(ys) if ys else 0.0) - y_min)
|
||||||
+ abs((max(ys) if ys else 0.0) - y_max)
|
+ abs((max(ys) if ys else 0.0) - y_max)
|
||||||
)
|
)
|
||||||
|
width = max(0.0, x_max - x_min)
|
||||||
|
height = max(0.0, y_max - y_min)
|
||||||
|
x_center = x_min + width / 2.0
|
||||||
|
y_center = y_min + height / 2.0
|
||||||
|
score = (x_center, y_center, width, height)
|
||||||
|
|
||||||
candidates.append((score, coords))
|
candidates.append((score, coords))
|
||||||
|
|
||||||
@@ -1164,13 +1135,10 @@ class TrainingTab(QWidget):
|
|||||||
return 1.0
|
return 1.0
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def _prepare_dataset_for_training(
|
def _prepare_dataset_for_training(self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None) -> Path:
|
||||||
self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None
|
|
||||||
) -> Path:
|
|
||||||
dataset_info = dataset_info or (
|
dataset_info = dataset_info or (
|
||||||
self.selected_dataset
|
self.selected_dataset
|
||||||
if self.selected_dataset
|
if self.selected_dataset and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
|
||||||
and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
|
|
||||||
else self._parse_dataset_yaml(dataset_yaml)
|
else self._parse_dataset_yaml(dataset_yaml)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1189,14 +1157,10 @@ class TrainingTab(QWidget):
|
|||||||
cache_root = self._get_rgb_cache_root(dataset_yaml)
|
cache_root = self._get_rgb_cache_root(dataset_yaml)
|
||||||
rgb_yaml = cache_root / "data.yaml"
|
rgb_yaml = cache_root / "data.yaml"
|
||||||
if rgb_yaml.exists():
|
if rgb_yaml.exists():
|
||||||
self._append_training_log(
|
self._append_training_log(f"Detected grayscale dataset; reusing RGB cache at {cache_root}")
|
||||||
f"Detected grayscale dataset; reusing RGB cache at {cache_root}"
|
|
||||||
)
|
|
||||||
return rgb_yaml
|
return rgb_yaml
|
||||||
|
|
||||||
self._append_training_log(
|
self._append_training_log(f"Detected grayscale dataset; creating RGB cache at {cache_root}")
|
||||||
f"Detected grayscale dataset; creating RGB cache at {cache_root}"
|
|
||||||
)
|
|
||||||
self._build_rgb_dataset(cache_root, dataset_info)
|
self._build_rgb_dataset(cache_root, dataset_info)
|
||||||
return rgb_yaml
|
return rgb_yaml
|
||||||
|
|
||||||
@@ -1463,15 +1427,12 @@ class TrainingTab(QWidget):
|
|||||||
|
|
||||||
dataset_path = Path(dataset_yaml).expanduser()
|
dataset_path = Path(dataset_yaml).expanduser()
|
||||||
if not dataset_path.exists():
|
if not dataset_path.exists():
|
||||||
QMessageBox.warning(
|
QMessageBox.warning(self, "Invalid Dataset", "Selected data.yaml file does not exist.")
|
||||||
self, "Invalid Dataset", "Selected data.yaml file does not exist."
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
dataset_info = (
|
dataset_info = (
|
||||||
self.selected_dataset
|
self.selected_dataset
|
||||||
if self.selected_dataset
|
if self.selected_dataset and self.selected_dataset.get("yaml_path") == str(dataset_path)
|
||||||
and self.selected_dataset.get("yaml_path") == str(dataset_path)
|
|
||||||
else self._parse_dataset_yaml(dataset_path)
|
else self._parse_dataset_yaml(dataset_path)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1480,16 +1441,12 @@ class TrainingTab(QWidget):
|
|||||||
|
|
||||||
dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info)
|
dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info)
|
||||||
if dataset_to_use != dataset_path:
|
if dataset_to_use != dataset_path:
|
||||||
self._append_training_log(
|
self._append_training_log(f"Using RGB-converted dataset at {dataset_to_use.parent}")
|
||||||
f"Using RGB-converted dataset at {dataset_to_use.parent}"
|
|
||||||
)
|
|
||||||
|
|
||||||
params = self._collect_training_params()
|
params = self._collect_training_params()
|
||||||
stage_plan = self._compose_stage_plan(params)
|
stage_plan = self._compose_stage_plan(params)
|
||||||
params["stage_plan"] = stage_plan
|
params["stage_plan"] = stage_plan
|
||||||
total_planned_epochs = (
|
total_planned_epochs = self._calculate_total_stage_epochs(stage_plan) or params["epochs"]
|
||||||
self._calculate_total_stage_epochs(stage_plan) or params["epochs"]
|
|
||||||
)
|
|
||||||
params["total_planned_epochs"] = total_planned_epochs
|
params["total_planned_epochs"] = total_planned_epochs
|
||||||
self._active_training_params = params
|
self._active_training_params = params
|
||||||
self._training_cancelled = False
|
self._training_cancelled = False
|
||||||
@@ -1498,9 +1455,7 @@ class TrainingTab(QWidget):
|
|||||||
self._append_training_log("Two-stage fine-tuning schedule:")
|
self._append_training_log("Two-stage fine-tuning schedule:")
|
||||||
self._log_stage_plan(stage_plan)
|
self._log_stage_plan(stage_plan)
|
||||||
|
|
||||||
self._append_training_log(
|
self._append_training_log(f"Starting training run '{params['run_name']}' using {params['base_model']}")
|
||||||
f"Starting training run '{params['run_name']}' using {params['base_model']}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.training_progress_bar.setVisible(True)
|
self.training_progress_bar.setVisible(True)
|
||||||
self.training_progress_bar.setMaximum(max(1, total_planned_epochs))
|
self.training_progress_bar.setMaximum(max(1, total_planned_epochs))
|
||||||
@@ -1528,9 +1483,7 @@ class TrainingTab(QWidget):
|
|||||||
def _stop_training(self):
|
def _stop_training(self):
|
||||||
if self.training_worker and self.training_worker.isRunning():
|
if self.training_worker and self.training_worker.isRunning():
|
||||||
self._training_cancelled = True
|
self._training_cancelled = True
|
||||||
self._append_training_log(
|
self._append_training_log("Stop requested. Waiting for the current epoch to finish...")
|
||||||
"Stop requested. Waiting for the current epoch to finish..."
|
|
||||||
)
|
|
||||||
self.training_worker.stop()
|
self.training_worker.stop()
|
||||||
self.stop_training_button.setEnabled(False)
|
self.stop_training_button.setEnabled(False)
|
||||||
|
|
||||||
@@ -1566,9 +1519,7 @@ class TrainingTab(QWidget):
|
|||||||
|
|
||||||
if worker.isRunning():
|
if worker.isRunning():
|
||||||
if not worker.wait(wait_timeout_ms):
|
if not worker.wait(wait_timeout_ms):
|
||||||
logger.warning(
|
logger.warning("Training worker did not finish within %sms", wait_timeout_ms)
|
||||||
"Training worker did not finish within %sms", wait_timeout_ms
|
|
||||||
)
|
|
||||||
|
|
||||||
worker.deleteLater()
|
worker.deleteLater()
|
||||||
|
|
||||||
@@ -1585,16 +1536,12 @@ class TrainingTab(QWidget):
|
|||||||
self._set_training_state(False)
|
self._set_training_state(False)
|
||||||
self.training_progress_bar.setVisible(False)
|
self.training_progress_bar.setVisible(False)
|
||||||
|
|
||||||
def _on_training_progress(
|
def _on_training_progress(self, current_epoch: int, total_epochs: int, metrics: Dict[str, Any]):
|
||||||
self, current_epoch: int, total_epochs: int, metrics: Dict[str, Any]
|
|
||||||
):
|
|
||||||
self.training_progress_bar.setMaximum(total_epochs)
|
self.training_progress_bar.setMaximum(total_epochs)
|
||||||
self.training_progress_bar.setValue(current_epoch)
|
self.training_progress_bar.setValue(current_epoch)
|
||||||
parts = [f"Epoch {current_epoch}/{total_epochs}"]
|
parts = [f"Epoch {current_epoch}/{total_epochs}"]
|
||||||
if metrics:
|
if metrics:
|
||||||
metric_text = ", ".join(
|
metric_text = ", ".join(f"{key}: {value:.4f}" for key, value in metrics.items())
|
||||||
f"{key}: {value:.4f}" for key, value in metrics.items()
|
|
||||||
)
|
|
||||||
parts.append(metric_text)
|
parts.append(metric_text)
|
||||||
self._append_training_log(" | ".join(parts))
|
self._append_training_log(" | ".join(parts))
|
||||||
|
|
||||||
@@ -1621,9 +1568,7 @@ class TrainingTab(QWidget):
|
|||||||
f"Model trained but not registered: {exc}",
|
f"Model trained but not registered: {exc}",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
QMessageBox.information(
|
QMessageBox.information(self, "Training Complete", "Training finished successfully.")
|
||||||
self, "Training Complete", "Training finished successfully."
|
|
||||||
)
|
|
||||||
|
|
||||||
def _on_training_error(self, message: str):
|
def _on_training_error(self, message: str):
|
||||||
self._cleanup_training_worker()
|
self._cleanup_training_worker()
|
||||||
@@ -1669,9 +1614,7 @@ class TrainingTab(QWidget):
|
|||||||
metrics=results.get("metrics"),
|
metrics=results.get("metrics"),
|
||||||
)
|
)
|
||||||
|
|
||||||
self._append_training_log(
|
self._append_training_log(f"Registered model '{params['model_name']}' (ID {model_id}) at {model_path}")
|
||||||
f"Registered model '{params['model_name']}' (ID {model_id}) at {model_path}"
|
|
||||||
)
|
|
||||||
self._active_training_params = None
|
self._active_training_params = None
|
||||||
|
|
||||||
def _set_training_state(self, is_training: bool):
|
def _set_training_state(self, is_training: bool):
|
||||||
@@ -1714,9 +1657,7 @@ class TrainingTab(QWidget):
|
|||||||
|
|
||||||
def _browse_save_dir(self):
|
def _browse_save_dir(self):
|
||||||
start_path = self.save_dir_edit.text().strip() or "data/models"
|
start_path = self.save_dir_edit.text().strip() or "data/models"
|
||||||
directory = QFileDialog.getExistingDirectory(
|
directory = QFileDialog.getExistingDirectory(self, "Select Save Directory", start_path)
|
||||||
self, "Select Save Directory", start_path
|
|
||||||
)
|
|
||||||
if directory:
|
if directory:
|
||||||
self.save_dir_edit.setText(directory)
|
self.save_dir_edit.setText(directory)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user