Adding result shower
This commit is contained in:
@@ -450,6 +450,25 @@ class DatabaseManager:
|
||||
filters["model_id"] = model_id
|
||||
return self.get_detections(filters)
|
||||
|
||||
def delete_detections_for_image(
|
||||
self, image_id: int, model_id: Optional[int] = None
|
||||
) -> int:
|
||||
"""Delete detections tied to a specific image and optional model."""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
if model_id is not None:
|
||||
cursor.execute(
|
||||
"DELETE FROM detections WHERE image_id = ? AND model_id = ?",
|
||||
(image_id, model_id),
|
||||
)
|
||||
else:
|
||||
cursor.execute("DELETE FROM detections WHERE image_id = ?", (image_id,))
|
||||
conn.commit()
|
||||
return cursor.rowcount
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def delete_detections_for_model(self, model_id: int) -> int:
|
||||
"""Delete all detections for a specific model."""
|
||||
conn = self.get_connection()
|
||||
|
||||
@@ -20,6 +20,7 @@ from PySide6.QtWidgets import (
|
||||
)
|
||||
from PySide6.QtCore import Qt, QThread, Signal
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from src.database.db_manager import DatabaseManager
|
||||
from src.utils.config_manager import ConfigManager
|
||||
@@ -147,30 +148,66 @@ class DetectionTab(QWidget):
|
||||
self.model_combo.currentIndexChanged.connect(self._on_model_changed)
|
||||
|
||||
def _load_models(self):
|
||||
"""Load available models from database."""
|
||||
"""Load available models from database and local storage."""
|
||||
try:
|
||||
models = self.db_manager.get_models()
|
||||
self.model_combo.clear()
|
||||
models = self.db_manager.get_models()
|
||||
has_models = False
|
||||
|
||||
if not models:
|
||||
self.model_combo.addItem("No models available", None)
|
||||
self._set_buttons_enabled(False)
|
||||
return
|
||||
known_paths = set()
|
||||
|
||||
# Add base model option
|
||||
# Add base model option first (always available)
|
||||
base_model = self.config_manager.get(
|
||||
"models.default_base_model", "yolov8s-seg.pt"
|
||||
)
|
||||
self.model_combo.addItem(
|
||||
f"Base Model ({base_model})", {"id": 0, "path": base_model}
|
||||
)
|
||||
if base_model:
|
||||
base_data = {
|
||||
"id": 0,
|
||||
"path": base_model,
|
||||
"model_name": Path(base_model).stem or "Base Model",
|
||||
"model_version": "pretrained",
|
||||
"base_model": base_model,
|
||||
"source": "base",
|
||||
}
|
||||
self.model_combo.addItem(f"Base Model ({base_model})", base_data)
|
||||
known_paths.add(self._normalize_model_path(base_model))
|
||||
has_models = True
|
||||
|
||||
# Add trained models
|
||||
# Add trained models from database
|
||||
for model in models:
|
||||
display_name = f"{model['model_name']} v{model['model_version']}"
|
||||
self.model_combo.addItem(display_name, model)
|
||||
model_data = {**model, "path": model.get("model_path")}
|
||||
normalized = self._normalize_model_path(model_data.get("path"))
|
||||
if normalized:
|
||||
known_paths.add(normalized)
|
||||
self.model_combo.addItem(display_name, model_data)
|
||||
has_models = True
|
||||
|
||||
self._set_buttons_enabled(True)
|
||||
# Discover local model files not yet in the database
|
||||
local_models = self._discover_local_models()
|
||||
for model_path in local_models:
|
||||
normalized = self._normalize_model_path(model_path)
|
||||
if normalized in known_paths:
|
||||
continue
|
||||
|
||||
display_name = f"Local Model ({Path(model_path).stem})"
|
||||
model_data = {
|
||||
"id": None,
|
||||
"path": str(model_path),
|
||||
"model_name": Path(model_path).stem,
|
||||
"model_version": "local",
|
||||
"base_model": Path(model_path).stem,
|
||||
"source": "local",
|
||||
}
|
||||
self.model_combo.addItem(display_name, model_data)
|
||||
known_paths.add(normalized)
|
||||
has_models = True
|
||||
|
||||
if not has_models:
|
||||
self.model_combo.addItem("No models available", None)
|
||||
self._set_buttons_enabled(False)
|
||||
else:
|
||||
self._set_buttons_enabled(True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading models: {e}")
|
||||
@@ -249,25 +286,39 @@ class DetectionTab(QWidget):
|
||||
QMessageBox.warning(self, "No Model", "Please select a model first.")
|
||||
return
|
||||
|
||||
model_path = model_data["path"]
|
||||
model_id = model_data["id"]
|
||||
model_path = model_data.get("path")
|
||||
if not model_path:
|
||||
QMessageBox.warning(
|
||||
self, "Invalid Model", "Selected model is missing a file path."
|
||||
)
|
||||
return
|
||||
|
||||
# Ensure we have a valid model ID (create entry for base model if needed)
|
||||
if model_id == 0:
|
||||
# Create database entry for base model
|
||||
base_model = self.config_manager.get(
|
||||
"models.default_base_model", "yolov8s-seg.pt"
|
||||
)
|
||||
model_id = self.db_manager.add_model(
|
||||
model_name="Base Model",
|
||||
model_version="pretrained",
|
||||
model_path=base_model,
|
||||
base_model=base_model,
|
||||
if not Path(model_path).exists():
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Model Not Found",
|
||||
f"The selected model file could not be found:\n{model_path}",
|
||||
)
|
||||
return
|
||||
|
||||
model_id = model_data.get("id")
|
||||
|
||||
# Ensure we have a database entry for the selected model
|
||||
if model_id in (None, 0):
|
||||
model_id = self._ensure_model_record(model_data)
|
||||
if not model_id:
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Model Registration Failed",
|
||||
"Unable to register the selected model in the database.",
|
||||
)
|
||||
return
|
||||
|
||||
normalized_model_path = self._normalize_model_path(model_path) or model_path
|
||||
|
||||
# Create inference engine
|
||||
self.inference_engine = InferenceEngine(
|
||||
model_path, self.db_manager, model_id
|
||||
normalized_model_path, self.db_manager, model_id
|
||||
)
|
||||
|
||||
# Get confidence threshold
|
||||
@@ -338,6 +389,76 @@ class DetectionTab(QWidget):
|
||||
self.batch_btn.setEnabled(enabled)
|
||||
self.model_combo.setEnabled(enabled)
|
||||
|
||||
def _discover_local_models(self) -> list:
|
||||
"""Scan the models directory for standalone .pt files."""
|
||||
models_dir = self.config_manager.get_models_directory()
|
||||
if not models_dir:
|
||||
return []
|
||||
|
||||
models_path = Path(models_dir)
|
||||
if not models_path.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
return sorted(
|
||||
[p for p in models_path.rglob("*.pt") if p.is_file()],
|
||||
key=lambda p: str(p).lower(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error discovering local models: {e}")
|
||||
return []
|
||||
|
||||
def _normalize_model_path(self, path_value) -> str:
|
||||
"""Return a normalized absolute path string for comparison."""
|
||||
if not path_value:
|
||||
return ""
|
||||
try:
|
||||
return str(Path(path_value).resolve())
|
||||
except Exception:
|
||||
return str(path_value)
|
||||
|
||||
def _ensure_model_record(self, model_data: dict) -> Optional[int]:
|
||||
"""Ensure a database record exists for the selected model."""
|
||||
model_path = model_data.get("path")
|
||||
if not model_path:
|
||||
return None
|
||||
|
||||
normalized_target = self._normalize_model_path(model_path)
|
||||
|
||||
try:
|
||||
existing_models = self.db_manager.get_models()
|
||||
for model in existing_models:
|
||||
existing_path = model.get("model_path")
|
||||
if not existing_path:
|
||||
continue
|
||||
normalized_existing = self._normalize_model_path(existing_path)
|
||||
if (
|
||||
normalized_existing == normalized_target
|
||||
or existing_path == model_path
|
||||
):
|
||||
return model["id"]
|
||||
|
||||
model_name = (
|
||||
model_data.get("model_name") or Path(model_path).stem or "Custom Model"
|
||||
)
|
||||
model_version = (
|
||||
model_data.get("model_version") or model_data.get("source") or "local"
|
||||
)
|
||||
base_model = model_data.get(
|
||||
"base_model",
|
||||
self.config_manager.get("models.default_base_model", "yolov8s-seg.pt"),
|
||||
)
|
||||
|
||||
return self.db_manager.add_model(
|
||||
model_name=model_name,
|
||||
model_version=model_version,
|
||||
model_path=normalized_target,
|
||||
base_model=base_model,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ensure model record for {model_path}: {e}")
|
||||
return None
|
||||
|
||||
def refresh(self):
|
||||
"""Refresh the tab."""
|
||||
self._load_models()
|
||||
|
||||
@@ -1,15 +1,39 @@
|
||||
"""
|
||||
Results tab for the microscopy object detection application.
|
||||
Results tab for browsing stored detections and visualizing overlays.
|
||||
"""
|
||||
|
||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from PySide6.QtWidgets import (
|
||||
QWidget,
|
||||
QVBoxLayout,
|
||||
QHBoxLayout,
|
||||
QLabel,
|
||||
QGroupBox,
|
||||
QPushButton,
|
||||
QSplitter,
|
||||
QTableWidget,
|
||||
QTableWidgetItem,
|
||||
QHeaderView,
|
||||
QAbstractItemView,
|
||||
QMessageBox,
|
||||
QCheckBox,
|
||||
)
|
||||
from PySide6.QtCore import Qt
|
||||
|
||||
from src.database.db_manager import DatabaseManager
|
||||
from src.utils.config_manager import ConfigManager
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.image import Image, ImageLoadError
|
||||
from src.gui.widgets import AnnotationCanvasWidget
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ResultsTab(QWidget):
|
||||
"""Results tab placeholder."""
|
||||
"""Results tab showing detection history and preview overlays."""
|
||||
|
||||
def __init__(
|
||||
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
||||
@@ -18,29 +42,387 @@ class ResultsTab(QWidget):
|
||||
self.db_manager = db_manager
|
||||
self.config_manager = config_manager
|
||||
|
||||
self.detection_summary: List[Dict] = []
|
||||
self.current_selection: Optional[Dict] = None
|
||||
self.current_image: Optional[Image] = None
|
||||
self.current_detections: List[Dict] = []
|
||||
self._image_path_cache: Dict[str, str] = {}
|
||||
|
||||
self._setup_ui()
|
||||
self.refresh()
|
||||
|
||||
def _setup_ui(self):
|
||||
"""Setup user interface."""
|
||||
layout = QVBoxLayout()
|
||||
|
||||
group = QGroupBox("Results")
|
||||
group_layout = QVBoxLayout()
|
||||
label = QLabel(
|
||||
"Results viewer will be implemented here.\n\n"
|
||||
"Features:\n"
|
||||
"- Detection history browser\n"
|
||||
"- Advanced filtering\n"
|
||||
"- Statistics dashboard\n"
|
||||
"- Export functionality"
|
||||
)
|
||||
group_layout.addWidget(label)
|
||||
group.setLayout(group_layout)
|
||||
# Splitter for list + preview
|
||||
splitter = QSplitter(Qt.Horizontal)
|
||||
|
||||
layout.addWidget(group)
|
||||
layout.addStretch()
|
||||
# Left pane: detection list
|
||||
left_container = QWidget()
|
||||
left_layout = QVBoxLayout()
|
||||
left_layout.setContentsMargins(0, 0, 0, 0)
|
||||
|
||||
controls_layout = QHBoxLayout()
|
||||
self.refresh_btn = QPushButton("Refresh")
|
||||
self.refresh_btn.clicked.connect(self.refresh)
|
||||
controls_layout.addWidget(self.refresh_btn)
|
||||
controls_layout.addStretch()
|
||||
left_layout.addLayout(controls_layout)
|
||||
|
||||
self.results_table = QTableWidget(0, 5)
|
||||
self.results_table.setHorizontalHeaderLabels(
|
||||
["Image", "Model", "Detections", "Classes", "Last Updated"]
|
||||
)
|
||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
||||
0, QHeaderView.Stretch
|
||||
)
|
||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
||||
1, QHeaderView.Stretch
|
||||
)
|
||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
||||
2, QHeaderView.ResizeToContents
|
||||
)
|
||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
||||
3, QHeaderView.Stretch
|
||||
)
|
||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
||||
4, QHeaderView.ResizeToContents
|
||||
)
|
||||
self.results_table.setSelectionBehavior(QAbstractItemView.SelectRows)
|
||||
self.results_table.setSelectionMode(QAbstractItemView.SingleSelection)
|
||||
self.results_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
|
||||
self.results_table.itemSelectionChanged.connect(self._on_result_selected)
|
||||
|
||||
left_layout.addWidget(self.results_table)
|
||||
left_container.setLayout(left_layout)
|
||||
|
||||
# Right pane: preview canvas and controls
|
||||
right_container = QWidget()
|
||||
right_layout = QVBoxLayout()
|
||||
right_layout.setContentsMargins(0, 0, 0, 0)
|
||||
|
||||
preview_group = QGroupBox("Detection Preview")
|
||||
preview_layout = QVBoxLayout()
|
||||
|
||||
self.preview_canvas = AnnotationCanvasWidget()
|
||||
self.preview_canvas.set_polyline_enabled(False)
|
||||
self.preview_canvas.set_show_bboxes(True)
|
||||
preview_layout.addWidget(self.preview_canvas)
|
||||
|
||||
toggles_layout = QHBoxLayout()
|
||||
self.show_masks_checkbox = QCheckBox("Show Masks")
|
||||
self.show_masks_checkbox.setChecked(False)
|
||||
self.show_masks_checkbox.stateChanged.connect(self._apply_detection_overlays)
|
||||
self.show_bboxes_checkbox = QCheckBox("Show Bounding Boxes")
|
||||
self.show_bboxes_checkbox.setChecked(True)
|
||||
self.show_bboxes_checkbox.stateChanged.connect(self._toggle_bboxes)
|
||||
toggles_layout.addWidget(self.show_masks_checkbox)
|
||||
toggles_layout.addWidget(self.show_bboxes_checkbox)
|
||||
toggles_layout.addStretch()
|
||||
preview_layout.addLayout(toggles_layout)
|
||||
|
||||
self.summary_label = QLabel("Select a detection result to preview.")
|
||||
self.summary_label.setWordWrap(True)
|
||||
preview_layout.addWidget(self.summary_label)
|
||||
|
||||
preview_group.setLayout(preview_layout)
|
||||
right_layout.addWidget(preview_group)
|
||||
right_container.setLayout(right_layout)
|
||||
|
||||
splitter.addWidget(left_container)
|
||||
splitter.addWidget(right_container)
|
||||
splitter.setStretchFactor(0, 1)
|
||||
splitter.setStretchFactor(1, 2)
|
||||
|
||||
layout.addWidget(splitter)
|
||||
self.setLayout(layout)
|
||||
|
||||
def refresh(self):
|
||||
"""Refresh the tab."""
|
||||
pass
|
||||
"""Refresh the detection list and preview."""
|
||||
self._load_detection_summary()
|
||||
self._populate_results_table()
|
||||
self.current_selection = None
|
||||
self.current_image = None
|
||||
self.current_detections = []
|
||||
self.preview_canvas.clear()
|
||||
self.summary_label.setText("Select a detection result to preview.")
|
||||
|
||||
def _load_detection_summary(self):
|
||||
"""Load latest detection summaries grouped by image + model."""
|
||||
try:
|
||||
detections = self.db_manager.get_detections(limit=500)
|
||||
summary_map: Dict[tuple, Dict] = {}
|
||||
|
||||
for det in detections:
|
||||
key = (det["image_id"], det["model_id"])
|
||||
metadata = det.get("metadata") or {}
|
||||
entry = summary_map.setdefault(
|
||||
key,
|
||||
{
|
||||
"image_id": det["image_id"],
|
||||
"model_id": det["model_id"],
|
||||
"image_path": det.get("image_path"),
|
||||
"image_filename": det.get("image_filename")
|
||||
or det.get("image_path"),
|
||||
"model_name": det.get("model_name", ""),
|
||||
"model_version": det.get("model_version", ""),
|
||||
"last_detected": det.get("detected_at"),
|
||||
"count": 0,
|
||||
"classes": set(),
|
||||
"source_path": metadata.get("source_path"),
|
||||
"repository_root": metadata.get("repository_root"),
|
||||
},
|
||||
)
|
||||
|
||||
entry["count"] += 1
|
||||
if det.get("detected_at") and (
|
||||
not entry.get("last_detected")
|
||||
or str(det.get("detected_at")) > str(entry.get("last_detected"))
|
||||
):
|
||||
entry["last_detected"] = det.get("detected_at")
|
||||
if det.get("class_name"):
|
||||
entry["classes"].add(det["class_name"])
|
||||
if metadata.get("source_path") and not entry.get("source_path"):
|
||||
entry["source_path"] = metadata.get("source_path")
|
||||
if metadata.get("repository_root") and not entry.get("repository_root"):
|
||||
entry["repository_root"] = metadata.get("repository_root")
|
||||
|
||||
self.detection_summary = sorted(
|
||||
summary_map.values(),
|
||||
key=lambda x: str(x.get("last_detected") or ""),
|
||||
reverse=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load detection summary: {e}")
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Error",
|
||||
f"Failed to load detection results:\n{str(e)}",
|
||||
)
|
||||
self.detection_summary = []
|
||||
|
||||
def _populate_results_table(self):
|
||||
"""Populate the table widget with detection summaries."""
|
||||
self.results_table.setRowCount(len(self.detection_summary))
|
||||
|
||||
for row, entry in enumerate(self.detection_summary):
|
||||
model_label = f"{entry['model_name']} {entry['model_version']}".strip()
|
||||
class_list = (
|
||||
", ".join(sorted(entry["classes"])) if entry["classes"] else "-"
|
||||
)
|
||||
|
||||
items = [
|
||||
QTableWidgetItem(entry.get("image_filename", "")),
|
||||
QTableWidgetItem(model_label),
|
||||
QTableWidgetItem(str(entry.get("count", 0))),
|
||||
QTableWidgetItem(class_list),
|
||||
QTableWidgetItem(str(entry.get("last_detected") or "")),
|
||||
]
|
||||
|
||||
for col, item in enumerate(items):
|
||||
item.setData(Qt.UserRole, row)
|
||||
self.results_table.setItem(row, col, item)
|
||||
|
||||
self.results_table.clearSelection()
|
||||
|
||||
def _on_result_selected(self):
|
||||
"""Handle selection changes in the detection table."""
|
||||
selected_items = self.results_table.selectedItems()
|
||||
if not selected_items:
|
||||
return
|
||||
|
||||
row = selected_items[0].data(Qt.UserRole)
|
||||
if row is None or row >= len(self.detection_summary):
|
||||
return
|
||||
|
||||
entry = self.detection_summary[row]
|
||||
if (
|
||||
self.current_selection
|
||||
and self.current_selection.get("image_id") == entry["image_id"]
|
||||
and self.current_selection.get("model_id") == entry["model_id"]
|
||||
):
|
||||
return
|
||||
|
||||
self.current_selection = entry
|
||||
|
||||
image_path = self._resolve_image_path(entry)
|
||||
if not image_path:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"Image Not Found",
|
||||
"Unable to locate the image file for this detection.",
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
self.current_image = Image(image_path)
|
||||
self.preview_canvas.load_image(self.current_image)
|
||||
except ImageLoadError as e:
|
||||
logger.error(f"Failed to load image '{image_path}': {e}")
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Image Error",
|
||||
f"Failed to load image for preview:\n{str(e)}",
|
||||
)
|
||||
return
|
||||
|
||||
self._load_detections_for_selection(entry)
|
||||
self._apply_detection_overlays()
|
||||
self._update_summary_label(entry)
|
||||
|
||||
def _load_detections_for_selection(self, entry: Dict):
|
||||
"""Load detection records for the selected image/model pair."""
|
||||
self.current_detections = []
|
||||
if not entry:
|
||||
return
|
||||
|
||||
try:
|
||||
filters = {"image_id": entry["image_id"], "model_id": entry["model_id"]}
|
||||
self.current_detections = self.db_manager.get_detections(filters)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load detections for preview: {e}")
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Error",
|
||||
f"Failed to load detections for this image:\n{str(e)}",
|
||||
)
|
||||
self.current_detections = []
|
||||
|
||||
def _apply_detection_overlays(self):
|
||||
"""Draw detections onto the preview canvas based on current toggles."""
|
||||
self.preview_canvas.clear_annotations()
|
||||
self.preview_canvas.set_show_bboxes(self.show_bboxes_checkbox.isChecked())
|
||||
|
||||
if not self.current_detections or not self.current_image:
|
||||
return
|
||||
|
||||
for det in self.current_detections:
|
||||
color = self._get_class_color(det.get("class_name"))
|
||||
|
||||
if self.show_masks_checkbox.isChecked() and det.get("segmentation_mask"):
|
||||
mask_points = self._convert_mask(det["segmentation_mask"])
|
||||
if mask_points:
|
||||
self.preview_canvas.draw_saved_polyline(mask_points, color)
|
||||
|
||||
bbox = [
|
||||
det.get("x_min"),
|
||||
det.get("y_min"),
|
||||
det.get("x_max"),
|
||||
det.get("y_max"),
|
||||
]
|
||||
if all(v is not None for v in bbox):
|
||||
self.preview_canvas.draw_saved_bbox(bbox, color)
|
||||
|
||||
def _convert_mask(self, mask_points: List[List[float]]) -> List[List[float]]:
|
||||
"""Convert stored [x, y] masks to [y, x] format for the canvas."""
|
||||
converted = []
|
||||
for point in mask_points:
|
||||
if len(point) >= 2:
|
||||
x, y = point[0], point[1]
|
||||
converted.append([y, x])
|
||||
return converted
|
||||
|
||||
def _toggle_bboxes(self):
|
||||
"""Update bounding box visibility on the canvas."""
|
||||
self.preview_canvas.set_show_bboxes(self.show_bboxes_checkbox.isChecked())
|
||||
# Re-render to respect show/hide when toggled
|
||||
self._apply_detection_overlays()
|
||||
|
||||
def _update_summary_label(self, entry: Dict):
|
||||
"""Display textual summary for the selected detection run."""
|
||||
classes = ", ".join(sorted(entry.get("classes", []))) or "-"
|
||||
summary_text = (
|
||||
f"Image: {entry.get('image_filename', 'unknown')}\n"
|
||||
f"Model: {entry.get('model_name', '')} {entry.get('model_version', '')}\n"
|
||||
f"Detections: {entry.get('count', 0)}\n"
|
||||
f"Classes: {classes}\n"
|
||||
f"Last Updated: {entry.get('last_detected', 'n/a')}"
|
||||
)
|
||||
self.summary_label.setText(summary_text)
|
||||
|
||||
def _resolve_image_path(self, entry: Dict) -> Optional[str]:
|
||||
"""Resolve an image path using metadata, cache, and repository hints."""
|
||||
relative_path = entry.get("image_path") if entry else None
|
||||
cache_key = relative_path or entry.get("source_path")
|
||||
if cache_key and cache_key in self._image_path_cache:
|
||||
cached = Path(self._image_path_cache[cache_key])
|
||||
if cached.exists():
|
||||
return self._image_path_cache[cache_key]
|
||||
del self._image_path_cache[cache_key]
|
||||
|
||||
candidates = []
|
||||
source_path = entry.get("source_path") if entry else None
|
||||
if source_path:
|
||||
candidates.append(Path(source_path))
|
||||
|
||||
repo_roots = []
|
||||
if entry.get("repository_root"):
|
||||
repo_roots.append(entry["repository_root"])
|
||||
config_repo = self.config_manager.get_image_repository_path()
|
||||
if config_repo:
|
||||
repo_roots.append(config_repo)
|
||||
|
||||
for root in repo_roots:
|
||||
if relative_path:
|
||||
candidates.append(Path(root) / relative_path)
|
||||
|
||||
if relative_path:
|
||||
candidates.append(Path(relative_path))
|
||||
|
||||
for candidate in candidates:
|
||||
try:
|
||||
if candidate and candidate.exists():
|
||||
resolved = str(candidate.resolve())
|
||||
if cache_key:
|
||||
self._image_path_cache[cache_key] = resolved
|
||||
return resolved
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Fallback: search by filename in known roots
|
||||
filename = Path(relative_path).name if relative_path else None
|
||||
if filename:
|
||||
search_roots = [Path(root) for root in repo_roots if root]
|
||||
if not search_roots:
|
||||
search_roots = [Path("data")]
|
||||
match = self._search_in_roots(filename, search_roots)
|
||||
if match:
|
||||
resolved = str(match.resolve())
|
||||
if cache_key:
|
||||
self._image_path_cache[cache_key] = resolved
|
||||
return resolved
|
||||
|
||||
return None
|
||||
|
||||
def _search_in_roots(self, filename: str, roots: List[Path]) -> Optional[Path]:
|
||||
"""Search for a file name within a list of root directories."""
|
||||
for root in roots:
|
||||
try:
|
||||
if not root.exists():
|
||||
continue
|
||||
for candidate in root.rglob(filename):
|
||||
return candidate
|
||||
except Exception as e:
|
||||
logger.debug(f"Error searching for {filename} in {root}: {e}")
|
||||
return None
|
||||
|
||||
def _get_class_color(self, class_name: Optional[str]) -> str:
|
||||
"""Return consistent color hex for a class name."""
|
||||
if not class_name:
|
||||
return "#FF6B6B"
|
||||
|
||||
color_map = self.config_manager.get_bbox_colors()
|
||||
if class_name in color_map:
|
||||
return color_map[class_name]
|
||||
|
||||
# Deterministic fallback color based on hash
|
||||
palette = [
|
||||
"#FF6B6B",
|
||||
"#4ECDC4",
|
||||
"#FFD166",
|
||||
"#1D3557",
|
||||
"#F4A261",
|
||||
"#E76F51",
|
||||
]
|
||||
return palette[hash(class_name) % len(palette)]
|
||||
|
||||
@@ -263,7 +263,7 @@ class AnnotationCanvasWidget(QWidget):
|
||||
height,
|
||||
bytes_per_line,
|
||||
self.current_image.qtimage_format,
|
||||
)
|
||||
).copy() # Copy so Qt owns the buffer even after numpy array goes out of scope
|
||||
|
||||
self.original_pixmap = QPixmap.fromImage(qimage)
|
||||
|
||||
|
||||
@@ -137,7 +137,7 @@ class ImageDisplayWidget(QWidget):
|
||||
height,
|
||||
bytes_per_line,
|
||||
self.current_image.qtimage_format,
|
||||
)
|
||||
).copy() # Copy to ensure Qt owns its memory after this scope
|
||||
|
||||
# Convert to pixmap
|
||||
pixmap = QPixmap.fromImage(qimage)
|
||||
|
||||
@@ -42,6 +42,7 @@ class InferenceEngine:
|
||||
relative_path: str,
|
||||
conf: float = 0.25,
|
||||
save_to_db: bool = True,
|
||||
repository_root: Optional[str] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Detect objects in a single image.
|
||||
@@ -51,11 +52,17 @@ class InferenceEngine:
|
||||
relative_path: Relative path from repository root
|
||||
conf: Confidence threshold
|
||||
save_to_db: Whether to save results to database
|
||||
repository_root: Base directory used to compute relative_path (if known)
|
||||
|
||||
Returns:
|
||||
Dictionary with detection results
|
||||
"""
|
||||
try:
|
||||
# Normalize storage path (fall back to absolute path when repo root is unknown)
|
||||
stored_relative_path = relative_path
|
||||
if not repository_root:
|
||||
stored_relative_path = str(Path(image_path).resolve())
|
||||
|
||||
# Get image dimensions
|
||||
img = Image.open(image_path)
|
||||
width, height = img.size
|
||||
@@ -66,34 +73,58 @@ class InferenceEngine:
|
||||
|
||||
# Add/get image in database
|
||||
image_id = self.db_manager.get_or_create_image(
|
||||
relative_path=relative_path,
|
||||
relative_path=stored_relative_path,
|
||||
filename=Path(image_path).name,
|
||||
width=width,
|
||||
height=height,
|
||||
)
|
||||
|
||||
# Save detections to database
|
||||
if save_to_db and detections:
|
||||
detection_records = []
|
||||
for det in detections:
|
||||
# Use normalized bbox from detection
|
||||
bbox_normalized = det[
|
||||
"bbox_normalized"
|
||||
] # [x_min, y_min, x_max, y_max]
|
||||
inserted_count = 0
|
||||
deleted_count = 0
|
||||
|
||||
record = {
|
||||
"image_id": image_id,
|
||||
"model_id": self.model_id,
|
||||
"class_name": det["class_name"],
|
||||
"bbox": tuple(bbox_normalized),
|
||||
"confidence": det["confidence"],
|
||||
"segmentation_mask": det.get("segmentation_mask"),
|
||||
"metadata": {"class_id": det["class_id"]},
|
||||
}
|
||||
detection_records.append(record)
|
||||
# Save detections to database, replacing any previous results for this image/model
|
||||
if save_to_db:
|
||||
deleted_count = self.db_manager.delete_detections_for_image(
|
||||
image_id, self.model_id
|
||||
)
|
||||
if detections:
|
||||
detection_records = []
|
||||
for det in detections:
|
||||
# Use normalized bbox from detection
|
||||
bbox_normalized = det[
|
||||
"bbox_normalized"
|
||||
] # [x_min, y_min, x_max, y_max]
|
||||
|
||||
self.db_manager.add_detections_batch(detection_records)
|
||||
logger.info(f"Saved {len(detection_records)} detections to database")
|
||||
metadata = {
|
||||
"class_id": det["class_id"],
|
||||
"source_path": str(Path(image_path).resolve()),
|
||||
}
|
||||
if repository_root:
|
||||
metadata["repository_root"] = str(
|
||||
Path(repository_root).resolve()
|
||||
)
|
||||
|
||||
record = {
|
||||
"image_id": image_id,
|
||||
"model_id": self.model_id,
|
||||
"class_name": det["class_name"],
|
||||
"bbox": tuple(bbox_normalized),
|
||||
"confidence": det["confidence"],
|
||||
"segmentation_mask": det.get("segmentation_mask"),
|
||||
"metadata": metadata,
|
||||
}
|
||||
detection_records.append(record)
|
||||
|
||||
inserted_count = self.db_manager.add_detections_batch(
|
||||
detection_records
|
||||
)
|
||||
logger.info(
|
||||
f"Saved {inserted_count} detections to database (replaced {deleted_count})"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Detection run removed {deleted_count} stale entries but produced no new detections"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
@@ -142,7 +173,12 @@ class InferenceEngine:
|
||||
rel_path = get_relative_path(image_path, repository_root)
|
||||
|
||||
# Perform detection
|
||||
result = self.detect_single(image_path, rel_path, conf)
|
||||
result = self.detect_single(
|
||||
image_path,
|
||||
rel_path,
|
||||
conf=conf,
|
||||
repository_root=repository_root,
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
# Update progress
|
||||
|
||||
@@ -7,6 +7,9 @@ from ultralytics import YOLO
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Callable, Any
|
||||
import torch
|
||||
from PIL import Image
|
||||
import tempfile
|
||||
import os
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
|
||||
@@ -162,10 +165,12 @@ class YOLOWrapper:
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
|
||||
prepared_source, cleanup_path = self._prepare_source(source)
|
||||
|
||||
try:
|
||||
logger.info(f"Running inference on {source}")
|
||||
results = self.model.predict(
|
||||
source=source,
|
||||
source=prepared_source,
|
||||
conf=conf,
|
||||
iou=iou,
|
||||
save=save,
|
||||
@@ -182,6 +187,14 @@ class YOLOWrapper:
|
||||
except Exception as e:
|
||||
logger.error(f"Error during inference: {e}")
|
||||
raise
|
||||
finally:
|
||||
if cleanup_path:
|
||||
try:
|
||||
os.remove(cleanup_path)
|
||||
except OSError as cleanup_error:
|
||||
logger.warning(
|
||||
f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}"
|
||||
)
|
||||
|
||||
def export(
|
||||
self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
|
||||
@@ -210,6 +223,36 @@ class YOLOWrapper:
|
||||
logger.error(f"Error exporting model: {e}")
|
||||
raise
|
||||
|
||||
def _prepare_source(self, source):
|
||||
"""Convert single-channel images to RGB temporarily for inference."""
|
||||
cleanup_path = None
|
||||
|
||||
if isinstance(source, (str, Path)):
|
||||
source_path = Path(source)
|
||||
if source_path.is_file():
|
||||
try:
|
||||
with Image.open(source_path) as img:
|
||||
if len(img.getbands()) == 1:
|
||||
rgb_img = img.convert("RGB")
|
||||
suffix = source_path.suffix or ".png"
|
||||
tmp = tempfile.NamedTemporaryFile(
|
||||
suffix=suffix, delete=False
|
||||
)
|
||||
tmp_path = tmp.name
|
||||
tmp.close()
|
||||
rgb_img.save(tmp_path)
|
||||
cleanup_path = tmp_path
|
||||
logger.info(
|
||||
f"Converted single-channel image {source_path} to RGB for inference"
|
||||
)
|
||||
return tmp_path, cleanup_path
|
||||
except Exception as convert_error:
|
||||
logger.warning(
|
||||
f"Failed to preprocess {source_path} as RGB, continuing with original file: {convert_error}"
|
||||
)
|
||||
|
||||
return source, cleanup_path
|
||||
|
||||
def _format_training_results(self, results) -> Dict[str, Any]:
|
||||
"""Format training results into dictionary."""
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user