From fcbd5fb16d45527545e0260598ea2c97cfafffd7 Mon Sep 17 00:00:00 2001 From: Martin Laasmaa Date: Fri, 16 Jan 2026 10:24:19 +0200 Subject: [PATCH] correcting label writing and formatting code --- src/gui/tabs/training_tab.py | 165 +++++++++++------------------------ 1 file changed, 53 insertions(+), 112 deletions(-) diff --git a/src/gui/tabs/training_tab.py b/src/gui/tabs/training_tab.py index 058b5ae..d6bdbd0 100644 --- a/src/gui/tabs/training_tab.py +++ b/src/gui/tabs/training_tab.py @@ -10,6 +10,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import yaml +import numpy as np from PySide6.QtCore import Qt, QThread, Signal from PySide6.QtWidgets import ( QWidget, @@ -91,10 +92,7 @@ class TrainingWorker(QThread): }, } ] - computed_total = sum( - max(0, int((stage.get("params") or {}).get("epochs", 0))) - for stage in self.stage_plan - ) + computed_total = sum(max(0, int((stage.get("params") or {}).get("epochs", 0))) for stage in self.stage_plan) self.total_epochs = total_epochs if total_epochs else computed_total or epochs self._stop_requested = False @@ -201,9 +199,7 @@ class TrainingWorker(QThread): class TrainingTab(QWidget): """Training tab for model training.""" - def __init__( - self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None - ): + def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None): super().__init__(parent) self.db_manager = db_manager self.config_manager = config_manager @@ -337,18 +333,14 @@ class TrainingTab(QWidget): self.model_version_edit = QLineEdit("v1") form_layout.addRow("Version:", self.model_version_edit) - default_base_model = self.config_manager.get( - "models.default_base_model", "yolov8s-seg.pt" - ) + default_base_model = self.config_manager.get("models.default_base_model", "yolov8s-seg.pt") base_model_choices = self.config_manager.get("models.base_model_choices", []) self.base_model_combo = QComboBox() self.base_model_combo.addItem("Custom path…", "") for choice in base_model_choices: self.base_model_combo.addItem(choice, choice) - self.base_model_combo.currentIndexChanged.connect( - self._on_base_model_preset_changed - ) + self.base_model_combo.currentIndexChanged.connect(self._on_base_model_preset_changed) form_layout.addRow("Base Model Preset:", self.base_model_combo) base_model_layout = QHBoxLayout() @@ -434,12 +426,8 @@ class TrainingTab(QWidget): group_layout = QVBoxLayout() self.two_stage_checkbox = QCheckBox("Enable staged head-only + full fine-tune") - two_stage_defaults = ( - training_defaults.get("two_stage", {}) if training_defaults else {} - ) - self.two_stage_checkbox.setChecked( - bool(two_stage_defaults.get("enabled", False)) - ) + two_stage_defaults = training_defaults.get("two_stage", {}) if training_defaults else {} + self.two_stage_checkbox.setChecked(bool(two_stage_defaults.get("enabled", False))) self.two_stage_checkbox.toggled.connect(self._on_two_stage_toggled) group_layout.addWidget(self.two_stage_checkbox) @@ -501,9 +489,7 @@ class TrainingTab(QWidget): stage2_group.setLayout(stage2_form) controls_layout.addWidget(stage2_group) - helper_label = QLabel( - "When enabled, staged hyperparameters override the global epochs/patience/lr." - ) + helper_label = QLabel("When enabled, staged hyperparameters override the global epochs/patience/lr.") helper_label.setWordWrap(True) controls_layout.addWidget(helper_label) @@ -548,9 +534,7 @@ class TrainingTab(QWidget): if normalized == preset_value: target_index = idx break - if normalized.endswith(f"/{preset_value}") or normalized.endswith( - f"\\{preset_value}" - ): + if normalized.endswith(f"/{preset_value}") or normalized.endswith(f"\\{preset_value}"): target_index = idx break self.base_model_combo.blockSignals(True) @@ -638,9 +622,7 @@ class TrainingTab(QWidget): def _browse_dataset(self): """Open a file dialog to manually select data.yaml.""" - start_dir = self.config_manager.get( - "training.last_dataset_dir", "data/datasets" - ) + start_dir = self.config_manager.get("training.last_dataset_dir", "data/datasets") start_path = Path(start_dir).expanduser() if not start_path.exists(): start_path = Path.cwd() @@ -676,9 +658,7 @@ class TrainingTab(QWidget): return except Exception as exc: logger.exception("Unexpected error while generating data.yaml") - self._display_dataset_error( - "Unexpected error while generating data.yaml. Check logs for details." - ) + self._display_dataset_error("Unexpected error while generating data.yaml. Check logs for details.") QMessageBox.critical( self, "data.yaml Generation Failed", @@ -755,13 +735,9 @@ class TrainingTab(QWidget): self.selected_dataset = info self.dataset_root_label.setText(info["root"]) # type: ignore[arg-type] - self.train_count_label.setText( - self._format_split_info(info["splits"].get("train")) - ) + self.train_count_label.setText(self._format_split_info(info["splits"].get("train"))) self.val_count_label.setText(self._format_split_info(info["splits"].get("val"))) - self.test_count_label.setText( - self._format_split_info(info["splits"].get("test")) - ) + self.test_count_label.setText(self._format_split_info(info["splits"].get("test"))) self.num_classes_label.setText(str(info["num_classes"])) class_names = ", ".join(info["class_names"]) or "–" self.class_names_label.setText(class_names) @@ -815,18 +791,12 @@ class TrainingTab(QWidget): if split_path.exists(): split_info["count"] = self._count_images(split_path) if split_info["count"] == 0: - warnings.append( - f"No images found for {split_name} split at {split_path}" - ) + warnings.append(f"No images found for {split_name} split at {split_path}") else: - warnings.append( - f"{split_name.capitalize()} path does not exist: {split_path}" - ) + warnings.append(f"{split_name.capitalize()} path does not exist: {split_path}") else: if split_name in ("train", "val"): - warnings.append( - f"{split_name.capitalize()} split missing in data.yaml" - ) + warnings.append(f"{split_name.capitalize()} split missing in data.yaml") splits[split_name] = split_info names_list = self._normalize_class_names(data.get("names")) @@ -844,9 +814,7 @@ class TrainingTab(QWidget): if not names_list and nc_value: names_list = [f"class_{idx}" for idx in range(int(nc_value))] elif nc_value and len(names_list) not in (0, int(nc_value)): - warnings.append( - f"Number of class names ({len(names_list)}) does not match nc={nc_value}" - ) + warnings.append(f"Number of class names ({len(names_list)}) does not match nc={nc_value}") dataset_name = data.get("name") or base_path.name @@ -898,16 +866,12 @@ class TrainingTab(QWidget): class_index_map = self._build_class_index_map(dataset_info) if not class_index_map: - self._append_training_log( - "Skipping label export: dataset classes do not match database entries." - ) + self._append_training_log("Skipping label export: dataset classes do not match database entries.") return dataset_root_str = dataset_info.get("root") dataset_yaml_path = dataset_info.get("yaml_path") - dataset_yaml = ( - Path(dataset_yaml_path).expanduser() if dataset_yaml_path else None - ) + dataset_yaml = Path(dataset_yaml_path).expanduser() if dataset_yaml_path else None dataset_root: Optional[Path] if dataset_root_str: dataset_root = Path(dataset_root_str).resolve() @@ -941,7 +905,9 @@ class TrainingTab(QWidget): if stats["registered_images"]: message += f" {stats['registered_images']} image(s) had database-backed annotations." if stats["missing_records"]: - message += f" {stats['missing_records']} image(s) had no database entry; empty label files were written." + message += ( + f" {stats['missing_records']} image(s) had no database entry; empty label files were written." + ) split_messages.append(message) for msg in split_messages: @@ -973,9 +939,7 @@ class TrainingTab(QWidget): continue processed_images += 1 - label_path = (labels_dir / image_file.relative_to(images_dir)).with_suffix( - ".txt" - ) + label_path = (labels_dir / image_file.relative_to(images_dir)).with_suffix(".txt") label_path.parent.mkdir(parents=True, exist_ok=True) found, annotation_entries = self._fetch_annotations_for_image( @@ -991,25 +955,23 @@ class TrainingTab(QWidget): for entry in annotation_entries: polygon = entry.get("polygon") or [] if polygon: + print(image_file, polygon[:4], polygon[-2:], entry.get("bbox")) + # coords = " ".join(f"{value:.6f}" for value in entry.get("bbox")) + # coords += " " coords = " ".join(f"{value:.6f}" for value in polygon) handle.write(f"{entry['class_idx']} {coords}\n") annotations_written += 1 elif entry.get("bbox"): x_center, y_center, width, height = entry["bbox"] - handle.write( - f"{entry['class_idx']} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n" - ) + handle.write(f"{entry['class_idx']} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n") annotations_written += 1 - total_annotations += annotations_written cache_reset_root = labels_dir.parent self._invalidate_split_cache(cache_reset_root) if processed_images == 0: - self._append_training_log( - f"[{split_name}] No images found to export labels for." - ) + self._append_training_log(f"[{split_name}] No images found to export labels for.") return None return { @@ -1135,6 +1097,10 @@ class TrainingTab(QWidget): xs.append(x_val) ys.append(y_val) + if any(np.abs(np.array(coords[:2]) - np.array(coords[-2:])) < 1e-5): + print("Closing polygon") + coords.extend(coords[:2]) + if len(coords) < 6: continue @@ -1147,6 +1113,11 @@ class TrainingTab(QWidget): + abs((min(ys) if ys else 0.0) - y_min) + abs((max(ys) if ys else 0.0) - y_max) ) + width = max(0.0, x_max - x_min) + height = max(0.0, y_max - y_min) + x_center = x_min + width / 2.0 + y_center = y_min + height / 2.0 + score = (x_center, y_center, width, height) candidates.append((score, coords)) @@ -1164,13 +1135,10 @@ class TrainingTab(QWidget): return 1.0 return value - def _prepare_dataset_for_training( - self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None - ) -> Path: + def _prepare_dataset_for_training(self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None) -> Path: dataset_info = dataset_info or ( self.selected_dataset - if self.selected_dataset - and self.selected_dataset.get("yaml_path") == str(dataset_yaml) + if self.selected_dataset and self.selected_dataset.get("yaml_path") == str(dataset_yaml) else self._parse_dataset_yaml(dataset_yaml) ) @@ -1189,14 +1157,10 @@ class TrainingTab(QWidget): cache_root = self._get_rgb_cache_root(dataset_yaml) rgb_yaml = cache_root / "data.yaml" if rgb_yaml.exists(): - self._append_training_log( - f"Detected grayscale dataset; reusing RGB cache at {cache_root}" - ) + self._append_training_log(f"Detected grayscale dataset; reusing RGB cache at {cache_root}") return rgb_yaml - self._append_training_log( - f"Detected grayscale dataset; creating RGB cache at {cache_root}" - ) + self._append_training_log(f"Detected grayscale dataset; creating RGB cache at {cache_root}") self._build_rgb_dataset(cache_root, dataset_info) return rgb_yaml @@ -1463,15 +1427,12 @@ class TrainingTab(QWidget): dataset_path = Path(dataset_yaml).expanduser() if not dataset_path.exists(): - QMessageBox.warning( - self, "Invalid Dataset", "Selected data.yaml file does not exist." - ) + QMessageBox.warning(self, "Invalid Dataset", "Selected data.yaml file does not exist.") return dataset_info = ( self.selected_dataset - if self.selected_dataset - and self.selected_dataset.get("yaml_path") == str(dataset_path) + if self.selected_dataset and self.selected_dataset.get("yaml_path") == str(dataset_path) else self._parse_dataset_yaml(dataset_path) ) @@ -1480,16 +1441,12 @@ class TrainingTab(QWidget): dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info) if dataset_to_use != dataset_path: - self._append_training_log( - f"Using RGB-converted dataset at {dataset_to_use.parent}" - ) + self._append_training_log(f"Using RGB-converted dataset at {dataset_to_use.parent}") params = self._collect_training_params() stage_plan = self._compose_stage_plan(params) params["stage_plan"] = stage_plan - total_planned_epochs = ( - self._calculate_total_stage_epochs(stage_plan) or params["epochs"] - ) + total_planned_epochs = self._calculate_total_stage_epochs(stage_plan) or params["epochs"] params["total_planned_epochs"] = total_planned_epochs self._active_training_params = params self._training_cancelled = False @@ -1498,9 +1455,7 @@ class TrainingTab(QWidget): self._append_training_log("Two-stage fine-tuning schedule:") self._log_stage_plan(stage_plan) - self._append_training_log( - f"Starting training run '{params['run_name']}' using {params['base_model']}" - ) + self._append_training_log(f"Starting training run '{params['run_name']}' using {params['base_model']}") self.training_progress_bar.setVisible(True) self.training_progress_bar.setMaximum(max(1, total_planned_epochs)) @@ -1528,9 +1483,7 @@ class TrainingTab(QWidget): def _stop_training(self): if self.training_worker and self.training_worker.isRunning(): self._training_cancelled = True - self._append_training_log( - "Stop requested. Waiting for the current epoch to finish..." - ) + self._append_training_log("Stop requested. Waiting for the current epoch to finish...") self.training_worker.stop() self.stop_training_button.setEnabled(False) @@ -1566,9 +1519,7 @@ class TrainingTab(QWidget): if worker.isRunning(): if not worker.wait(wait_timeout_ms): - logger.warning( - "Training worker did not finish within %sms", wait_timeout_ms - ) + logger.warning("Training worker did not finish within %sms", wait_timeout_ms) worker.deleteLater() @@ -1585,16 +1536,12 @@ class TrainingTab(QWidget): self._set_training_state(False) self.training_progress_bar.setVisible(False) - def _on_training_progress( - self, current_epoch: int, total_epochs: int, metrics: Dict[str, Any] - ): + def _on_training_progress(self, current_epoch: int, total_epochs: int, metrics: Dict[str, Any]): self.training_progress_bar.setMaximum(total_epochs) self.training_progress_bar.setValue(current_epoch) parts = [f"Epoch {current_epoch}/{total_epochs}"] if metrics: - metric_text = ", ".join( - f"{key}: {value:.4f}" for key, value in metrics.items() - ) + metric_text = ", ".join(f"{key}: {value:.4f}" for key, value in metrics.items()) parts.append(metric_text) self._append_training_log(" | ".join(parts)) @@ -1621,9 +1568,7 @@ class TrainingTab(QWidget): f"Model trained but not registered: {exc}", ) else: - QMessageBox.information( - self, "Training Complete", "Training finished successfully." - ) + QMessageBox.information(self, "Training Complete", "Training finished successfully.") def _on_training_error(self, message: str): self._cleanup_training_worker() @@ -1669,9 +1614,7 @@ class TrainingTab(QWidget): metrics=results.get("metrics"), ) - self._append_training_log( - f"Registered model '{params['model_name']}' (ID {model_id}) at {model_path}" - ) + self._append_training_log(f"Registered model '{params['model_name']}' (ID {model_id}) at {model_path}") self._active_training_params = None def _set_training_state(self, is_training: bool): @@ -1714,9 +1657,7 @@ class TrainingTab(QWidget): def _browse_save_dir(self): start_path = self.save_dir_edit.text().strip() or "data/models" - directory = QFileDialog.getExistingDirectory( - self, "Select Save Directory", start_path - ) + directory = QFileDialog.getExistingDirectory(self, "Select Save Directory", start_path) if directory: self.save_dir_edit.setText(directory)