correcting label writing and formatting code
This commit is contained in:
@@ -10,6 +10,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import yaml
|
||||
import numpy as np
|
||||
from PySide6.QtCore import Qt, QThread, Signal
|
||||
from PySide6.QtWidgets import (
|
||||
QWidget,
|
||||
@@ -91,10 +92,7 @@ class TrainingWorker(QThread):
|
||||
},
|
||||
}
|
||||
]
|
||||
computed_total = sum(
|
||||
max(0, int((stage.get("params") or {}).get("epochs", 0)))
|
||||
for stage in self.stage_plan
|
||||
)
|
||||
computed_total = sum(max(0, int((stage.get("params") or {}).get("epochs", 0))) for stage in self.stage_plan)
|
||||
self.total_epochs = total_epochs if total_epochs else computed_total or epochs
|
||||
self._stop_requested = False
|
||||
|
||||
@@ -201,9 +199,7 @@ class TrainingWorker(QThread):
|
||||
class TrainingTab(QWidget):
|
||||
"""Training tab for model training."""
|
||||
|
||||
def __init__(
|
||||
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
||||
):
|
||||
def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
|
||||
super().__init__(parent)
|
||||
self.db_manager = db_manager
|
||||
self.config_manager = config_manager
|
||||
@@ -337,18 +333,14 @@ class TrainingTab(QWidget):
|
||||
self.model_version_edit = QLineEdit("v1")
|
||||
form_layout.addRow("Version:", self.model_version_edit)
|
||||
|
||||
default_base_model = self.config_manager.get(
|
||||
"models.default_base_model", "yolov8s-seg.pt"
|
||||
)
|
||||
default_base_model = self.config_manager.get("models.default_base_model", "yolov8s-seg.pt")
|
||||
base_model_choices = self.config_manager.get("models.base_model_choices", [])
|
||||
|
||||
self.base_model_combo = QComboBox()
|
||||
self.base_model_combo.addItem("Custom path…", "")
|
||||
for choice in base_model_choices:
|
||||
self.base_model_combo.addItem(choice, choice)
|
||||
self.base_model_combo.currentIndexChanged.connect(
|
||||
self._on_base_model_preset_changed
|
||||
)
|
||||
self.base_model_combo.currentIndexChanged.connect(self._on_base_model_preset_changed)
|
||||
form_layout.addRow("Base Model Preset:", self.base_model_combo)
|
||||
|
||||
base_model_layout = QHBoxLayout()
|
||||
@@ -434,12 +426,8 @@ class TrainingTab(QWidget):
|
||||
group_layout = QVBoxLayout()
|
||||
|
||||
self.two_stage_checkbox = QCheckBox("Enable staged head-only + full fine-tune")
|
||||
two_stage_defaults = (
|
||||
training_defaults.get("two_stage", {}) if training_defaults else {}
|
||||
)
|
||||
self.two_stage_checkbox.setChecked(
|
||||
bool(two_stage_defaults.get("enabled", False))
|
||||
)
|
||||
two_stage_defaults = training_defaults.get("two_stage", {}) if training_defaults else {}
|
||||
self.two_stage_checkbox.setChecked(bool(two_stage_defaults.get("enabled", False)))
|
||||
self.two_stage_checkbox.toggled.connect(self._on_two_stage_toggled)
|
||||
group_layout.addWidget(self.two_stage_checkbox)
|
||||
|
||||
@@ -501,9 +489,7 @@ class TrainingTab(QWidget):
|
||||
stage2_group.setLayout(stage2_form)
|
||||
controls_layout.addWidget(stage2_group)
|
||||
|
||||
helper_label = QLabel(
|
||||
"When enabled, staged hyperparameters override the global epochs/patience/lr."
|
||||
)
|
||||
helper_label = QLabel("When enabled, staged hyperparameters override the global epochs/patience/lr.")
|
||||
helper_label.setWordWrap(True)
|
||||
controls_layout.addWidget(helper_label)
|
||||
|
||||
@@ -548,9 +534,7 @@ class TrainingTab(QWidget):
|
||||
if normalized == preset_value:
|
||||
target_index = idx
|
||||
break
|
||||
if normalized.endswith(f"/{preset_value}") or normalized.endswith(
|
||||
f"\\{preset_value}"
|
||||
):
|
||||
if normalized.endswith(f"/{preset_value}") or normalized.endswith(f"\\{preset_value}"):
|
||||
target_index = idx
|
||||
break
|
||||
self.base_model_combo.blockSignals(True)
|
||||
@@ -638,9 +622,7 @@ class TrainingTab(QWidget):
|
||||
|
||||
def _browse_dataset(self):
|
||||
"""Open a file dialog to manually select data.yaml."""
|
||||
start_dir = self.config_manager.get(
|
||||
"training.last_dataset_dir", "data/datasets"
|
||||
)
|
||||
start_dir = self.config_manager.get("training.last_dataset_dir", "data/datasets")
|
||||
start_path = Path(start_dir).expanduser()
|
||||
if not start_path.exists():
|
||||
start_path = Path.cwd()
|
||||
@@ -676,9 +658,7 @@ class TrainingTab(QWidget):
|
||||
return
|
||||
except Exception as exc:
|
||||
logger.exception("Unexpected error while generating data.yaml")
|
||||
self._display_dataset_error(
|
||||
"Unexpected error while generating data.yaml. Check logs for details."
|
||||
)
|
||||
self._display_dataset_error("Unexpected error while generating data.yaml. Check logs for details.")
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"data.yaml Generation Failed",
|
||||
@@ -755,13 +735,9 @@ class TrainingTab(QWidget):
|
||||
self.selected_dataset = info
|
||||
|
||||
self.dataset_root_label.setText(info["root"]) # type: ignore[arg-type]
|
||||
self.train_count_label.setText(
|
||||
self._format_split_info(info["splits"].get("train"))
|
||||
)
|
||||
self.train_count_label.setText(self._format_split_info(info["splits"].get("train")))
|
||||
self.val_count_label.setText(self._format_split_info(info["splits"].get("val")))
|
||||
self.test_count_label.setText(
|
||||
self._format_split_info(info["splits"].get("test"))
|
||||
)
|
||||
self.test_count_label.setText(self._format_split_info(info["splits"].get("test")))
|
||||
self.num_classes_label.setText(str(info["num_classes"]))
|
||||
class_names = ", ".join(info["class_names"]) or "–"
|
||||
self.class_names_label.setText(class_names)
|
||||
@@ -815,18 +791,12 @@ class TrainingTab(QWidget):
|
||||
if split_path.exists():
|
||||
split_info["count"] = self._count_images(split_path)
|
||||
if split_info["count"] == 0:
|
||||
warnings.append(
|
||||
f"No images found for {split_name} split at {split_path}"
|
||||
)
|
||||
warnings.append(f"No images found for {split_name} split at {split_path}")
|
||||
else:
|
||||
warnings.append(
|
||||
f"{split_name.capitalize()} path does not exist: {split_path}"
|
||||
)
|
||||
warnings.append(f"{split_name.capitalize()} path does not exist: {split_path}")
|
||||
else:
|
||||
if split_name in ("train", "val"):
|
||||
warnings.append(
|
||||
f"{split_name.capitalize()} split missing in data.yaml"
|
||||
)
|
||||
warnings.append(f"{split_name.capitalize()} split missing in data.yaml")
|
||||
splits[split_name] = split_info
|
||||
|
||||
names_list = self._normalize_class_names(data.get("names"))
|
||||
@@ -844,9 +814,7 @@ class TrainingTab(QWidget):
|
||||
if not names_list and nc_value:
|
||||
names_list = [f"class_{idx}" for idx in range(int(nc_value))]
|
||||
elif nc_value and len(names_list) not in (0, int(nc_value)):
|
||||
warnings.append(
|
||||
f"Number of class names ({len(names_list)}) does not match nc={nc_value}"
|
||||
)
|
||||
warnings.append(f"Number of class names ({len(names_list)}) does not match nc={nc_value}")
|
||||
|
||||
dataset_name = data.get("name") or base_path.name
|
||||
|
||||
@@ -898,16 +866,12 @@ class TrainingTab(QWidget):
|
||||
|
||||
class_index_map = self._build_class_index_map(dataset_info)
|
||||
if not class_index_map:
|
||||
self._append_training_log(
|
||||
"Skipping label export: dataset classes do not match database entries."
|
||||
)
|
||||
self._append_training_log("Skipping label export: dataset classes do not match database entries.")
|
||||
return
|
||||
|
||||
dataset_root_str = dataset_info.get("root")
|
||||
dataset_yaml_path = dataset_info.get("yaml_path")
|
||||
dataset_yaml = (
|
||||
Path(dataset_yaml_path).expanduser() if dataset_yaml_path else None
|
||||
)
|
||||
dataset_yaml = Path(dataset_yaml_path).expanduser() if dataset_yaml_path else None
|
||||
dataset_root: Optional[Path]
|
||||
if dataset_root_str:
|
||||
dataset_root = Path(dataset_root_str).resolve()
|
||||
@@ -941,7 +905,9 @@ class TrainingTab(QWidget):
|
||||
if stats["registered_images"]:
|
||||
message += f" {stats['registered_images']} image(s) had database-backed annotations."
|
||||
if stats["missing_records"]:
|
||||
message += f" {stats['missing_records']} image(s) had no database entry; empty label files were written."
|
||||
message += (
|
||||
f" {stats['missing_records']} image(s) had no database entry; empty label files were written."
|
||||
)
|
||||
split_messages.append(message)
|
||||
|
||||
for msg in split_messages:
|
||||
@@ -973,9 +939,7 @@ class TrainingTab(QWidget):
|
||||
continue
|
||||
|
||||
processed_images += 1
|
||||
label_path = (labels_dir / image_file.relative_to(images_dir)).with_suffix(
|
||||
".txt"
|
||||
)
|
||||
label_path = (labels_dir / image_file.relative_to(images_dir)).with_suffix(".txt")
|
||||
label_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
found, annotation_entries = self._fetch_annotations_for_image(
|
||||
@@ -991,25 +955,23 @@ class TrainingTab(QWidget):
|
||||
for entry in annotation_entries:
|
||||
polygon = entry.get("polygon") or []
|
||||
if polygon:
|
||||
print(image_file, polygon[:4], polygon[-2:], entry.get("bbox"))
|
||||
# coords = " ".join(f"{value:.6f}" for value in entry.get("bbox"))
|
||||
# coords += " "
|
||||
coords = " ".join(f"{value:.6f}" for value in polygon)
|
||||
handle.write(f"{entry['class_idx']} {coords}\n")
|
||||
annotations_written += 1
|
||||
elif entry.get("bbox"):
|
||||
x_center, y_center, width, height = entry["bbox"]
|
||||
handle.write(
|
||||
f"{entry['class_idx']} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n"
|
||||
)
|
||||
handle.write(f"{entry['class_idx']} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n")
|
||||
annotations_written += 1
|
||||
|
||||
total_annotations += annotations_written
|
||||
|
||||
cache_reset_root = labels_dir.parent
|
||||
self._invalidate_split_cache(cache_reset_root)
|
||||
|
||||
if processed_images == 0:
|
||||
self._append_training_log(
|
||||
f"[{split_name}] No images found to export labels for."
|
||||
)
|
||||
self._append_training_log(f"[{split_name}] No images found to export labels for.")
|
||||
return None
|
||||
|
||||
return {
|
||||
@@ -1135,6 +1097,10 @@ class TrainingTab(QWidget):
|
||||
xs.append(x_val)
|
||||
ys.append(y_val)
|
||||
|
||||
if any(np.abs(np.array(coords[:2]) - np.array(coords[-2:])) < 1e-5):
|
||||
print("Closing polygon")
|
||||
coords.extend(coords[:2])
|
||||
|
||||
if len(coords) < 6:
|
||||
continue
|
||||
|
||||
@@ -1147,6 +1113,11 @@ class TrainingTab(QWidget):
|
||||
+ abs((min(ys) if ys else 0.0) - y_min)
|
||||
+ abs((max(ys) if ys else 0.0) - y_max)
|
||||
)
|
||||
width = max(0.0, x_max - x_min)
|
||||
height = max(0.0, y_max - y_min)
|
||||
x_center = x_min + width / 2.0
|
||||
y_center = y_min + height / 2.0
|
||||
score = (x_center, y_center, width, height)
|
||||
|
||||
candidates.append((score, coords))
|
||||
|
||||
@@ -1164,13 +1135,10 @@ class TrainingTab(QWidget):
|
||||
return 1.0
|
||||
return value
|
||||
|
||||
def _prepare_dataset_for_training(
|
||||
self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None
|
||||
) -> Path:
|
||||
def _prepare_dataset_for_training(self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None) -> Path:
|
||||
dataset_info = dataset_info or (
|
||||
self.selected_dataset
|
||||
if self.selected_dataset
|
||||
and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
|
||||
if self.selected_dataset and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
|
||||
else self._parse_dataset_yaml(dataset_yaml)
|
||||
)
|
||||
|
||||
@@ -1189,14 +1157,10 @@ class TrainingTab(QWidget):
|
||||
cache_root = self._get_rgb_cache_root(dataset_yaml)
|
||||
rgb_yaml = cache_root / "data.yaml"
|
||||
if rgb_yaml.exists():
|
||||
self._append_training_log(
|
||||
f"Detected grayscale dataset; reusing RGB cache at {cache_root}"
|
||||
)
|
||||
self._append_training_log(f"Detected grayscale dataset; reusing RGB cache at {cache_root}")
|
||||
return rgb_yaml
|
||||
|
||||
self._append_training_log(
|
||||
f"Detected grayscale dataset; creating RGB cache at {cache_root}"
|
||||
)
|
||||
self._append_training_log(f"Detected grayscale dataset; creating RGB cache at {cache_root}")
|
||||
self._build_rgb_dataset(cache_root, dataset_info)
|
||||
return rgb_yaml
|
||||
|
||||
@@ -1463,15 +1427,12 @@ class TrainingTab(QWidget):
|
||||
|
||||
dataset_path = Path(dataset_yaml).expanduser()
|
||||
if not dataset_path.exists():
|
||||
QMessageBox.warning(
|
||||
self, "Invalid Dataset", "Selected data.yaml file does not exist."
|
||||
)
|
||||
QMessageBox.warning(self, "Invalid Dataset", "Selected data.yaml file does not exist.")
|
||||
return
|
||||
|
||||
dataset_info = (
|
||||
self.selected_dataset
|
||||
if self.selected_dataset
|
||||
and self.selected_dataset.get("yaml_path") == str(dataset_path)
|
||||
if self.selected_dataset and self.selected_dataset.get("yaml_path") == str(dataset_path)
|
||||
else self._parse_dataset_yaml(dataset_path)
|
||||
)
|
||||
|
||||
@@ -1480,16 +1441,12 @@ class TrainingTab(QWidget):
|
||||
|
||||
dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info)
|
||||
if dataset_to_use != dataset_path:
|
||||
self._append_training_log(
|
||||
f"Using RGB-converted dataset at {dataset_to_use.parent}"
|
||||
)
|
||||
self._append_training_log(f"Using RGB-converted dataset at {dataset_to_use.parent}")
|
||||
|
||||
params = self._collect_training_params()
|
||||
stage_plan = self._compose_stage_plan(params)
|
||||
params["stage_plan"] = stage_plan
|
||||
total_planned_epochs = (
|
||||
self._calculate_total_stage_epochs(stage_plan) or params["epochs"]
|
||||
)
|
||||
total_planned_epochs = self._calculate_total_stage_epochs(stage_plan) or params["epochs"]
|
||||
params["total_planned_epochs"] = total_planned_epochs
|
||||
self._active_training_params = params
|
||||
self._training_cancelled = False
|
||||
@@ -1498,9 +1455,7 @@ class TrainingTab(QWidget):
|
||||
self._append_training_log("Two-stage fine-tuning schedule:")
|
||||
self._log_stage_plan(stage_plan)
|
||||
|
||||
self._append_training_log(
|
||||
f"Starting training run '{params['run_name']}' using {params['base_model']}"
|
||||
)
|
||||
self._append_training_log(f"Starting training run '{params['run_name']}' using {params['base_model']}")
|
||||
|
||||
self.training_progress_bar.setVisible(True)
|
||||
self.training_progress_bar.setMaximum(max(1, total_planned_epochs))
|
||||
@@ -1528,9 +1483,7 @@ class TrainingTab(QWidget):
|
||||
def _stop_training(self):
|
||||
if self.training_worker and self.training_worker.isRunning():
|
||||
self._training_cancelled = True
|
||||
self._append_training_log(
|
||||
"Stop requested. Waiting for the current epoch to finish..."
|
||||
)
|
||||
self._append_training_log("Stop requested. Waiting for the current epoch to finish...")
|
||||
self.training_worker.stop()
|
||||
self.stop_training_button.setEnabled(False)
|
||||
|
||||
@@ -1566,9 +1519,7 @@ class TrainingTab(QWidget):
|
||||
|
||||
if worker.isRunning():
|
||||
if not worker.wait(wait_timeout_ms):
|
||||
logger.warning(
|
||||
"Training worker did not finish within %sms", wait_timeout_ms
|
||||
)
|
||||
logger.warning("Training worker did not finish within %sms", wait_timeout_ms)
|
||||
|
||||
worker.deleteLater()
|
||||
|
||||
@@ -1585,16 +1536,12 @@ class TrainingTab(QWidget):
|
||||
self._set_training_state(False)
|
||||
self.training_progress_bar.setVisible(False)
|
||||
|
||||
def _on_training_progress(
|
||||
self, current_epoch: int, total_epochs: int, metrics: Dict[str, Any]
|
||||
):
|
||||
def _on_training_progress(self, current_epoch: int, total_epochs: int, metrics: Dict[str, Any]):
|
||||
self.training_progress_bar.setMaximum(total_epochs)
|
||||
self.training_progress_bar.setValue(current_epoch)
|
||||
parts = [f"Epoch {current_epoch}/{total_epochs}"]
|
||||
if metrics:
|
||||
metric_text = ", ".join(
|
||||
f"{key}: {value:.4f}" for key, value in metrics.items()
|
||||
)
|
||||
metric_text = ", ".join(f"{key}: {value:.4f}" for key, value in metrics.items())
|
||||
parts.append(metric_text)
|
||||
self._append_training_log(" | ".join(parts))
|
||||
|
||||
@@ -1621,9 +1568,7 @@ class TrainingTab(QWidget):
|
||||
f"Model trained but not registered: {exc}",
|
||||
)
|
||||
else:
|
||||
QMessageBox.information(
|
||||
self, "Training Complete", "Training finished successfully."
|
||||
)
|
||||
QMessageBox.information(self, "Training Complete", "Training finished successfully.")
|
||||
|
||||
def _on_training_error(self, message: str):
|
||||
self._cleanup_training_worker()
|
||||
@@ -1669,9 +1614,7 @@ class TrainingTab(QWidget):
|
||||
metrics=results.get("metrics"),
|
||||
)
|
||||
|
||||
self._append_training_log(
|
||||
f"Registered model '{params['model_name']}' (ID {model_id}) at {model_path}"
|
||||
)
|
||||
self._append_training_log(f"Registered model '{params['model_name']}' (ID {model_id}) at {model_path}")
|
||||
self._active_training_params = None
|
||||
|
||||
def _set_training_state(self, is_training: bool):
|
||||
@@ -1714,9 +1657,7 @@ class TrainingTab(QWidget):
|
||||
|
||||
def _browse_save_dir(self):
|
||||
start_path = self.save_dir_edit.text().strip() or "data/models"
|
||||
directory = QFileDialog.getExistingDirectory(
|
||||
self, "Select Save Directory", start_path
|
||||
)
|
||||
directory = QFileDialog.getExistingDirectory(self, "Select Save Directory", start_path)
|
||||
if directory:
|
||||
self.save_dir_edit.setText(directory)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user