correcting label writing and formatting code

This commit is contained in:
2026-01-16 10:24:19 +02:00
parent ca52312925
commit fcbd5fb16d

View File

@@ -10,6 +10,7 @@ from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import yaml import yaml
import numpy as np
from PySide6.QtCore import Qt, QThread, Signal from PySide6.QtCore import Qt, QThread, Signal
from PySide6.QtWidgets import ( from PySide6.QtWidgets import (
QWidget, QWidget,
@@ -91,10 +92,7 @@ class TrainingWorker(QThread):
}, },
} }
] ]
computed_total = sum( computed_total = sum(max(0, int((stage.get("params") or {}).get("epochs", 0))) for stage in self.stage_plan)
max(0, int((stage.get("params") or {}).get("epochs", 0)))
for stage in self.stage_plan
)
self.total_epochs = total_epochs if total_epochs else computed_total or epochs self.total_epochs = total_epochs if total_epochs else computed_total or epochs
self._stop_requested = False self._stop_requested = False
@@ -201,9 +199,7 @@ class TrainingWorker(QThread):
class TrainingTab(QWidget): class TrainingTab(QWidget):
"""Training tab for model training.""" """Training tab for model training."""
def __init__( def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
):
super().__init__(parent) super().__init__(parent)
self.db_manager = db_manager self.db_manager = db_manager
self.config_manager = config_manager self.config_manager = config_manager
@@ -337,18 +333,14 @@ class TrainingTab(QWidget):
self.model_version_edit = QLineEdit("v1") self.model_version_edit = QLineEdit("v1")
form_layout.addRow("Version:", self.model_version_edit) form_layout.addRow("Version:", self.model_version_edit)
default_base_model = self.config_manager.get( default_base_model = self.config_manager.get("models.default_base_model", "yolov8s-seg.pt")
"models.default_base_model", "yolov8s-seg.pt"
)
base_model_choices = self.config_manager.get("models.base_model_choices", []) base_model_choices = self.config_manager.get("models.base_model_choices", [])
self.base_model_combo = QComboBox() self.base_model_combo = QComboBox()
self.base_model_combo.addItem("Custom path…", "") self.base_model_combo.addItem("Custom path…", "")
for choice in base_model_choices: for choice in base_model_choices:
self.base_model_combo.addItem(choice, choice) self.base_model_combo.addItem(choice, choice)
self.base_model_combo.currentIndexChanged.connect( self.base_model_combo.currentIndexChanged.connect(self._on_base_model_preset_changed)
self._on_base_model_preset_changed
)
form_layout.addRow("Base Model Preset:", self.base_model_combo) form_layout.addRow("Base Model Preset:", self.base_model_combo)
base_model_layout = QHBoxLayout() base_model_layout = QHBoxLayout()
@@ -434,12 +426,8 @@ class TrainingTab(QWidget):
group_layout = QVBoxLayout() group_layout = QVBoxLayout()
self.two_stage_checkbox = QCheckBox("Enable staged head-only + full fine-tune") self.two_stage_checkbox = QCheckBox("Enable staged head-only + full fine-tune")
two_stage_defaults = ( two_stage_defaults = training_defaults.get("two_stage", {}) if training_defaults else {}
training_defaults.get("two_stage", {}) if training_defaults else {} self.two_stage_checkbox.setChecked(bool(two_stage_defaults.get("enabled", False)))
)
self.two_stage_checkbox.setChecked(
bool(two_stage_defaults.get("enabled", False))
)
self.two_stage_checkbox.toggled.connect(self._on_two_stage_toggled) self.two_stage_checkbox.toggled.connect(self._on_two_stage_toggled)
group_layout.addWidget(self.two_stage_checkbox) group_layout.addWidget(self.two_stage_checkbox)
@@ -501,9 +489,7 @@ class TrainingTab(QWidget):
stage2_group.setLayout(stage2_form) stage2_group.setLayout(stage2_form)
controls_layout.addWidget(stage2_group) controls_layout.addWidget(stage2_group)
helper_label = QLabel( helper_label = QLabel("When enabled, staged hyperparameters override the global epochs/patience/lr.")
"When enabled, staged hyperparameters override the global epochs/patience/lr."
)
helper_label.setWordWrap(True) helper_label.setWordWrap(True)
controls_layout.addWidget(helper_label) controls_layout.addWidget(helper_label)
@@ -548,9 +534,7 @@ class TrainingTab(QWidget):
if normalized == preset_value: if normalized == preset_value:
target_index = idx target_index = idx
break break
if normalized.endswith(f"/{preset_value}") or normalized.endswith( if normalized.endswith(f"/{preset_value}") or normalized.endswith(f"\\{preset_value}"):
f"\\{preset_value}"
):
target_index = idx target_index = idx
break break
self.base_model_combo.blockSignals(True) self.base_model_combo.blockSignals(True)
@@ -638,9 +622,7 @@ class TrainingTab(QWidget):
def _browse_dataset(self): def _browse_dataset(self):
"""Open a file dialog to manually select data.yaml.""" """Open a file dialog to manually select data.yaml."""
start_dir = self.config_manager.get( start_dir = self.config_manager.get("training.last_dataset_dir", "data/datasets")
"training.last_dataset_dir", "data/datasets"
)
start_path = Path(start_dir).expanduser() start_path = Path(start_dir).expanduser()
if not start_path.exists(): if not start_path.exists():
start_path = Path.cwd() start_path = Path.cwd()
@@ -676,9 +658,7 @@ class TrainingTab(QWidget):
return return
except Exception as exc: except Exception as exc:
logger.exception("Unexpected error while generating data.yaml") logger.exception("Unexpected error while generating data.yaml")
self._display_dataset_error( self._display_dataset_error("Unexpected error while generating data.yaml. Check logs for details.")
"Unexpected error while generating data.yaml. Check logs for details."
)
QMessageBox.critical( QMessageBox.critical(
self, self,
"data.yaml Generation Failed", "data.yaml Generation Failed",
@@ -755,13 +735,9 @@ class TrainingTab(QWidget):
self.selected_dataset = info self.selected_dataset = info
self.dataset_root_label.setText(info["root"]) # type: ignore[arg-type] self.dataset_root_label.setText(info["root"]) # type: ignore[arg-type]
self.train_count_label.setText( self.train_count_label.setText(self._format_split_info(info["splits"].get("train")))
self._format_split_info(info["splits"].get("train"))
)
self.val_count_label.setText(self._format_split_info(info["splits"].get("val"))) self.val_count_label.setText(self._format_split_info(info["splits"].get("val")))
self.test_count_label.setText( self.test_count_label.setText(self._format_split_info(info["splits"].get("test")))
self._format_split_info(info["splits"].get("test"))
)
self.num_classes_label.setText(str(info["num_classes"])) self.num_classes_label.setText(str(info["num_classes"]))
class_names = ", ".join(info["class_names"]) or "" class_names = ", ".join(info["class_names"]) or ""
self.class_names_label.setText(class_names) self.class_names_label.setText(class_names)
@@ -815,18 +791,12 @@ class TrainingTab(QWidget):
if split_path.exists(): if split_path.exists():
split_info["count"] = self._count_images(split_path) split_info["count"] = self._count_images(split_path)
if split_info["count"] == 0: if split_info["count"] == 0:
warnings.append( warnings.append(f"No images found for {split_name} split at {split_path}")
f"No images found for {split_name} split at {split_path}"
)
else: else:
warnings.append( warnings.append(f"{split_name.capitalize()} path does not exist: {split_path}")
f"{split_name.capitalize()} path does not exist: {split_path}"
)
else: else:
if split_name in ("train", "val"): if split_name in ("train", "val"):
warnings.append( warnings.append(f"{split_name.capitalize()} split missing in data.yaml")
f"{split_name.capitalize()} split missing in data.yaml"
)
splits[split_name] = split_info splits[split_name] = split_info
names_list = self._normalize_class_names(data.get("names")) names_list = self._normalize_class_names(data.get("names"))
@@ -844,9 +814,7 @@ class TrainingTab(QWidget):
if not names_list and nc_value: if not names_list and nc_value:
names_list = [f"class_{idx}" for idx in range(int(nc_value))] names_list = [f"class_{idx}" for idx in range(int(nc_value))]
elif nc_value and len(names_list) not in (0, int(nc_value)): elif nc_value and len(names_list) not in (0, int(nc_value)):
warnings.append( warnings.append(f"Number of class names ({len(names_list)}) does not match nc={nc_value}")
f"Number of class names ({len(names_list)}) does not match nc={nc_value}"
)
dataset_name = data.get("name") or base_path.name dataset_name = data.get("name") or base_path.name
@@ -898,16 +866,12 @@ class TrainingTab(QWidget):
class_index_map = self._build_class_index_map(dataset_info) class_index_map = self._build_class_index_map(dataset_info)
if not class_index_map: if not class_index_map:
self._append_training_log( self._append_training_log("Skipping label export: dataset classes do not match database entries.")
"Skipping label export: dataset classes do not match database entries."
)
return return
dataset_root_str = dataset_info.get("root") dataset_root_str = dataset_info.get("root")
dataset_yaml_path = dataset_info.get("yaml_path") dataset_yaml_path = dataset_info.get("yaml_path")
dataset_yaml = ( dataset_yaml = Path(dataset_yaml_path).expanduser() if dataset_yaml_path else None
Path(dataset_yaml_path).expanduser() if dataset_yaml_path else None
)
dataset_root: Optional[Path] dataset_root: Optional[Path]
if dataset_root_str: if dataset_root_str:
dataset_root = Path(dataset_root_str).resolve() dataset_root = Path(dataset_root_str).resolve()
@@ -941,7 +905,9 @@ class TrainingTab(QWidget):
if stats["registered_images"]: if stats["registered_images"]:
message += f" {stats['registered_images']} image(s) had database-backed annotations." message += f" {stats['registered_images']} image(s) had database-backed annotations."
if stats["missing_records"]: if stats["missing_records"]:
message += f" {stats['missing_records']} image(s) had no database entry; empty label files were written." message += (
f" {stats['missing_records']} image(s) had no database entry; empty label files were written."
)
split_messages.append(message) split_messages.append(message)
for msg in split_messages: for msg in split_messages:
@@ -973,9 +939,7 @@ class TrainingTab(QWidget):
continue continue
processed_images += 1 processed_images += 1
label_path = (labels_dir / image_file.relative_to(images_dir)).with_suffix( label_path = (labels_dir / image_file.relative_to(images_dir)).with_suffix(".txt")
".txt"
)
label_path.parent.mkdir(parents=True, exist_ok=True) label_path.parent.mkdir(parents=True, exist_ok=True)
found, annotation_entries = self._fetch_annotations_for_image( found, annotation_entries = self._fetch_annotations_for_image(
@@ -991,25 +955,23 @@ class TrainingTab(QWidget):
for entry in annotation_entries: for entry in annotation_entries:
polygon = entry.get("polygon") or [] polygon = entry.get("polygon") or []
if polygon: if polygon:
print(image_file, polygon[:4], polygon[-2:], entry.get("bbox"))
# coords = " ".join(f"{value:.6f}" for value in entry.get("bbox"))
# coords += " "
coords = " ".join(f"{value:.6f}" for value in polygon) coords = " ".join(f"{value:.6f}" for value in polygon)
handle.write(f"{entry['class_idx']} {coords}\n") handle.write(f"{entry['class_idx']} {coords}\n")
annotations_written += 1 annotations_written += 1
elif entry.get("bbox"): elif entry.get("bbox"):
x_center, y_center, width, height = entry["bbox"] x_center, y_center, width, height = entry["bbox"]
handle.write( handle.write(f"{entry['class_idx']} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n")
f"{entry['class_idx']} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n"
)
annotations_written += 1 annotations_written += 1
total_annotations += annotations_written total_annotations += annotations_written
cache_reset_root = labels_dir.parent cache_reset_root = labels_dir.parent
self._invalidate_split_cache(cache_reset_root) self._invalidate_split_cache(cache_reset_root)
if processed_images == 0: if processed_images == 0:
self._append_training_log( self._append_training_log(f"[{split_name}] No images found to export labels for.")
f"[{split_name}] No images found to export labels for."
)
return None return None
return { return {
@@ -1135,6 +1097,10 @@ class TrainingTab(QWidget):
xs.append(x_val) xs.append(x_val)
ys.append(y_val) ys.append(y_val)
if any(np.abs(np.array(coords[:2]) - np.array(coords[-2:])) < 1e-5):
print("Closing polygon")
coords.extend(coords[:2])
if len(coords) < 6: if len(coords) < 6:
continue continue
@@ -1147,6 +1113,11 @@ class TrainingTab(QWidget):
+ abs((min(ys) if ys else 0.0) - y_min) + abs((min(ys) if ys else 0.0) - y_min)
+ abs((max(ys) if ys else 0.0) - y_max) + abs((max(ys) if ys else 0.0) - y_max)
) )
width = max(0.0, x_max - x_min)
height = max(0.0, y_max - y_min)
x_center = x_min + width / 2.0
y_center = y_min + height / 2.0
score = (x_center, y_center, width, height)
candidates.append((score, coords)) candidates.append((score, coords))
@@ -1164,13 +1135,10 @@ class TrainingTab(QWidget):
return 1.0 return 1.0
return value return value
def _prepare_dataset_for_training( def _prepare_dataset_for_training(self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None) -> Path:
self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None
) -> Path:
dataset_info = dataset_info or ( dataset_info = dataset_info or (
self.selected_dataset self.selected_dataset
if self.selected_dataset if self.selected_dataset and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
else self._parse_dataset_yaml(dataset_yaml) else self._parse_dataset_yaml(dataset_yaml)
) )
@@ -1189,14 +1157,10 @@ class TrainingTab(QWidget):
cache_root = self._get_rgb_cache_root(dataset_yaml) cache_root = self._get_rgb_cache_root(dataset_yaml)
rgb_yaml = cache_root / "data.yaml" rgb_yaml = cache_root / "data.yaml"
if rgb_yaml.exists(): if rgb_yaml.exists():
self._append_training_log( self._append_training_log(f"Detected grayscale dataset; reusing RGB cache at {cache_root}")
f"Detected grayscale dataset; reusing RGB cache at {cache_root}"
)
return rgb_yaml return rgb_yaml
self._append_training_log( self._append_training_log(f"Detected grayscale dataset; creating RGB cache at {cache_root}")
f"Detected grayscale dataset; creating RGB cache at {cache_root}"
)
self._build_rgb_dataset(cache_root, dataset_info) self._build_rgb_dataset(cache_root, dataset_info)
return rgb_yaml return rgb_yaml
@@ -1463,15 +1427,12 @@ class TrainingTab(QWidget):
dataset_path = Path(dataset_yaml).expanduser() dataset_path = Path(dataset_yaml).expanduser()
if not dataset_path.exists(): if not dataset_path.exists():
QMessageBox.warning( QMessageBox.warning(self, "Invalid Dataset", "Selected data.yaml file does not exist.")
self, "Invalid Dataset", "Selected data.yaml file does not exist."
)
return return
dataset_info = ( dataset_info = (
self.selected_dataset self.selected_dataset
if self.selected_dataset if self.selected_dataset and self.selected_dataset.get("yaml_path") == str(dataset_path)
and self.selected_dataset.get("yaml_path") == str(dataset_path)
else self._parse_dataset_yaml(dataset_path) else self._parse_dataset_yaml(dataset_path)
) )
@@ -1480,16 +1441,12 @@ class TrainingTab(QWidget):
dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info) dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info)
if dataset_to_use != dataset_path: if dataset_to_use != dataset_path:
self._append_training_log( self._append_training_log(f"Using RGB-converted dataset at {dataset_to_use.parent}")
f"Using RGB-converted dataset at {dataset_to_use.parent}"
)
params = self._collect_training_params() params = self._collect_training_params()
stage_plan = self._compose_stage_plan(params) stage_plan = self._compose_stage_plan(params)
params["stage_plan"] = stage_plan params["stage_plan"] = stage_plan
total_planned_epochs = ( total_planned_epochs = self._calculate_total_stage_epochs(stage_plan) or params["epochs"]
self._calculate_total_stage_epochs(stage_plan) or params["epochs"]
)
params["total_planned_epochs"] = total_planned_epochs params["total_planned_epochs"] = total_planned_epochs
self._active_training_params = params self._active_training_params = params
self._training_cancelled = False self._training_cancelled = False
@@ -1498,9 +1455,7 @@ class TrainingTab(QWidget):
self._append_training_log("Two-stage fine-tuning schedule:") self._append_training_log("Two-stage fine-tuning schedule:")
self._log_stage_plan(stage_plan) self._log_stage_plan(stage_plan)
self._append_training_log( self._append_training_log(f"Starting training run '{params['run_name']}' using {params['base_model']}")
f"Starting training run '{params['run_name']}' using {params['base_model']}"
)
self.training_progress_bar.setVisible(True) self.training_progress_bar.setVisible(True)
self.training_progress_bar.setMaximum(max(1, total_planned_epochs)) self.training_progress_bar.setMaximum(max(1, total_planned_epochs))
@@ -1528,9 +1483,7 @@ class TrainingTab(QWidget):
def _stop_training(self): def _stop_training(self):
if self.training_worker and self.training_worker.isRunning(): if self.training_worker and self.training_worker.isRunning():
self._training_cancelled = True self._training_cancelled = True
self._append_training_log( self._append_training_log("Stop requested. Waiting for the current epoch to finish...")
"Stop requested. Waiting for the current epoch to finish..."
)
self.training_worker.stop() self.training_worker.stop()
self.stop_training_button.setEnabled(False) self.stop_training_button.setEnabled(False)
@@ -1566,9 +1519,7 @@ class TrainingTab(QWidget):
if worker.isRunning(): if worker.isRunning():
if not worker.wait(wait_timeout_ms): if not worker.wait(wait_timeout_ms):
logger.warning( logger.warning("Training worker did not finish within %sms", wait_timeout_ms)
"Training worker did not finish within %sms", wait_timeout_ms
)
worker.deleteLater() worker.deleteLater()
@@ -1585,16 +1536,12 @@ class TrainingTab(QWidget):
self._set_training_state(False) self._set_training_state(False)
self.training_progress_bar.setVisible(False) self.training_progress_bar.setVisible(False)
def _on_training_progress( def _on_training_progress(self, current_epoch: int, total_epochs: int, metrics: Dict[str, Any]):
self, current_epoch: int, total_epochs: int, metrics: Dict[str, Any]
):
self.training_progress_bar.setMaximum(total_epochs) self.training_progress_bar.setMaximum(total_epochs)
self.training_progress_bar.setValue(current_epoch) self.training_progress_bar.setValue(current_epoch)
parts = [f"Epoch {current_epoch}/{total_epochs}"] parts = [f"Epoch {current_epoch}/{total_epochs}"]
if metrics: if metrics:
metric_text = ", ".join( metric_text = ", ".join(f"{key}: {value:.4f}" for key, value in metrics.items())
f"{key}: {value:.4f}" for key, value in metrics.items()
)
parts.append(metric_text) parts.append(metric_text)
self._append_training_log(" | ".join(parts)) self._append_training_log(" | ".join(parts))
@@ -1621,9 +1568,7 @@ class TrainingTab(QWidget):
f"Model trained but not registered: {exc}", f"Model trained but not registered: {exc}",
) )
else: else:
QMessageBox.information( QMessageBox.information(self, "Training Complete", "Training finished successfully.")
self, "Training Complete", "Training finished successfully."
)
def _on_training_error(self, message: str): def _on_training_error(self, message: str):
self._cleanup_training_worker() self._cleanup_training_worker()
@@ -1669,9 +1614,7 @@ class TrainingTab(QWidget):
metrics=results.get("metrics"), metrics=results.get("metrics"),
) )
self._append_training_log( self._append_training_log(f"Registered model '{params['model_name']}' (ID {model_id}) at {model_path}")
f"Registered model '{params['model_name']}' (ID {model_id}) at {model_path}"
)
self._active_training_params = None self._active_training_params = None
def _set_training_state(self, is_training: bool): def _set_training_state(self, is_training: bool):
@@ -1714,9 +1657,7 @@ class TrainingTab(QWidget):
def _browse_save_dir(self): def _browse_save_dir(self):
start_path = self.save_dir_edit.text().strip() or "data/models" start_path = self.save_dir_edit.text().strip() or "data/models"
directory = QFileDialog.getExistingDirectory( directory = QFileDialog.getExistingDirectory(self, "Select Save Directory", start_path)
self, "Select Save Directory", start_path
)
if directory: if directory:
self.save_dir_edit.setText(directory) self.save_dir_edit.setText(directory)