From 2c494dac492e51474ea7330da2a81b54bca8fc96 Mon Sep 17 00:00:00 2001 From: Martin Laasmaa Date: Fri, 16 Jan 2026 11:15:12 +0200 Subject: [PATCH] Adding export for labels in results --- src/gui/tabs/results_tab.py | 238 +++++++++++++++++++++++++++++++++++- 1 file changed, 237 insertions(+), 1 deletion(-) diff --git a/src/gui/tabs/results_tab.py b/src/gui/tabs/results_tab.py index 97edfec..8ac28d5 100644 --- a/src/gui/tabs/results_tab.py +++ b/src/gui/tabs/results_tab.py @@ -3,7 +3,7 @@ Results tab for browsing stored detections and visualizing overlays. """ from pathlib import Path -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple from PySide6.QtWidgets import ( QWidget, @@ -65,6 +65,15 @@ class ResultsTab(QWidget): self.refresh_btn = QPushButton("Refresh") self.refresh_btn.clicked.connect(self.refresh) controls_layout.addWidget(self.refresh_btn) + + self.export_labels_btn = QPushButton("Export Labels") + self.export_labels_btn.setToolTip( + "Export YOLO .txt labels for the selected image/model run.\n" + "Output path is inferred from the image path (images/ -> labels/)." + ) + self.export_labels_btn.clicked.connect(self._export_labels_for_current_selection) + controls_layout.addWidget(self.export_labels_btn) + controls_layout.addStretch() left_layout.addLayout(controls_layout) @@ -139,6 +148,8 @@ class ResultsTab(QWidget): self.current_detections = [] self.preview_canvas.clear() self.summary_label.setText("Select a detection result to preview.") + if hasattr(self, "export_labels_btn"): + self.export_labels_btn.setEnabled(False) def _load_detection_summary(self): """Load latest detection summaries grouped by image + model.""" @@ -258,6 +269,231 @@ class ResultsTab(QWidget): self._load_detections_for_selection(entry) self._apply_detection_overlays() self._update_summary_label(entry) + if hasattr(self, "export_labels_btn"): + self.export_labels_btn.setEnabled(True) + + def _export_labels_for_current_selection(self): + """Export YOLO label file(s) for the currently selected image/model.""" + if not self.current_selection: + QMessageBox.information(self, "Export Labels", "Select a detection result first.") + return + + entry = self.current_selection + + image_path_str = self._resolve_image_path(entry) + if not image_path_str: + QMessageBox.warning( + self, + "Export Labels", + "Unable to locate the image file for this detection; cannot infer labels path.", + ) + return + + # Ensure we have the detections for the selection. + if not self.current_detections: + self._load_detections_for_selection(entry) + + if not self.current_detections: + QMessageBox.information( + self, + "Export Labels", + "No detections found for this image/model selection.", + ) + return + + image_path = Path(image_path_str) + try: + label_path = self._infer_yolo_label_path(image_path) + except Exception as exc: + logger.error(f"Failed to infer label path for {image_path}: {exc}") + QMessageBox.critical( + self, + "Export Labels", + f"Failed to infer export path for labels:\n{exc}", + ) + return + + class_map = self._build_detection_class_index_map(self.current_detections) + if not class_map: + QMessageBox.warning( + self, + "Export Labels", + "Unable to build class->index mapping (missing class names).", + ) + return + + lines_written = 0 + skipped = 0 + label_path.parent.mkdir(parents=True, exist_ok=True) + try: + with open(label_path, "w", encoding="utf-8") as handle: + print("writing to", label_path) + for det in self.current_detections: + yolo_line = self._format_detection_as_yolo_line(det, class_map) + if not yolo_line: + skipped += 1 + continue + handle.write(yolo_line + "\n") + lines_written += 1 + except OSError as exc: + logger.error(f"Failed to write labels file {label_path}: {exc}") + QMessageBox.critical( + self, + "Export Labels", + f"Failed to write label file:\n{label_path}\n\n{exc}", + ) + return + + return + # Optional: write a classes.txt next to the labels root to make the mapping discoverable. + # This is not required by Ultralytics (data.yaml usually holds class names), but helps reuse. + try: + classes_txt = label_path.parent.parent / "classes.txt" + classes_txt.parent.mkdir(parents=True, exist_ok=True) + inv = {idx: name for name, idx in class_map.items()} + with open(classes_txt, "w", encoding="utf-8") as handle: + for idx in range(len(inv)): + handle.write(f"{inv[idx]}\n") + except Exception: + # Non-fatal + pass + + QMessageBox.information( + self, + "Export Labels", + f"Exported {lines_written} label line(s) to:\n{label_path}\n\nSkipped {skipped} invalid detection(s).", + ) + + def _infer_yolo_label_path(self, image_path: Path) -> Path: + """Infer a YOLO label path from an image path. + + If the image lives under an `images/` directory (anywhere in the path), we mirror the + subpath under a sibling `labels/` directory at the same level. + + Example: + /dataset/train/images/sub/img.jpg -> /dataset/train/labels/sub/img.txt + """ + + resolved = image_path.expanduser().resolve() + + # Find the nearest ancestor directory named 'images' + images_dir: Optional[Path] = None + for parent in [resolved.parent, *resolved.parents]: + if parent.name.lower() == "images": + images_dir = parent + break + + if images_dir is not None: + rel = resolved.relative_to(images_dir) + labels_dir = images_dir.parent / "labels" + return (labels_dir / rel).with_suffix(".txt") + + # Fallback: create a local sibling labels folder next to the image. + return (resolved.parent / "labels" / resolved.name).with_suffix(".txt") + + def _build_detection_class_index_map(self, detections: List[Dict]) -> Dict[str, int]: + """Build a stable class_name -> YOLO class index mapping. + + Preference order: + 1) Database object_classes table (alphabetical class_name order) + 2) Fallback to class_name values present in the detections (alphabetical) + """ + + names: List[str] = [] + try: + db_classes = self.db_manager.get_object_classes() or [] + names = [str(row.get("class_name")) for row in db_classes if row.get("class_name")] + except Exception: + names = [] + + if not names: + observed = sorted({str(det.get("class_name")) for det in detections if det.get("class_name")}) + names = list(observed) + + return {name: idx for idx, name in enumerate(names)} + + def _format_detection_as_yolo_line(self, det: Dict, class_map: Dict[str, int]) -> Optional[str]: + """Convert a detection row to a YOLO label line. + + - If segmentation_mask is present, exports segmentation polygon format: + class x1 y1 x2 y2 ... + (normalized coordinates) + - Otherwise exports bbox format: + class x_center y_center width height + (normalized coordinates) + """ + + class_name = det.get("class_name") + if not class_name or class_name not in class_map: + return None + class_idx = class_map[class_name] + + mask = det.get("segmentation_mask") + polygon = self._convert_segmentation_mask_to_polygon(mask) + if polygon: + coords = " ".join(f"{value:.6f}" for value in polygon) + return f"{class_idx} {coords}".strip() + + bbox = self._convert_bbox_to_yolo_xywh(det) + if bbox is None: + return None + x_center, y_center, width, height = bbox + return f"{class_idx} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}" + + def _convert_bbox_to_yolo_xywh(self, det: Dict) -> Optional[Tuple[float, float, float, float]]: + """Convert stored xyxy (normalized) bbox to YOLO xywh (normalized).""" + + x_min = det.get("x_min") + y_min = det.get("y_min") + x_max = det.get("x_max") + y_max = det.get("y_max") + if any(v is None for v in (x_min, y_min, x_max, y_max)): + return None + + try: + x_min_f = self._clamp01(float(x_min)) + y_min_f = self._clamp01(float(y_min)) + x_max_f = self._clamp01(float(x_max)) + y_max_f = self._clamp01(float(y_max)) + except (TypeError, ValueError): + return None + + width = max(0.0, x_max_f - x_min_f) + height = max(0.0, y_max_f - y_min_f) + if width <= 0.0 or height <= 0.0: + return None + + x_center = x_min_f + width / 2.0 + y_center = y_min_f + height / 2.0 + return x_center, y_center, width, height + + def _convert_segmentation_mask_to_polygon(self, mask_data) -> List[float]: + """Convert stored segmentation_mask [[x,y], ...] to YOLO polygon coords [x1,y1,...].""" + + if not isinstance(mask_data, list): + return [] + + coords: List[float] = [] + for point in mask_data: + if not isinstance(point, (list, tuple)) or len(point) < 2: + continue + try: + x = self._clamp01(float(point[0])) + y = self._clamp01(float(point[1])) + except (TypeError, ValueError): + continue + coords.extend([x, y]) + + # Need at least 3 points => 6 values. + return coords if len(coords) >= 6 else [] + + @staticmethod + def _clamp01(value: float) -> float: + if value < 0.0: + return 0.0 + if value > 1.0: + return 1.0 + return value def _load_detections_for_selection(self, entry: Dict): """Load detection records for the selected image/model pair."""