Adding export for labels in results
This commit is contained in:
@@ -3,7 +3,7 @@ Results tab for browsing stored detections and visualizing overlays.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from PySide6.QtWidgets import (
|
||||
QWidget,
|
||||
@@ -65,6 +65,15 @@ class ResultsTab(QWidget):
|
||||
self.refresh_btn = QPushButton("Refresh")
|
||||
self.refresh_btn.clicked.connect(self.refresh)
|
||||
controls_layout.addWidget(self.refresh_btn)
|
||||
|
||||
self.export_labels_btn = QPushButton("Export Labels")
|
||||
self.export_labels_btn.setToolTip(
|
||||
"Export YOLO .txt labels for the selected image/model run.\n"
|
||||
"Output path is inferred from the image path (images/ -> labels/)."
|
||||
)
|
||||
self.export_labels_btn.clicked.connect(self._export_labels_for_current_selection)
|
||||
controls_layout.addWidget(self.export_labels_btn)
|
||||
|
||||
controls_layout.addStretch()
|
||||
left_layout.addLayout(controls_layout)
|
||||
|
||||
@@ -139,6 +148,8 @@ class ResultsTab(QWidget):
|
||||
self.current_detections = []
|
||||
self.preview_canvas.clear()
|
||||
self.summary_label.setText("Select a detection result to preview.")
|
||||
if hasattr(self, "export_labels_btn"):
|
||||
self.export_labels_btn.setEnabled(False)
|
||||
|
||||
def _load_detection_summary(self):
|
||||
"""Load latest detection summaries grouped by image + model."""
|
||||
@@ -258,6 +269,231 @@ class ResultsTab(QWidget):
|
||||
self._load_detections_for_selection(entry)
|
||||
self._apply_detection_overlays()
|
||||
self._update_summary_label(entry)
|
||||
if hasattr(self, "export_labels_btn"):
|
||||
self.export_labels_btn.setEnabled(True)
|
||||
|
||||
def _export_labels_for_current_selection(self):
|
||||
"""Export YOLO label file(s) for the currently selected image/model."""
|
||||
if not self.current_selection:
|
||||
QMessageBox.information(self, "Export Labels", "Select a detection result first.")
|
||||
return
|
||||
|
||||
entry = self.current_selection
|
||||
|
||||
image_path_str = self._resolve_image_path(entry)
|
||||
if not image_path_str:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"Export Labels",
|
||||
"Unable to locate the image file for this detection; cannot infer labels path.",
|
||||
)
|
||||
return
|
||||
|
||||
# Ensure we have the detections for the selection.
|
||||
if not self.current_detections:
|
||||
self._load_detections_for_selection(entry)
|
||||
|
||||
if not self.current_detections:
|
||||
QMessageBox.information(
|
||||
self,
|
||||
"Export Labels",
|
||||
"No detections found for this image/model selection.",
|
||||
)
|
||||
return
|
||||
|
||||
image_path = Path(image_path_str)
|
||||
try:
|
||||
label_path = self._infer_yolo_label_path(image_path)
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to infer label path for {image_path}: {exc}")
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Export Labels",
|
||||
f"Failed to infer export path for labels:\n{exc}",
|
||||
)
|
||||
return
|
||||
|
||||
class_map = self._build_detection_class_index_map(self.current_detections)
|
||||
if not class_map:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"Export Labels",
|
||||
"Unable to build class->index mapping (missing class names).",
|
||||
)
|
||||
return
|
||||
|
||||
lines_written = 0
|
||||
skipped = 0
|
||||
label_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
with open(label_path, "w", encoding="utf-8") as handle:
|
||||
print("writing to", label_path)
|
||||
for det in self.current_detections:
|
||||
yolo_line = self._format_detection_as_yolo_line(det, class_map)
|
||||
if not yolo_line:
|
||||
skipped += 1
|
||||
continue
|
||||
handle.write(yolo_line + "\n")
|
||||
lines_written += 1
|
||||
except OSError as exc:
|
||||
logger.error(f"Failed to write labels file {label_path}: {exc}")
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Export Labels",
|
||||
f"Failed to write label file:\n{label_path}\n\n{exc}",
|
||||
)
|
||||
return
|
||||
|
||||
return
|
||||
# Optional: write a classes.txt next to the labels root to make the mapping discoverable.
|
||||
# This is not required by Ultralytics (data.yaml usually holds class names), but helps reuse.
|
||||
try:
|
||||
classes_txt = label_path.parent.parent / "classes.txt"
|
||||
classes_txt.parent.mkdir(parents=True, exist_ok=True)
|
||||
inv = {idx: name for name, idx in class_map.items()}
|
||||
with open(classes_txt, "w", encoding="utf-8") as handle:
|
||||
for idx in range(len(inv)):
|
||||
handle.write(f"{inv[idx]}\n")
|
||||
except Exception:
|
||||
# Non-fatal
|
||||
pass
|
||||
|
||||
QMessageBox.information(
|
||||
self,
|
||||
"Export Labels",
|
||||
f"Exported {lines_written} label line(s) to:\n{label_path}\n\nSkipped {skipped} invalid detection(s).",
|
||||
)
|
||||
|
||||
def _infer_yolo_label_path(self, image_path: Path) -> Path:
|
||||
"""Infer a YOLO label path from an image path.
|
||||
|
||||
If the image lives under an `images/` directory (anywhere in the path), we mirror the
|
||||
subpath under a sibling `labels/` directory at the same level.
|
||||
|
||||
Example:
|
||||
/dataset/train/images/sub/img.jpg -> /dataset/train/labels/sub/img.txt
|
||||
"""
|
||||
|
||||
resolved = image_path.expanduser().resolve()
|
||||
|
||||
# Find the nearest ancestor directory named 'images'
|
||||
images_dir: Optional[Path] = None
|
||||
for parent in [resolved.parent, *resolved.parents]:
|
||||
if parent.name.lower() == "images":
|
||||
images_dir = parent
|
||||
break
|
||||
|
||||
if images_dir is not None:
|
||||
rel = resolved.relative_to(images_dir)
|
||||
labels_dir = images_dir.parent / "labels"
|
||||
return (labels_dir / rel).with_suffix(".txt")
|
||||
|
||||
# Fallback: create a local sibling labels folder next to the image.
|
||||
return (resolved.parent / "labels" / resolved.name).with_suffix(".txt")
|
||||
|
||||
def _build_detection_class_index_map(self, detections: List[Dict]) -> Dict[str, int]:
|
||||
"""Build a stable class_name -> YOLO class index mapping.
|
||||
|
||||
Preference order:
|
||||
1) Database object_classes table (alphabetical class_name order)
|
||||
2) Fallback to class_name values present in the detections (alphabetical)
|
||||
"""
|
||||
|
||||
names: List[str] = []
|
||||
try:
|
||||
db_classes = self.db_manager.get_object_classes() or []
|
||||
names = [str(row.get("class_name")) for row in db_classes if row.get("class_name")]
|
||||
except Exception:
|
||||
names = []
|
||||
|
||||
if not names:
|
||||
observed = sorted({str(det.get("class_name")) for det in detections if det.get("class_name")})
|
||||
names = list(observed)
|
||||
|
||||
return {name: idx for idx, name in enumerate(names)}
|
||||
|
||||
def _format_detection_as_yolo_line(self, det: Dict, class_map: Dict[str, int]) -> Optional[str]:
|
||||
"""Convert a detection row to a YOLO label line.
|
||||
|
||||
- If segmentation_mask is present, exports segmentation polygon format:
|
||||
class x1 y1 x2 y2 ...
|
||||
(normalized coordinates)
|
||||
- Otherwise exports bbox format:
|
||||
class x_center y_center width height
|
||||
(normalized coordinates)
|
||||
"""
|
||||
|
||||
class_name = det.get("class_name")
|
||||
if not class_name or class_name not in class_map:
|
||||
return None
|
||||
class_idx = class_map[class_name]
|
||||
|
||||
mask = det.get("segmentation_mask")
|
||||
polygon = self._convert_segmentation_mask_to_polygon(mask)
|
||||
if polygon:
|
||||
coords = " ".join(f"{value:.6f}" for value in polygon)
|
||||
return f"{class_idx} {coords}".strip()
|
||||
|
||||
bbox = self._convert_bbox_to_yolo_xywh(det)
|
||||
if bbox is None:
|
||||
return None
|
||||
x_center, y_center, width, height = bbox
|
||||
return f"{class_idx} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
|
||||
|
||||
def _convert_bbox_to_yolo_xywh(self, det: Dict) -> Optional[Tuple[float, float, float, float]]:
|
||||
"""Convert stored xyxy (normalized) bbox to YOLO xywh (normalized)."""
|
||||
|
||||
x_min = det.get("x_min")
|
||||
y_min = det.get("y_min")
|
||||
x_max = det.get("x_max")
|
||||
y_max = det.get("y_max")
|
||||
if any(v is None for v in (x_min, y_min, x_max, y_max)):
|
||||
return None
|
||||
|
||||
try:
|
||||
x_min_f = self._clamp01(float(x_min))
|
||||
y_min_f = self._clamp01(float(y_min))
|
||||
x_max_f = self._clamp01(float(x_max))
|
||||
y_max_f = self._clamp01(float(y_max))
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
width = max(0.0, x_max_f - x_min_f)
|
||||
height = max(0.0, y_max_f - y_min_f)
|
||||
if width <= 0.0 or height <= 0.0:
|
||||
return None
|
||||
|
||||
x_center = x_min_f + width / 2.0
|
||||
y_center = y_min_f + height / 2.0
|
||||
return x_center, y_center, width, height
|
||||
|
||||
def _convert_segmentation_mask_to_polygon(self, mask_data) -> List[float]:
|
||||
"""Convert stored segmentation_mask [[x,y], ...] to YOLO polygon coords [x1,y1,...]."""
|
||||
|
||||
if not isinstance(mask_data, list):
|
||||
return []
|
||||
|
||||
coords: List[float] = []
|
||||
for point in mask_data:
|
||||
if not isinstance(point, (list, tuple)) or len(point) < 2:
|
||||
continue
|
||||
try:
|
||||
x = self._clamp01(float(point[0]))
|
||||
y = self._clamp01(float(point[1]))
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
coords.extend([x, y])
|
||||
|
||||
# Need at least 3 points => 6 values.
|
||||
return coords if len(coords) >= 6 else []
|
||||
|
||||
@staticmethod
|
||||
def _clamp01(value: float) -> float:
|
||||
if value < 0.0:
|
||||
return 0.0
|
||||
if value > 1.0:
|
||||
return 1.0
|
||||
return value
|
||||
|
||||
def _load_detections_for_selection(self, entry: Dict):
|
||||
"""Load detection records for the selected image/model pair."""
|
||||
|
||||
Reference in New Issue
Block a user