Compare commits
3 Commits
506c74e53a
...
9c8931e6f3
| Author | SHA1 | Date | |
|---|---|---|---|
| 9c8931e6f3 | |||
| 20578c1fdf | |||
| 2c494dac49 |
@@ -462,6 +462,22 @@ class DatabaseManager:
|
|||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
def delete_all_detections(self) -> int:
|
||||||
|
"""Delete all detections from the database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of rows deleted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
conn = self.get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("DELETE FROM detections")
|
||||||
|
conn.commit()
|
||||||
|
return cursor.rowcount
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
# ==================== Statistics Operations ====================
|
# ==================== Statistics Operations ====================
|
||||||
|
|
||||||
def get_detection_statistics(
|
def get_detection_statistics(
|
||||||
|
|||||||
@@ -55,10 +55,7 @@ CREATE TABLE IF NOT EXISTS object_classes (
|
|||||||
|
|
||||||
-- Insert default object classes
|
-- Insert default object classes
|
||||||
INSERT OR IGNORE INTO object_classes (class_name, color, description) VALUES
|
INSERT OR IGNORE INTO object_classes (class_name, color, description) VALUES
|
||||||
('cell', '#FF0000', 'Cell object'),
|
('terminal', '#FFFF00', 'Axion terminal');
|
||||||
('nucleus', '#00FF00', 'Cell nucleus'),
|
|
||||||
('mitochondria', '#0000FF', 'Mitochondria'),
|
|
||||||
('vesicle', '#FFFF00', 'Vesicle');
|
|
||||||
|
|
||||||
-- Annotations table: stores manual annotations
|
-- Annotations table: stores manual annotations
|
||||||
CREATE TABLE IF NOT EXISTS annotations (
|
CREATE TABLE IF NOT EXISTS annotations (
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ Results tab for browsing stored detections and visualizing overlays.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from PySide6.QtWidgets import (
|
from PySide6.QtWidgets import (
|
||||||
QWidget,
|
QWidget,
|
||||||
@@ -65,6 +65,22 @@ class ResultsTab(QWidget):
|
|||||||
self.refresh_btn = QPushButton("Refresh")
|
self.refresh_btn = QPushButton("Refresh")
|
||||||
self.refresh_btn.clicked.connect(self.refresh)
|
self.refresh_btn.clicked.connect(self.refresh)
|
||||||
controls_layout.addWidget(self.refresh_btn)
|
controls_layout.addWidget(self.refresh_btn)
|
||||||
|
|
||||||
|
self.delete_all_btn = QPushButton("Delete All Detections")
|
||||||
|
self.delete_all_btn.setToolTip(
|
||||||
|
"Permanently delete ALL detections from the database.\n" "This cannot be undone."
|
||||||
|
)
|
||||||
|
self.delete_all_btn.clicked.connect(self._delete_all_detections)
|
||||||
|
controls_layout.addWidget(self.delete_all_btn)
|
||||||
|
|
||||||
|
self.export_labels_btn = QPushButton("Export Labels")
|
||||||
|
self.export_labels_btn.setToolTip(
|
||||||
|
"Export YOLO .txt labels for the selected image/model run.\n"
|
||||||
|
"Output path is inferred from the image path (images/ -> labels/)."
|
||||||
|
)
|
||||||
|
self.export_labels_btn.clicked.connect(self._export_labels_for_current_selection)
|
||||||
|
controls_layout.addWidget(self.export_labels_btn)
|
||||||
|
|
||||||
controls_layout.addStretch()
|
controls_layout.addStretch()
|
||||||
left_layout.addLayout(controls_layout)
|
left_layout.addLayout(controls_layout)
|
||||||
|
|
||||||
@@ -130,6 +146,41 @@ class ResultsTab(QWidget):
|
|||||||
layout.addWidget(splitter)
|
layout.addWidget(splitter)
|
||||||
self.setLayout(layout)
|
self.setLayout(layout)
|
||||||
|
|
||||||
|
def _delete_all_detections(self):
|
||||||
|
"""Delete all detections from the database after user confirmation."""
|
||||||
|
confirm = QMessageBox.warning(
|
||||||
|
self,
|
||||||
|
"Delete All Detections",
|
||||||
|
"This will permanently delete ALL detections from the database.\n\n"
|
||||||
|
"This action cannot be undone.\n\n"
|
||||||
|
"Do you want to continue?",
|
||||||
|
QMessageBox.Yes | QMessageBox.No,
|
||||||
|
QMessageBox.No,
|
||||||
|
)
|
||||||
|
|
||||||
|
if confirm != QMessageBox.Yes:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
deleted = self.db_manager.delete_all_detections()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Failed to delete all detections: {exc}")
|
||||||
|
QMessageBox.critical(
|
||||||
|
self,
|
||||||
|
"Error",
|
||||||
|
f"Failed to delete detections:\n{exc}",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
QMessageBox.information(
|
||||||
|
self,
|
||||||
|
"Delete All Detections",
|
||||||
|
f"Deleted {deleted} detection(s) from the database.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reset UI state.
|
||||||
|
self.refresh()
|
||||||
|
|
||||||
def refresh(self):
|
def refresh(self):
|
||||||
"""Refresh the detection list and preview."""
|
"""Refresh the detection list and preview."""
|
||||||
self._load_detection_summary()
|
self._load_detection_summary()
|
||||||
@@ -139,6 +190,8 @@ class ResultsTab(QWidget):
|
|||||||
self.current_detections = []
|
self.current_detections = []
|
||||||
self.preview_canvas.clear()
|
self.preview_canvas.clear()
|
||||||
self.summary_label.setText("Select a detection result to preview.")
|
self.summary_label.setText("Select a detection result to preview.")
|
||||||
|
if hasattr(self, "export_labels_btn"):
|
||||||
|
self.export_labels_btn.setEnabled(False)
|
||||||
|
|
||||||
def _load_detection_summary(self):
|
def _load_detection_summary(self):
|
||||||
"""Load latest detection summaries grouped by image + model."""
|
"""Load latest detection summaries grouped by image + model."""
|
||||||
@@ -258,6 +311,231 @@ class ResultsTab(QWidget):
|
|||||||
self._load_detections_for_selection(entry)
|
self._load_detections_for_selection(entry)
|
||||||
self._apply_detection_overlays()
|
self._apply_detection_overlays()
|
||||||
self._update_summary_label(entry)
|
self._update_summary_label(entry)
|
||||||
|
if hasattr(self, "export_labels_btn"):
|
||||||
|
self.export_labels_btn.setEnabled(True)
|
||||||
|
|
||||||
|
def _export_labels_for_current_selection(self):
|
||||||
|
"""Export YOLO label file(s) for the currently selected image/model."""
|
||||||
|
if not self.current_selection:
|
||||||
|
QMessageBox.information(self, "Export Labels", "Select a detection result first.")
|
||||||
|
return
|
||||||
|
|
||||||
|
entry = self.current_selection
|
||||||
|
|
||||||
|
image_path_str = self._resolve_image_path(entry)
|
||||||
|
if not image_path_str:
|
||||||
|
QMessageBox.warning(
|
||||||
|
self,
|
||||||
|
"Export Labels",
|
||||||
|
"Unable to locate the image file for this detection; cannot infer labels path.",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Ensure we have the detections for the selection.
|
||||||
|
if not self.current_detections:
|
||||||
|
self._load_detections_for_selection(entry)
|
||||||
|
|
||||||
|
if not self.current_detections:
|
||||||
|
QMessageBox.information(
|
||||||
|
self,
|
||||||
|
"Export Labels",
|
||||||
|
"No detections found for this image/model selection.",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
image_path = Path(image_path_str)
|
||||||
|
try:
|
||||||
|
label_path = self._infer_yolo_label_path(image_path)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Failed to infer label path for {image_path}: {exc}")
|
||||||
|
QMessageBox.critical(
|
||||||
|
self,
|
||||||
|
"Export Labels",
|
||||||
|
f"Failed to infer export path for labels:\n{exc}",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
class_map = self._build_detection_class_index_map(self.current_detections)
|
||||||
|
if not class_map:
|
||||||
|
QMessageBox.warning(
|
||||||
|
self,
|
||||||
|
"Export Labels",
|
||||||
|
"Unable to build class->index mapping (missing class names).",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
lines_written = 0
|
||||||
|
skipped = 0
|
||||||
|
label_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
try:
|
||||||
|
with open(label_path, "w", encoding="utf-8") as handle:
|
||||||
|
print("writing to", label_path)
|
||||||
|
for det in self.current_detections:
|
||||||
|
yolo_line = self._format_detection_as_yolo_line(det, class_map)
|
||||||
|
if not yolo_line:
|
||||||
|
skipped += 1
|
||||||
|
continue
|
||||||
|
handle.write(yolo_line + "\n")
|
||||||
|
lines_written += 1
|
||||||
|
except OSError as exc:
|
||||||
|
logger.error(f"Failed to write labels file {label_path}: {exc}")
|
||||||
|
QMessageBox.critical(
|
||||||
|
self,
|
||||||
|
"Export Labels",
|
||||||
|
f"Failed to write label file:\n{label_path}\n\n{exc}",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
return
|
||||||
|
# Optional: write a classes.txt next to the labels root to make the mapping discoverable.
|
||||||
|
# This is not required by Ultralytics (data.yaml usually holds class names), but helps reuse.
|
||||||
|
try:
|
||||||
|
classes_txt = label_path.parent.parent / "classes.txt"
|
||||||
|
classes_txt.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
inv = {idx: name for name, idx in class_map.items()}
|
||||||
|
with open(classes_txt, "w", encoding="utf-8") as handle:
|
||||||
|
for idx in range(len(inv)):
|
||||||
|
handle.write(f"{inv[idx]}\n")
|
||||||
|
except Exception:
|
||||||
|
# Non-fatal
|
||||||
|
pass
|
||||||
|
|
||||||
|
QMessageBox.information(
|
||||||
|
self,
|
||||||
|
"Export Labels",
|
||||||
|
f"Exported {lines_written} label line(s) to:\n{label_path}\n\nSkipped {skipped} invalid detection(s).",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _infer_yolo_label_path(self, image_path: Path) -> Path:
|
||||||
|
"""Infer a YOLO label path from an image path.
|
||||||
|
|
||||||
|
If the image lives under an `images/` directory (anywhere in the path), we mirror the
|
||||||
|
subpath under a sibling `labels/` directory at the same level.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
/dataset/train/images/sub/img.jpg -> /dataset/train/labels/sub/img.txt
|
||||||
|
"""
|
||||||
|
|
||||||
|
resolved = image_path.expanduser().resolve()
|
||||||
|
|
||||||
|
# Find the nearest ancestor directory named 'images'
|
||||||
|
images_dir: Optional[Path] = None
|
||||||
|
for parent in [resolved.parent, *resolved.parents]:
|
||||||
|
if parent.name.lower() == "images":
|
||||||
|
images_dir = parent
|
||||||
|
break
|
||||||
|
|
||||||
|
if images_dir is not None:
|
||||||
|
rel = resolved.relative_to(images_dir)
|
||||||
|
labels_dir = images_dir.parent / "labels"
|
||||||
|
return (labels_dir / rel).with_suffix(".txt")
|
||||||
|
|
||||||
|
# Fallback: create a local sibling labels folder next to the image.
|
||||||
|
return (resolved.parent / "labels" / resolved.name).with_suffix(".txt")
|
||||||
|
|
||||||
|
def _build_detection_class_index_map(self, detections: List[Dict]) -> Dict[str, int]:
|
||||||
|
"""Build a stable class_name -> YOLO class index mapping.
|
||||||
|
|
||||||
|
Preference order:
|
||||||
|
1) Database object_classes table (alphabetical class_name order)
|
||||||
|
2) Fallback to class_name values present in the detections (alphabetical)
|
||||||
|
"""
|
||||||
|
|
||||||
|
names: List[str] = []
|
||||||
|
try:
|
||||||
|
db_classes = self.db_manager.get_object_classes() or []
|
||||||
|
names = [str(row.get("class_name")) for row in db_classes if row.get("class_name")]
|
||||||
|
except Exception:
|
||||||
|
names = []
|
||||||
|
|
||||||
|
if not names:
|
||||||
|
observed = sorted({str(det.get("class_name")) for det in detections if det.get("class_name")})
|
||||||
|
names = list(observed)
|
||||||
|
|
||||||
|
return {name: idx for idx, name in enumerate(names)}
|
||||||
|
|
||||||
|
def _format_detection_as_yolo_line(self, det: Dict, class_map: Dict[str, int]) -> Optional[str]:
|
||||||
|
"""Convert a detection row to a YOLO label line.
|
||||||
|
|
||||||
|
- If segmentation_mask is present, exports segmentation polygon format:
|
||||||
|
class x1 y1 x2 y2 ...
|
||||||
|
(normalized coordinates)
|
||||||
|
- Otherwise exports bbox format:
|
||||||
|
class x_center y_center width height
|
||||||
|
(normalized coordinates)
|
||||||
|
"""
|
||||||
|
|
||||||
|
class_name = det.get("class_name")
|
||||||
|
if not class_name or class_name not in class_map:
|
||||||
|
return None
|
||||||
|
class_idx = class_map[class_name]
|
||||||
|
|
||||||
|
mask = det.get("segmentation_mask")
|
||||||
|
polygon = self._convert_segmentation_mask_to_polygon(mask)
|
||||||
|
if polygon:
|
||||||
|
coords = " ".join(f"{value:.6f}" for value in polygon)
|
||||||
|
return f"{class_idx} {coords}".strip()
|
||||||
|
|
||||||
|
bbox = self._convert_bbox_to_yolo_xywh(det)
|
||||||
|
if bbox is None:
|
||||||
|
return None
|
||||||
|
x_center, y_center, width, height = bbox
|
||||||
|
return f"{class_idx} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
|
||||||
|
|
||||||
|
def _convert_bbox_to_yolo_xywh(self, det: Dict) -> Optional[Tuple[float, float, float, float]]:
|
||||||
|
"""Convert stored xyxy (normalized) bbox to YOLO xywh (normalized)."""
|
||||||
|
|
||||||
|
x_min = det.get("x_min")
|
||||||
|
y_min = det.get("y_min")
|
||||||
|
x_max = det.get("x_max")
|
||||||
|
y_max = det.get("y_max")
|
||||||
|
if any(v is None for v in (x_min, y_min, x_max, y_max)):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
x_min_f = self._clamp01(float(x_min))
|
||||||
|
y_min_f = self._clamp01(float(y_min))
|
||||||
|
x_max_f = self._clamp01(float(x_max))
|
||||||
|
y_max_f = self._clamp01(float(y_max))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
width = max(0.0, x_max_f - x_min_f)
|
||||||
|
height = max(0.0, y_max_f - y_min_f)
|
||||||
|
if width <= 0.0 or height <= 0.0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
x_center = x_min_f + width / 2.0
|
||||||
|
y_center = y_min_f + height / 2.0
|
||||||
|
return x_center, y_center, width, height
|
||||||
|
|
||||||
|
def _convert_segmentation_mask_to_polygon(self, mask_data) -> List[float]:
|
||||||
|
"""Convert stored segmentation_mask [[x,y], ...] to YOLO polygon coords [x1,y1,...]."""
|
||||||
|
|
||||||
|
if not isinstance(mask_data, list):
|
||||||
|
return []
|
||||||
|
|
||||||
|
coords: List[float] = []
|
||||||
|
for point in mask_data:
|
||||||
|
if not isinstance(point, (list, tuple)) or len(point) < 2:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
x = self._clamp01(float(point[0]))
|
||||||
|
y = self._clamp01(float(point[1]))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
coords.extend([x, y])
|
||||||
|
|
||||||
|
# Need at least 3 points => 6 values.
|
||||||
|
return coords if len(coords) >= 6 else []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _clamp01(value: float) -> float:
|
||||||
|
if value < 0.0:
|
||||||
|
return 0.0
|
||||||
|
if value > 1.0:
|
||||||
|
return 1.0
|
||||||
|
return value
|
||||||
|
|
||||||
def _load_detections_for_selection(self, entry: Dict):
|
def _load_detections_for_selection(self, entry: Dict):
|
||||||
"""Load detection records for the selected image/model pair."""
|
"""Load detection records for the selected image/model pair."""
|
||||||
|
|||||||
@@ -2,45 +2,554 @@
|
|||||||
Validation tab for the microscopy object detection application.
|
Validation tab for the microscopy object detection application.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
from PySide6.QtCore import Qt, QSize
|
||||||
|
from PySide6.QtGui import QPainter, QPixmap
|
||||||
|
from PySide6.QtWidgets import (
|
||||||
|
QWidget,
|
||||||
|
QVBoxLayout,
|
||||||
|
QLabel,
|
||||||
|
QGroupBox,
|
||||||
|
QHBoxLayout,
|
||||||
|
QPushButton,
|
||||||
|
QComboBox,
|
||||||
|
QFormLayout,
|
||||||
|
QScrollArea,
|
||||||
|
QGridLayout,
|
||||||
|
QFrame,
|
||||||
|
QTableWidget,
|
||||||
|
QTableWidgetItem,
|
||||||
|
QHeaderView,
|
||||||
|
QSplitter,
|
||||||
|
QListWidget,
|
||||||
|
QListWidgetItem,
|
||||||
|
QAbstractItemView,
|
||||||
|
QGraphicsView,
|
||||||
|
QGraphicsScene,
|
||||||
|
QGraphicsPixmapItem,
|
||||||
|
)
|
||||||
|
|
||||||
from src.database.db_manager import DatabaseManager
|
from src.database.db_manager import DatabaseManager
|
||||||
from src.utils.config_manager import ConfigManager
|
from src.utils.config_manager import ConfigManager
|
||||||
|
from src.utils.logger import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class _PlotItem:
|
||||||
|
label: str
|
||||||
|
path: Path
|
||||||
|
|
||||||
|
|
||||||
|
class _ZoomableImageView(QGraphicsView):
|
||||||
|
"""Zoomable image viewer.
|
||||||
|
|
||||||
|
- Mouse wheel: zoom in/out
|
||||||
|
- Left mouse drag: pan (ScrollHandDrag)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, parent: Optional[QWidget] = None):
|
||||||
|
super().__init__(parent)
|
||||||
|
self._scene = QGraphicsScene(self)
|
||||||
|
self.setScene(self._scene)
|
||||||
|
self._pixmap_item = QGraphicsPixmapItem()
|
||||||
|
self._scene.addItem(self._pixmap_item)
|
||||||
|
|
||||||
|
# QGraphicsView render hints are QPainter.RenderHints.
|
||||||
|
self.setRenderHints(self.renderHints() | QPainter.RenderHint.SmoothPixmapTransform)
|
||||||
|
self.setDragMode(QGraphicsView.DragMode.ScrollHandDrag)
|
||||||
|
self.setTransformationAnchor(QGraphicsView.ViewportAnchor.AnchorUnderMouse)
|
||||||
|
self.setResizeAnchor(QGraphicsView.ViewportAnchor.AnchorUnderMouse)
|
||||||
|
|
||||||
|
self._has_pixmap = False
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
self._pixmap_item.setPixmap(QPixmap())
|
||||||
|
self._scene.setSceneRect(0, 0, 1, 1)
|
||||||
|
self.resetTransform()
|
||||||
|
self._has_pixmap = False
|
||||||
|
|
||||||
|
def set_pixmap(self, pixmap: QPixmap, *, fit: bool = True) -> None:
|
||||||
|
self._pixmap_item.setPixmap(pixmap)
|
||||||
|
self._scene.setSceneRect(pixmap.rect())
|
||||||
|
self._has_pixmap = not pixmap.isNull()
|
||||||
|
self.resetTransform()
|
||||||
|
if fit and self._has_pixmap:
|
||||||
|
self.fitInView(self._pixmap_item, Qt.AspectRatioMode.KeepAspectRatio)
|
||||||
|
|
||||||
|
def wheelEvent(self, event) -> None: # type: ignore[override]
|
||||||
|
if not self._has_pixmap:
|
||||||
|
return
|
||||||
|
zoom_in_factor = 1.25
|
||||||
|
zoom_out_factor = 1.0 / zoom_in_factor
|
||||||
|
factor = zoom_in_factor if event.angleDelta().y() > 0 else zoom_out_factor
|
||||||
|
self.scale(factor, factor)
|
||||||
|
|
||||||
|
|
||||||
class ValidationTab(QWidget):
|
class ValidationTab(QWidget):
|
||||||
"""Validation tab placeholder."""
|
"""Validation tab that shows stored validation metrics + plots for a selected model."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
|
||||||
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
|
||||||
):
|
|
||||||
super().__init__(parent)
|
super().__init__(parent)
|
||||||
self.db_manager = db_manager
|
self.db_manager = db_manager
|
||||||
self.config_manager = config_manager
|
self.config_manager = config_manager
|
||||||
|
|
||||||
|
self._models: List[Dict[str, Any]] = []
|
||||||
|
self._selected_model_id: Optional[int] = None
|
||||||
|
self._plot_widgets: List[QWidget] = []
|
||||||
|
self._plot_items: List[_PlotItem] = []
|
||||||
|
|
||||||
self._setup_ui()
|
self._setup_ui()
|
||||||
|
self.refresh()
|
||||||
|
|
||||||
def _setup_ui(self):
|
def _setup_ui(self):
|
||||||
"""Setup user interface."""
|
"""Setup user interface."""
|
||||||
layout = QVBoxLayout()
|
layout = QVBoxLayout(self)
|
||||||
|
|
||||||
group = QGroupBox("Validation")
|
# ===== Header controls =====
|
||||||
group_layout = QVBoxLayout()
|
header = QGroupBox("Validation")
|
||||||
label = QLabel(
|
header_layout = QVBoxLayout()
|
||||||
"Validation functionality will be implemented here.\n\n"
|
header_row = QHBoxLayout()
|
||||||
"Features:\n"
|
|
||||||
"- Model validation\n"
|
|
||||||
"- Metrics visualization\n"
|
|
||||||
"- Confusion matrix\n"
|
|
||||||
"- Precision-Recall curves"
|
|
||||||
)
|
|
||||||
group_layout.addWidget(label)
|
|
||||||
group.setLayout(group_layout)
|
|
||||||
|
|
||||||
layout.addWidget(group)
|
header_row.addWidget(QLabel("Select model:"))
|
||||||
layout.addStretch()
|
|
||||||
self.setLayout(layout)
|
self.model_combo = QComboBox()
|
||||||
|
self.model_combo.setMinimumWidth(420)
|
||||||
|
self.model_combo.currentIndexChanged.connect(self._on_model_selected)
|
||||||
|
header_row.addWidget(self.model_combo, 1)
|
||||||
|
|
||||||
|
self.refresh_btn = QPushButton("Refresh")
|
||||||
|
self.refresh_btn.clicked.connect(self.refresh)
|
||||||
|
header_row.addWidget(self.refresh_btn)
|
||||||
|
header_row.addStretch()
|
||||||
|
|
||||||
|
header_layout.addLayout(header_row)
|
||||||
|
self.header_status = QLabel("No models loaded.")
|
||||||
|
self.header_status.setWordWrap(True)
|
||||||
|
header_layout.addWidget(self.header_status)
|
||||||
|
header.setLayout(header_layout)
|
||||||
|
layout.addWidget(header)
|
||||||
|
|
||||||
|
# ===== Metrics =====
|
||||||
|
metrics_group = QGroupBox("Validation Metrics")
|
||||||
|
metrics_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
self.metrics_form = QFormLayout()
|
||||||
|
self.metric_labels: Dict[str, QLabel] = {}
|
||||||
|
for key in ("mAP50", "mAP50-95", "precision", "recall", "fitness"):
|
||||||
|
value_label = QLabel("–")
|
||||||
|
value_label.setTextInteractionFlags(Qt.TextSelectableByMouse)
|
||||||
|
self.metric_labels[key] = value_label
|
||||||
|
self.metrics_form.addRow(f"{key}:", value_label)
|
||||||
|
metrics_layout.addLayout(self.metrics_form)
|
||||||
|
|
||||||
|
self.per_class_table = QTableWidget(0, 3)
|
||||||
|
self.per_class_table.setHorizontalHeaderLabels(["Class", "AP", "AP50"])
|
||||||
|
self.per_class_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Stretch)
|
||||||
|
self.per_class_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeToContents)
|
||||||
|
self.per_class_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeToContents)
|
||||||
|
self.per_class_table.setEditTriggers(QTableWidget.NoEditTriggers)
|
||||||
|
self.per_class_table.setMinimumHeight(160)
|
||||||
|
metrics_layout.addWidget(QLabel("Per-class metrics (if available):"))
|
||||||
|
metrics_layout.addWidget(self.per_class_table)
|
||||||
|
|
||||||
|
metrics_group.setLayout(metrics_layout)
|
||||||
|
layout.addWidget(metrics_group)
|
||||||
|
|
||||||
|
# ===== Plots =====
|
||||||
|
plots_group = QGroupBox("Validation Plots")
|
||||||
|
plots_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
self.plots_status = QLabel("Select a model to see validation plots.")
|
||||||
|
self.plots_status.setWordWrap(True)
|
||||||
|
plots_layout.addWidget(self.plots_status)
|
||||||
|
|
||||||
|
self.plots_splitter = QSplitter(Qt.Orientation.Horizontal)
|
||||||
|
|
||||||
|
# Left: selected image viewer
|
||||||
|
left_widget = QWidget()
|
||||||
|
left_layout = QVBoxLayout(left_widget)
|
||||||
|
left_layout.setContentsMargins(0, 0, 0, 0)
|
||||||
|
|
||||||
|
self.selected_plot_title = QLabel("No image selected.")
|
||||||
|
self.selected_plot_title.setWordWrap(True)
|
||||||
|
self.selected_plot_title.setTextInteractionFlags(Qt.TextSelectableByMouse)
|
||||||
|
left_layout.addWidget(self.selected_plot_title)
|
||||||
|
|
||||||
|
self.plot_view = _ZoomableImageView()
|
||||||
|
self.plot_view.setMinimumHeight(360)
|
||||||
|
left_layout.addWidget(self.plot_view, 1)
|
||||||
|
|
||||||
|
self.selected_plot_path = QLabel("")
|
||||||
|
self.selected_plot_path.setWordWrap(True)
|
||||||
|
self.selected_plot_path.setStyleSheet("color: #888;")
|
||||||
|
self.selected_plot_path.setTextInteractionFlags(Qt.TextSelectableByMouse)
|
||||||
|
left_layout.addWidget(self.selected_plot_path)
|
||||||
|
|
||||||
|
# Right: scrollable list
|
||||||
|
right_widget = QWidget()
|
||||||
|
right_layout = QVBoxLayout(right_widget)
|
||||||
|
right_layout.setContentsMargins(0, 0, 0, 0)
|
||||||
|
right_layout.addWidget(QLabel("Images:"))
|
||||||
|
|
||||||
|
self.plots_list = QListWidget()
|
||||||
|
self.plots_list.setSelectionMode(QAbstractItemView.SelectionMode.SingleSelection)
|
||||||
|
self.plots_list.setIconSize(QSize(160, 160))
|
||||||
|
self.plots_list.itemSelectionChanged.connect(self._on_plot_item_selected)
|
||||||
|
right_layout.addWidget(self.plots_list, 1)
|
||||||
|
|
||||||
|
self.plots_splitter.addWidget(left_widget)
|
||||||
|
self.plots_splitter.addWidget(right_widget)
|
||||||
|
self.plots_splitter.setStretchFactor(0, 3)
|
||||||
|
self.plots_splitter.setStretchFactor(1, 1)
|
||||||
|
plots_layout.addWidget(self.plots_splitter, 1)
|
||||||
|
|
||||||
|
plots_group.setLayout(plots_layout)
|
||||||
|
layout.addWidget(plots_group, 1)
|
||||||
|
|
||||||
|
layout.addStretch(0)
|
||||||
|
|
||||||
|
self._clear_metrics()
|
||||||
|
self._clear_plots()
|
||||||
|
|
||||||
|
# ==================== Public API ====================
|
||||||
|
|
||||||
def refresh(self):
|
def refresh(self):
|
||||||
"""Refresh the tab."""
|
"""Refresh the tab."""
|
||||||
|
self._load_models()
|
||||||
|
self._populate_model_combo()
|
||||||
|
self._restore_or_select_default_model()
|
||||||
|
|
||||||
|
# ==================== Internal: models ====================
|
||||||
|
|
||||||
|
def _load_models(self) -> None:
|
||||||
|
try:
|
||||||
|
self._models = self.db_manager.get_models() or []
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to load models: %s", exc)
|
||||||
|
self._models = []
|
||||||
|
|
||||||
|
def _populate_model_combo(self) -> None:
|
||||||
|
self.model_combo.blockSignals(True)
|
||||||
|
self.model_combo.clear()
|
||||||
|
self.model_combo.addItem("Select a model…", None)
|
||||||
|
|
||||||
|
for model in self._models:
|
||||||
|
model_id = model.get("id")
|
||||||
|
name = (model.get("model_name") or "").strip()
|
||||||
|
version = (model.get("model_version") or "").strip()
|
||||||
|
created_at = model.get("created_at")
|
||||||
|
label = f"{name} {version}".strip()
|
||||||
|
if created_at:
|
||||||
|
label = f"{label} ({created_at})"
|
||||||
|
self.model_combo.addItem(label, model_id)
|
||||||
|
|
||||||
|
self.model_combo.blockSignals(False)
|
||||||
|
|
||||||
|
if self._models:
|
||||||
|
self.header_status.setText(f"Loaded {len(self._models)} model(s).")
|
||||||
|
else:
|
||||||
|
self.header_status.setText("No models found. Train a model first.")
|
||||||
|
|
||||||
|
def _restore_or_select_default_model(self) -> None:
|
||||||
|
if not self._models:
|
||||||
|
self._selected_model_id = None
|
||||||
|
self._clear_metrics()
|
||||||
|
self._clear_plots()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Keep selection if still present.
|
||||||
|
if self._selected_model_id is not None:
|
||||||
|
for idx in range(1, self.model_combo.count()):
|
||||||
|
if self.model_combo.itemData(idx) == self._selected_model_id:
|
||||||
|
self.model_combo.setCurrentIndex(idx)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Otherwise select the newest model (top of get_models ORDER BY created_at DESC).
|
||||||
|
first_model_id = self.model_combo.itemData(1) if self.model_combo.count() > 1 else None
|
||||||
|
if first_model_id is not None:
|
||||||
|
self.model_combo.setCurrentIndex(1)
|
||||||
|
|
||||||
|
def _on_model_selected(self, index: int) -> None:
|
||||||
|
model_id = self.model_combo.itemData(index)
|
||||||
|
if not model_id:
|
||||||
|
self._selected_model_id = None
|
||||||
|
self._clear_metrics()
|
||||||
|
self._clear_plots()
|
||||||
|
self.plots_status.setText("Select a model to see validation plots.")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._selected_model_id = int(model_id)
|
||||||
|
model = self._get_model_by_id(self._selected_model_id)
|
||||||
|
if not model:
|
||||||
|
self._clear_metrics()
|
||||||
|
self._clear_plots()
|
||||||
|
self.plots_status.setText("Selected model not found.")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._render_metrics(model)
|
||||||
|
self._render_plots(model)
|
||||||
|
|
||||||
|
def _get_model_by_id(self, model_id: int) -> Optional[Dict[str, Any]]:
|
||||||
|
for model in self._models:
|
||||||
|
if model.get("id") == model_id:
|
||||||
|
return model
|
||||||
|
try:
|
||||||
|
return self.db_manager.get_model_by_id(model_id)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ==================== Internal: metrics ====================
|
||||||
|
|
||||||
|
def _clear_metrics(self) -> None:
|
||||||
|
for label in self.metric_labels.values():
|
||||||
|
label.setText("–")
|
||||||
|
self.per_class_table.setRowCount(0)
|
||||||
|
|
||||||
|
def _render_metrics(self, model: Dict[str, Any]) -> None:
|
||||||
|
self._clear_metrics()
|
||||||
|
|
||||||
|
metrics: Dict[str, Any] = model.get("metrics") or {}
|
||||||
|
# Training tab stores metrics under results['metrics'] in training results payload.
|
||||||
|
if isinstance(metrics, dict) and "metrics" in metrics and isinstance(metrics.get("metrics"), dict):
|
||||||
|
metrics = metrics.get("metrics") or {}
|
||||||
|
|
||||||
|
def set_metric(key: str, value: Any) -> None:
|
||||||
|
if key not in self.metric_labels:
|
||||||
|
return
|
||||||
|
if value is None:
|
||||||
|
self.metric_labels[key].setText("–")
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
self.metric_labels[key].setText(f"{float(value):.4f}")
|
||||||
|
except Exception:
|
||||||
|
self.metric_labels[key].setText(str(value))
|
||||||
|
|
||||||
|
set_metric("mAP50", metrics.get("mAP50"))
|
||||||
|
set_metric("mAP50-95", metrics.get("mAP50-95") or metrics.get("mAP50_95") or metrics.get("mAP50-95"))
|
||||||
|
set_metric("precision", metrics.get("precision"))
|
||||||
|
set_metric("recall", metrics.get("recall"))
|
||||||
|
set_metric("fitness", metrics.get("fitness"))
|
||||||
|
|
||||||
|
# Optional per-class metrics
|
||||||
|
class_metrics = metrics.get("class_metrics") if isinstance(metrics, dict) else None
|
||||||
|
if isinstance(class_metrics, dict) and class_metrics:
|
||||||
|
items = sorted(class_metrics.items(), key=lambda kv: str(kv[0]))
|
||||||
|
self.per_class_table.setRowCount(len(items))
|
||||||
|
for row, (cls_name, cls_stats) in enumerate(items):
|
||||||
|
ap = (cls_stats or {}).get("ap")
|
||||||
|
ap50 = (cls_stats or {}).get("ap50")
|
||||||
|
self.per_class_table.setItem(row, 0, QTableWidgetItem(str(cls_name)))
|
||||||
|
self.per_class_table.setItem(row, 1, QTableWidgetItem(self._format_float(ap)))
|
||||||
|
self.per_class_table.setItem(row, 2, QTableWidgetItem(self._format_float(ap50)))
|
||||||
|
else:
|
||||||
|
self.per_class_table.setRowCount(0)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _format_float(value: Any) -> str:
|
||||||
|
if value is None:
|
||||||
|
return "–"
|
||||||
|
try:
|
||||||
|
return f"{float(value):.4f}"
|
||||||
|
except Exception:
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
# ==================== Internal: plots ====================
|
||||||
|
|
||||||
|
def _clear_plots(self) -> None:
|
||||||
|
# Remove legacy grid widgets (from the initial implementation).
|
||||||
|
for widget in self._plot_widgets:
|
||||||
|
widget.setParent(None)
|
||||||
|
widget.deleteLater()
|
||||||
|
self._plot_widgets = []
|
||||||
|
|
||||||
|
self._plot_items = []
|
||||||
|
|
||||||
|
if hasattr(self, "plots_list"):
|
||||||
|
self.plots_list.blockSignals(True)
|
||||||
|
self.plots_list.clear()
|
||||||
|
self.plots_list.blockSignals(False)
|
||||||
|
|
||||||
|
if hasattr(self, "plot_view"):
|
||||||
|
self.plot_view.clear()
|
||||||
|
if hasattr(self, "selected_plot_title"):
|
||||||
|
self.selected_plot_title.setText("No image selected.")
|
||||||
|
if hasattr(self, "selected_plot_path"):
|
||||||
|
self.selected_plot_path.setText("")
|
||||||
|
|
||||||
|
def _render_plots(self, model: Dict[str, Any]) -> None:
|
||||||
|
self._clear_plots()
|
||||||
|
|
||||||
|
plot_dirs = self._infer_run_directories(model)
|
||||||
|
plot_items = self._discover_plot_items(plot_dirs)
|
||||||
|
|
||||||
|
if not plot_items:
|
||||||
|
dirs_text = "\n".join(str(p) for p in plot_dirs if p)
|
||||||
|
self.plots_status.setText(
|
||||||
|
"No validation plot images found for this model.\n\n"
|
||||||
|
"Searched directories:\n" + (dirs_text or "(none)")
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
self._plot_items = list(plot_items)
|
||||||
|
self.plots_status.setText(f"Found {len(plot_items)} plot image(s). Select one to view/zoom.")
|
||||||
|
|
||||||
|
self.plots_list.blockSignals(True)
|
||||||
|
self.plots_list.clear()
|
||||||
|
for idx, item in enumerate(self._plot_items):
|
||||||
|
qitem = QListWidgetItem(item.label)
|
||||||
|
qitem.setData(Qt.ItemDataRole.UserRole, idx)
|
||||||
|
|
||||||
|
pix = QPixmap(str(item.path))
|
||||||
|
if not pix.isNull():
|
||||||
|
thumb = pix.scaled(
|
||||||
|
self.plots_list.iconSize(),
|
||||||
|
Qt.AspectRatioMode.KeepAspectRatio,
|
||||||
|
Qt.TransformationMode.SmoothTransformation,
|
||||||
|
)
|
||||||
|
qitem.setIcon(thumb)
|
||||||
|
self.plots_list.addItem(qitem)
|
||||||
|
self.plots_list.blockSignals(False)
|
||||||
|
|
||||||
|
if self.plots_list.count() > 0:
|
||||||
|
self.plots_list.setCurrentRow(0)
|
||||||
|
|
||||||
|
def _on_plot_item_selected(self) -> None:
|
||||||
|
if not self._plot_items:
|
||||||
|
return
|
||||||
|
|
||||||
|
selected = self.plots_list.selectedItems()
|
||||||
|
if not selected:
|
||||||
|
return
|
||||||
|
|
||||||
|
idx = selected[0].data(Qt.ItemDataRole.UserRole)
|
||||||
|
try:
|
||||||
|
idx_int = int(idx)
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
if idx_int < 0 or idx_int >= len(self._plot_items):
|
||||||
|
return
|
||||||
|
|
||||||
|
plot = self._plot_items[idx_int]
|
||||||
|
self.selected_plot_title.setText(plot.label)
|
||||||
|
self.selected_plot_path.setText(str(plot.path))
|
||||||
|
|
||||||
|
pix = QPixmap(str(plot.path))
|
||||||
|
if pix.isNull():
|
||||||
|
self.plot_view.clear()
|
||||||
|
return
|
||||||
|
self.plot_view.set_pixmap(pix, fit=True)
|
||||||
|
|
||||||
|
def _infer_run_directories(self, model: Dict[str, Any]) -> List[Path]:
|
||||||
|
dirs: List[Path] = []
|
||||||
|
|
||||||
|
# 1) Infer from model_path: .../<run>/weights/best.pt -> <run>
|
||||||
|
model_path = model.get("model_path")
|
||||||
|
if model_path:
|
||||||
|
try:
|
||||||
|
p = Path(str(model_path)).expanduser()
|
||||||
|
if p.name.lower().endswith(".pt"):
|
||||||
|
# If it lives under weights/, use parent.parent.
|
||||||
|
if p.parent.name == "weights" and p.parent.parent.exists():
|
||||||
|
dirs.append(p.parent.parent)
|
||||||
|
elif p.parent.exists():
|
||||||
|
dirs.append(p.parent)
|
||||||
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# 2) Look at training_params.stage_results[].results.save_dir
|
||||||
|
training_params = model.get("training_params") or {}
|
||||||
|
stage_results = None
|
||||||
|
if isinstance(training_params, dict):
|
||||||
|
stage_results = training_params.get("stage_results")
|
||||||
|
if isinstance(stage_results, list):
|
||||||
|
for stage in stage_results:
|
||||||
|
results = (stage or {}).get("results")
|
||||||
|
save_dir = (results or {}).get("save_dir") if isinstance(results, dict) else None
|
||||||
|
if save_dir:
|
||||||
|
try:
|
||||||
|
save_path = Path(str(save_dir)).expanduser()
|
||||||
|
if save_path.exists():
|
||||||
|
dirs.append(save_path)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Deduplicate while preserving order.
|
||||||
|
unique: List[Path] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for d in dirs:
|
||||||
|
try:
|
||||||
|
resolved = str(d.resolve())
|
||||||
|
except Exception:
|
||||||
|
resolved = str(d)
|
||||||
|
if resolved not in seen and d.exists() and d.is_dir():
|
||||||
|
seen.add(resolved)
|
||||||
|
unique.append(d)
|
||||||
|
return unique
|
||||||
|
|
||||||
|
def _discover_plot_items(self, directories: Sequence[Path]) -> List[_PlotItem]:
|
||||||
|
# Prefer canonical Ultralytics filenames first, then fall back to any png/jpg.
|
||||||
|
preferred_names = [
|
||||||
|
"results.png",
|
||||||
|
"results.jpg",
|
||||||
|
"confusion_matrix.png",
|
||||||
|
"confusion_matrix_normalized.png",
|
||||||
|
"labels.jpg",
|
||||||
|
"labels.png",
|
||||||
|
"BoxPR_curve.png",
|
||||||
|
"BoxP_curve.png",
|
||||||
|
"BoxR_curve.png",
|
||||||
|
"BoxF1_curve.png",
|
||||||
|
"MaskPR_curve.png",
|
||||||
|
"MaskP_curve.png",
|
||||||
|
"MaskR_curve.png",
|
||||||
|
"MaskF1_curve.png",
|
||||||
|
"val_batch0_pred.jpg",
|
||||||
|
"val_batch0_labels.jpg",
|
||||||
|
]
|
||||||
|
|
||||||
|
found: List[_PlotItem] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
|
||||||
|
for d in directories:
|
||||||
|
# 1) Preferred
|
||||||
|
for name in preferred_names:
|
||||||
|
p = d / name
|
||||||
|
if p.exists() and p.is_file():
|
||||||
|
key = str(p)
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
found.append(_PlotItem(label=f"{name} (from {d.name})", path=p))
|
||||||
|
|
||||||
|
# 2) Curated globs
|
||||||
|
for pattern in ("train_batch*.jpg", "val_batch*.jpg", "*curve*.png"):
|
||||||
|
for p in sorted(d.glob(pattern)):
|
||||||
|
if not p.is_file():
|
||||||
|
continue
|
||||||
|
key = str(p)
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
found.append(_PlotItem(label=f"{p.name} (from {d.name})", path=p))
|
||||||
|
|
||||||
|
# 3) Fallback: any top-level png/jpg (excluding weights dir contents)
|
||||||
|
for ext in ("*.png", "*.jpg", "*.jpeg", "*.webp"):
|
||||||
|
for p in sorted(d.glob(ext)):
|
||||||
|
if not p.is_file():
|
||||||
|
continue
|
||||||
|
key = str(p)
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
found.append(_PlotItem(label=f"{p.name} (from {d.name})", path=p))
|
||||||
|
|
||||||
|
# Keep list bounded to avoid UI overload for huge runs.
|
||||||
|
return found[:60]
|
||||||
|
|||||||
103
src/utils/create_mask_from_detection.py
Normal file
103
src/utils/create_mask_from_detection.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from skimage.draw import polygon
|
||||||
|
from tifffile import TiffFile
|
||||||
|
|
||||||
|
from src.database.db_manager import DatabaseManager
|
||||||
|
|
||||||
|
|
||||||
|
def read_image(image_path: Path) -> np.ndarray:
|
||||||
|
metadata = {}
|
||||||
|
with TiffFile(image_path) as tif:
|
||||||
|
image = tif.asarray()
|
||||||
|
metadata = tif.imagej_metadata
|
||||||
|
return image, metadata
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
polygon_vertices = np.array([[10, 10], [50, 10], [50, 50], [10, 50]])
|
||||||
|
image = np.zeros((100, 100), dtype=np.uint8)
|
||||||
|
rr, cc = polygon(polygon_vertices[:, 0], polygon_vertices[:, 1])
|
||||||
|
image[rr, cc] = 255
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
db = DatabaseManager()
|
||||||
|
model_name = "c17"
|
||||||
|
model_id = db.get_models(filters={"model_name": model_name})[0]["id"]
|
||||||
|
print(f"Model name {model_name}, id {model_id}")
|
||||||
|
detections = db.get_detections(filters={"model_id": model_id})
|
||||||
|
|
||||||
|
file_stems = set()
|
||||||
|
|
||||||
|
for detection in detections:
|
||||||
|
file_stems.add(detection["image_filename"].split("_")[0])
|
||||||
|
|
||||||
|
print("Files:", file_stems)
|
||||||
|
|
||||||
|
for stem in file_stems:
|
||||||
|
print(stem)
|
||||||
|
detections = db.get_detections(filters={"model_id": model_id, "i.filename": f"LIKE %{stem}%"})
|
||||||
|
annotations = []
|
||||||
|
for detection in detections:
|
||||||
|
source_path = Path(detection["metadata"]["source_path"])
|
||||||
|
image, metadata = read_image(source_path)
|
||||||
|
|
||||||
|
offset = np.array(list(map(int, metadata["tile_section"].split(","))))[::-1]
|
||||||
|
scale = np.array(list(map(int, metadata["patch_size"].split(","))))[::-1]
|
||||||
|
# tile_size = np.array(list(map(int, metadata["tile_size"].split(","))))
|
||||||
|
segmentation = np.array(detection["segmentation_mask"]) # * tile_size
|
||||||
|
|
||||||
|
# print(source_path, image, metadata, segmentation.shape)
|
||||||
|
# print(offset)
|
||||||
|
# print(scale)
|
||||||
|
# print(segmentation)
|
||||||
|
|
||||||
|
# segmentation = (segmentation + offset * tile_size) / (tile_size * scale)
|
||||||
|
segmentation = (segmentation + offset) / scale
|
||||||
|
|
||||||
|
yolo_annotation = f"{detection['metadata']['class_id']} " + " ".join(
|
||||||
|
[f"{x:.6f} {y:.6f}" for x, y in segmentation]
|
||||||
|
)
|
||||||
|
annotations.append(yolo_annotation)
|
||||||
|
# print(segmentation)
|
||||||
|
# print(yolo_annotation)
|
||||||
|
|
||||||
|
# aa
|
||||||
|
print(
|
||||||
|
" ",
|
||||||
|
detection["model_name"],
|
||||||
|
detection["image_id"],
|
||||||
|
detection["image_filename"],
|
||||||
|
source_path,
|
||||||
|
metadata["label_path"],
|
||||||
|
)
|
||||||
|
# section_i_section_j = detection["image_filename"].split("_")[1].split(".")[0]
|
||||||
|
# print(" ", section_i_section_j)
|
||||||
|
|
||||||
|
label_path = metadata["label_path"]
|
||||||
|
print(" ", label_path)
|
||||||
|
with open(label_path, "w") as f:
|
||||||
|
f.write("\n".join(annotations))
|
||||||
|
|
||||||
|
exit()
|
||||||
|
|
||||||
|
for detection in detections:
|
||||||
|
print(detection["model_name"], detection["image_id"], detection["image_filename"])
|
||||||
|
|
||||||
|
print(detections[0])
|
||||||
|
# polygon_vertices = np.array([[10, 10], [50, 10], [50, 50], [10, 50]])
|
||||||
|
|
||||||
|
# image = np.zeros((100, 100), dtype=np.uint8)
|
||||||
|
|
||||||
|
# rr, cc = polygon(polygon_vertices[:, 0], polygon_vertices[:, 1])
|
||||||
|
|
||||||
|
# image[rr, cc] = 255
|
||||||
|
|
||||||
|
# import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
# plt.imshow(image, cmap='gray')
|
||||||
|
# plt.show()
|
||||||
@@ -189,25 +189,30 @@ def main():
|
|||||||
# continue and just show image
|
# continue and just show image
|
||||||
out = draw_annotations(img.copy(), labels, alpha=args.alpha, draw_bbox_for_poly=(not args.no_bbox))
|
out = draw_annotations(img.copy(), labels, alpha=args.alpha, draw_bbox_for_poly=(not args.no_bbox))
|
||||||
|
|
||||||
lclass, coords = labels[0]
|
out_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
print(lclass, coords)
|
plt.figure(figsize=(10, 10 * out.shape[0] / out.shape[1]))
|
||||||
|
if 0:
|
||||||
|
plt.imshow(out_rgb.transpose(1, 0, 2))
|
||||||
|
else:
|
||||||
|
plt.imshow(out_rgb)
|
||||||
|
|
||||||
|
for label in labels:
|
||||||
|
lclass, coords = label
|
||||||
|
# print(lclass, coords)
|
||||||
bbox = coords[:4]
|
bbox = coords[:4]
|
||||||
print("bbox", bbox)
|
# print("bbox", bbox)
|
||||||
bbox = np.array(bbox) * np.array([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
|
bbox = np.array(bbox) * np.array([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
|
||||||
yc, xc, h, w = bbox
|
yc, xc, h, w = bbox
|
||||||
print("bbox", bbox)
|
# print("bbox", bbox)
|
||||||
|
|
||||||
# polyline = np.array(coords[4:]).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
|
# polyline = np.array(coords[4:]).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
|
||||||
polyline = np.array(coords).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
|
polyline = np.array(coords).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
|
||||||
print("pl", coords[4:])
|
# print("pl", coords[4:])
|
||||||
print("pl", polyline)
|
# print("pl", polyline)
|
||||||
|
|
||||||
# Convert BGR -> RGB for matplotlib display
|
# Convert BGR -> RGB for matplotlib display
|
||||||
# out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
|
# out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
|
||||||
out_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
||||||
# out_rgb = Image()
|
# out_rgb = Image()
|
||||||
plt.figure(figsize=(10, 10 * out.shape[0] / out.shape[1]))
|
|
||||||
plt.imshow(out_rgb)
|
|
||||||
plt.plot(polyline[:, 0], polyline[:, 1], "y", linewidth=2)
|
plt.plot(polyline[:, 0], polyline[:, 1], "y", linewidth=2)
|
||||||
if 0:
|
if 0:
|
||||||
plt.plot(
|
plt.plot(
|
||||||
|
|||||||
Reference in New Issue
Block a user