correcting label writing and formatting code

This commit is contained in:
2026-01-16 10:24:19 +02:00
parent ca52312925
commit fcbd5fb16d

View File

@@ -10,6 +10,7 @@ from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import yaml
import numpy as np
from PySide6.QtCore import Qt, QThread, Signal
from PySide6.QtWidgets import (
QWidget,
@@ -91,10 +92,7 @@ class TrainingWorker(QThread):
},
}
]
computed_total = sum(
max(0, int((stage.get("params") or {}).get("epochs", 0)))
for stage in self.stage_plan
)
computed_total = sum(max(0, int((stage.get("params") or {}).get("epochs", 0))) for stage in self.stage_plan)
self.total_epochs = total_epochs if total_epochs else computed_total or epochs
self._stop_requested = False
@@ -201,9 +199,7 @@ class TrainingWorker(QThread):
class TrainingTab(QWidget):
"""Training tab for model training."""
def __init__(
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
):
def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
super().__init__(parent)
self.db_manager = db_manager
self.config_manager = config_manager
@@ -337,18 +333,14 @@ class TrainingTab(QWidget):
self.model_version_edit = QLineEdit("v1")
form_layout.addRow("Version:", self.model_version_edit)
default_base_model = self.config_manager.get(
"models.default_base_model", "yolov8s-seg.pt"
)
default_base_model = self.config_manager.get("models.default_base_model", "yolov8s-seg.pt")
base_model_choices = self.config_manager.get("models.base_model_choices", [])
self.base_model_combo = QComboBox()
self.base_model_combo.addItem("Custom path…", "")
for choice in base_model_choices:
self.base_model_combo.addItem(choice, choice)
self.base_model_combo.currentIndexChanged.connect(
self._on_base_model_preset_changed
)
self.base_model_combo.currentIndexChanged.connect(self._on_base_model_preset_changed)
form_layout.addRow("Base Model Preset:", self.base_model_combo)
base_model_layout = QHBoxLayout()
@@ -434,12 +426,8 @@ class TrainingTab(QWidget):
group_layout = QVBoxLayout()
self.two_stage_checkbox = QCheckBox("Enable staged head-only + full fine-tune")
two_stage_defaults = (
training_defaults.get("two_stage", {}) if training_defaults else {}
)
self.two_stage_checkbox.setChecked(
bool(two_stage_defaults.get("enabled", False))
)
two_stage_defaults = training_defaults.get("two_stage", {}) if training_defaults else {}
self.two_stage_checkbox.setChecked(bool(two_stage_defaults.get("enabled", False)))
self.two_stage_checkbox.toggled.connect(self._on_two_stage_toggled)
group_layout.addWidget(self.two_stage_checkbox)
@@ -501,9 +489,7 @@ class TrainingTab(QWidget):
stage2_group.setLayout(stage2_form)
controls_layout.addWidget(stage2_group)
helper_label = QLabel(
"When enabled, staged hyperparameters override the global epochs/patience/lr."
)
helper_label = QLabel("When enabled, staged hyperparameters override the global epochs/patience/lr.")
helper_label.setWordWrap(True)
controls_layout.addWidget(helper_label)
@@ -548,9 +534,7 @@ class TrainingTab(QWidget):
if normalized == preset_value:
target_index = idx
break
if normalized.endswith(f"/{preset_value}") or normalized.endswith(
f"\\{preset_value}"
):
if normalized.endswith(f"/{preset_value}") or normalized.endswith(f"\\{preset_value}"):
target_index = idx
break
self.base_model_combo.blockSignals(True)
@@ -638,9 +622,7 @@ class TrainingTab(QWidget):
def _browse_dataset(self):
"""Open a file dialog to manually select data.yaml."""
start_dir = self.config_manager.get(
"training.last_dataset_dir", "data/datasets"
)
start_dir = self.config_manager.get("training.last_dataset_dir", "data/datasets")
start_path = Path(start_dir).expanduser()
if not start_path.exists():
start_path = Path.cwd()
@@ -676,9 +658,7 @@ class TrainingTab(QWidget):
return
except Exception as exc:
logger.exception("Unexpected error while generating data.yaml")
self._display_dataset_error(
"Unexpected error while generating data.yaml. Check logs for details."
)
self._display_dataset_error("Unexpected error while generating data.yaml. Check logs for details.")
QMessageBox.critical(
self,
"data.yaml Generation Failed",
@@ -755,13 +735,9 @@ class TrainingTab(QWidget):
self.selected_dataset = info
self.dataset_root_label.setText(info["root"]) # type: ignore[arg-type]
self.train_count_label.setText(
self._format_split_info(info["splits"].get("train"))
)
self.train_count_label.setText(self._format_split_info(info["splits"].get("train")))
self.val_count_label.setText(self._format_split_info(info["splits"].get("val")))
self.test_count_label.setText(
self._format_split_info(info["splits"].get("test"))
)
self.test_count_label.setText(self._format_split_info(info["splits"].get("test")))
self.num_classes_label.setText(str(info["num_classes"]))
class_names = ", ".join(info["class_names"]) or ""
self.class_names_label.setText(class_names)
@@ -815,18 +791,12 @@ class TrainingTab(QWidget):
if split_path.exists():
split_info["count"] = self._count_images(split_path)
if split_info["count"] == 0:
warnings.append(
f"No images found for {split_name} split at {split_path}"
)
warnings.append(f"No images found for {split_name} split at {split_path}")
else:
warnings.append(
f"{split_name.capitalize()} path does not exist: {split_path}"
)
warnings.append(f"{split_name.capitalize()} path does not exist: {split_path}")
else:
if split_name in ("train", "val"):
warnings.append(
f"{split_name.capitalize()} split missing in data.yaml"
)
warnings.append(f"{split_name.capitalize()} split missing in data.yaml")
splits[split_name] = split_info
names_list = self._normalize_class_names(data.get("names"))
@@ -844,9 +814,7 @@ class TrainingTab(QWidget):
if not names_list and nc_value:
names_list = [f"class_{idx}" for idx in range(int(nc_value))]
elif nc_value and len(names_list) not in (0, int(nc_value)):
warnings.append(
f"Number of class names ({len(names_list)}) does not match nc={nc_value}"
)
warnings.append(f"Number of class names ({len(names_list)}) does not match nc={nc_value}")
dataset_name = data.get("name") or base_path.name
@@ -898,16 +866,12 @@ class TrainingTab(QWidget):
class_index_map = self._build_class_index_map(dataset_info)
if not class_index_map:
self._append_training_log(
"Skipping label export: dataset classes do not match database entries."
)
self._append_training_log("Skipping label export: dataset classes do not match database entries.")
return
dataset_root_str = dataset_info.get("root")
dataset_yaml_path = dataset_info.get("yaml_path")
dataset_yaml = (
Path(dataset_yaml_path).expanduser() if dataset_yaml_path else None
)
dataset_yaml = Path(dataset_yaml_path).expanduser() if dataset_yaml_path else None
dataset_root: Optional[Path]
if dataset_root_str:
dataset_root = Path(dataset_root_str).resolve()
@@ -941,7 +905,9 @@ class TrainingTab(QWidget):
if stats["registered_images"]:
message += f" {stats['registered_images']} image(s) had database-backed annotations."
if stats["missing_records"]:
message += f" {stats['missing_records']} image(s) had no database entry; empty label files were written."
message += (
f" {stats['missing_records']} image(s) had no database entry; empty label files were written."
)
split_messages.append(message)
for msg in split_messages:
@@ -973,9 +939,7 @@ class TrainingTab(QWidget):
continue
processed_images += 1
label_path = (labels_dir / image_file.relative_to(images_dir)).with_suffix(
".txt"
)
label_path = (labels_dir / image_file.relative_to(images_dir)).with_suffix(".txt")
label_path.parent.mkdir(parents=True, exist_ok=True)
found, annotation_entries = self._fetch_annotations_for_image(
@@ -991,25 +955,23 @@ class TrainingTab(QWidget):
for entry in annotation_entries:
polygon = entry.get("polygon") or []
if polygon:
print(image_file, polygon[:4], polygon[-2:], entry.get("bbox"))
# coords = " ".join(f"{value:.6f}" for value in entry.get("bbox"))
# coords += " "
coords = " ".join(f"{value:.6f}" for value in polygon)
handle.write(f"{entry['class_idx']} {coords}\n")
annotations_written += 1
elif entry.get("bbox"):
x_center, y_center, width, height = entry["bbox"]
handle.write(
f"{entry['class_idx']} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n"
)
handle.write(f"{entry['class_idx']} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n")
annotations_written += 1
total_annotations += annotations_written
cache_reset_root = labels_dir.parent
self._invalidate_split_cache(cache_reset_root)
if processed_images == 0:
self._append_training_log(
f"[{split_name}] No images found to export labels for."
)
self._append_training_log(f"[{split_name}] No images found to export labels for.")
return None
return {
@@ -1135,6 +1097,10 @@ class TrainingTab(QWidget):
xs.append(x_val)
ys.append(y_val)
if any(np.abs(np.array(coords[:2]) - np.array(coords[-2:])) < 1e-5):
print("Closing polygon")
coords.extend(coords[:2])
if len(coords) < 6:
continue
@@ -1147,6 +1113,11 @@ class TrainingTab(QWidget):
+ abs((min(ys) if ys else 0.0) - y_min)
+ abs((max(ys) if ys else 0.0) - y_max)
)
width = max(0.0, x_max - x_min)
height = max(0.0, y_max - y_min)
x_center = x_min + width / 2.0
y_center = y_min + height / 2.0
score = (x_center, y_center, width, height)
candidates.append((score, coords))
@@ -1164,13 +1135,10 @@ class TrainingTab(QWidget):
return 1.0
return value
def _prepare_dataset_for_training(
self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None
) -> Path:
def _prepare_dataset_for_training(self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None) -> Path:
dataset_info = dataset_info or (
self.selected_dataset
if self.selected_dataset
and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
if self.selected_dataset and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
else self._parse_dataset_yaml(dataset_yaml)
)
@@ -1189,14 +1157,10 @@ class TrainingTab(QWidget):
cache_root = self._get_rgb_cache_root(dataset_yaml)
rgb_yaml = cache_root / "data.yaml"
if rgb_yaml.exists():
self._append_training_log(
f"Detected grayscale dataset; reusing RGB cache at {cache_root}"
)
self._append_training_log(f"Detected grayscale dataset; reusing RGB cache at {cache_root}")
return rgb_yaml
self._append_training_log(
f"Detected grayscale dataset; creating RGB cache at {cache_root}"
)
self._append_training_log(f"Detected grayscale dataset; creating RGB cache at {cache_root}")
self._build_rgb_dataset(cache_root, dataset_info)
return rgb_yaml
@@ -1463,15 +1427,12 @@ class TrainingTab(QWidget):
dataset_path = Path(dataset_yaml).expanduser()
if not dataset_path.exists():
QMessageBox.warning(
self, "Invalid Dataset", "Selected data.yaml file does not exist."
)
QMessageBox.warning(self, "Invalid Dataset", "Selected data.yaml file does not exist.")
return
dataset_info = (
self.selected_dataset
if self.selected_dataset
and self.selected_dataset.get("yaml_path") == str(dataset_path)
if self.selected_dataset and self.selected_dataset.get("yaml_path") == str(dataset_path)
else self._parse_dataset_yaml(dataset_path)
)
@@ -1480,16 +1441,12 @@ class TrainingTab(QWidget):
dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info)
if dataset_to_use != dataset_path:
self._append_training_log(
f"Using RGB-converted dataset at {dataset_to_use.parent}"
)
self._append_training_log(f"Using RGB-converted dataset at {dataset_to_use.parent}")
params = self._collect_training_params()
stage_plan = self._compose_stage_plan(params)
params["stage_plan"] = stage_plan
total_planned_epochs = (
self._calculate_total_stage_epochs(stage_plan) or params["epochs"]
)
total_planned_epochs = self._calculate_total_stage_epochs(stage_plan) or params["epochs"]
params["total_planned_epochs"] = total_planned_epochs
self._active_training_params = params
self._training_cancelled = False
@@ -1498,9 +1455,7 @@ class TrainingTab(QWidget):
self._append_training_log("Two-stage fine-tuning schedule:")
self._log_stage_plan(stage_plan)
self._append_training_log(
f"Starting training run '{params['run_name']}' using {params['base_model']}"
)
self._append_training_log(f"Starting training run '{params['run_name']}' using {params['base_model']}")
self.training_progress_bar.setVisible(True)
self.training_progress_bar.setMaximum(max(1, total_planned_epochs))
@@ -1528,9 +1483,7 @@ class TrainingTab(QWidget):
def _stop_training(self):
if self.training_worker and self.training_worker.isRunning():
self._training_cancelled = True
self._append_training_log(
"Stop requested. Waiting for the current epoch to finish..."
)
self._append_training_log("Stop requested. Waiting for the current epoch to finish...")
self.training_worker.stop()
self.stop_training_button.setEnabled(False)
@@ -1566,9 +1519,7 @@ class TrainingTab(QWidget):
if worker.isRunning():
if not worker.wait(wait_timeout_ms):
logger.warning(
"Training worker did not finish within %sms", wait_timeout_ms
)
logger.warning("Training worker did not finish within %sms", wait_timeout_ms)
worker.deleteLater()
@@ -1585,16 +1536,12 @@ class TrainingTab(QWidget):
self._set_training_state(False)
self.training_progress_bar.setVisible(False)
def _on_training_progress(
self, current_epoch: int, total_epochs: int, metrics: Dict[str, Any]
):
def _on_training_progress(self, current_epoch: int, total_epochs: int, metrics: Dict[str, Any]):
self.training_progress_bar.setMaximum(total_epochs)
self.training_progress_bar.setValue(current_epoch)
parts = [f"Epoch {current_epoch}/{total_epochs}"]
if metrics:
metric_text = ", ".join(
f"{key}: {value:.4f}" for key, value in metrics.items()
)
metric_text = ", ".join(f"{key}: {value:.4f}" for key, value in metrics.items())
parts.append(metric_text)
self._append_training_log(" | ".join(parts))
@@ -1621,9 +1568,7 @@ class TrainingTab(QWidget):
f"Model trained but not registered: {exc}",
)
else:
QMessageBox.information(
self, "Training Complete", "Training finished successfully."
)
QMessageBox.information(self, "Training Complete", "Training finished successfully.")
def _on_training_error(self, message: str):
self._cleanup_training_worker()
@@ -1669,9 +1614,7 @@ class TrainingTab(QWidget):
metrics=results.get("metrics"),
)
self._append_training_log(
f"Registered model '{params['model_name']}' (ID {model_id}) at {model_path}"
)
self._append_training_log(f"Registered model '{params['model_name']}' (ID {model_id}) at {model_path}")
self._active_training_params = None
def _set_training_state(self, is_training: bool):
@@ -1714,9 +1657,7 @@ class TrainingTab(QWidget):
def _browse_save_dir(self):
start_path = self.save_dir_edit.text().strip() or "data/models"
directory = QFileDialog.getExistingDirectory(
self, "Select Save Directory", start_path
)
directory = QFileDialog.getExistingDirectory(self, "Select Save Directory", start_path)
if directory:
self.save_dir_edit.setText(directory)