Adding result shower

This commit is contained in:
2025-12-10 16:55:28 +02:00
parent 5370d31dce
commit 833b222fad
7 changed files with 672 additions and 71 deletions

View File

@@ -450,6 +450,25 @@ class DatabaseManager:
filters["model_id"] = model_id filters["model_id"] = model_id
return self.get_detections(filters) return self.get_detections(filters)
def delete_detections_for_image(
self, image_id: int, model_id: Optional[int] = None
) -> int:
"""Delete detections tied to a specific image and optional model."""
conn = self.get_connection()
try:
cursor = conn.cursor()
if model_id is not None:
cursor.execute(
"DELETE FROM detections WHERE image_id = ? AND model_id = ?",
(image_id, model_id),
)
else:
cursor.execute("DELETE FROM detections WHERE image_id = ?", (image_id,))
conn.commit()
return cursor.rowcount
finally:
conn.close()
def delete_detections_for_model(self, model_id: int) -> int: def delete_detections_for_model(self, model_id: int) -> int:
"""Delete all detections for a specific model.""" """Delete all detections for a specific model."""
conn = self.get_connection() conn = self.get_connection()

View File

@@ -20,6 +20,7 @@ from PySide6.QtWidgets import (
) )
from PySide6.QtCore import Qt, QThread, Signal from PySide6.QtCore import Qt, QThread, Signal
from pathlib import Path from pathlib import Path
from typing import Optional
from src.database.db_manager import DatabaseManager from src.database.db_manager import DatabaseManager
from src.utils.config_manager import ConfigManager from src.utils.config_manager import ConfigManager
@@ -147,30 +148,66 @@ class DetectionTab(QWidget):
self.model_combo.currentIndexChanged.connect(self._on_model_changed) self.model_combo.currentIndexChanged.connect(self._on_model_changed)
def _load_models(self): def _load_models(self):
"""Load available models from database.""" """Load available models from database and local storage."""
try: try:
models = self.db_manager.get_models()
self.model_combo.clear() self.model_combo.clear()
models = self.db_manager.get_models()
has_models = False
if not models: known_paths = set()
self.model_combo.addItem("No models available", None)
self._set_buttons_enabled(False)
return
# Add base model option # Add base model option first (always available)
base_model = self.config_manager.get( base_model = self.config_manager.get(
"models.default_base_model", "yolov8s-seg.pt" "models.default_base_model", "yolov8s-seg.pt"
) )
self.model_combo.addItem( if base_model:
f"Base Model ({base_model})", {"id": 0, "path": base_model} base_data = {
) "id": 0,
"path": base_model,
"model_name": Path(base_model).stem or "Base Model",
"model_version": "pretrained",
"base_model": base_model,
"source": "base",
}
self.model_combo.addItem(f"Base Model ({base_model})", base_data)
known_paths.add(self._normalize_model_path(base_model))
has_models = True
# Add trained models # Add trained models from database
for model in models: for model in models:
display_name = f"{model['model_name']} v{model['model_version']}" display_name = f"{model['model_name']} v{model['model_version']}"
self.model_combo.addItem(display_name, model) model_data = {**model, "path": model.get("model_path")}
normalized = self._normalize_model_path(model_data.get("path"))
if normalized:
known_paths.add(normalized)
self.model_combo.addItem(display_name, model_data)
has_models = True
self._set_buttons_enabled(True) # Discover local model files not yet in the database
local_models = self._discover_local_models()
for model_path in local_models:
normalized = self._normalize_model_path(model_path)
if normalized in known_paths:
continue
display_name = f"Local Model ({Path(model_path).stem})"
model_data = {
"id": None,
"path": str(model_path),
"model_name": Path(model_path).stem,
"model_version": "local",
"base_model": Path(model_path).stem,
"source": "local",
}
self.model_combo.addItem(display_name, model_data)
known_paths.add(normalized)
has_models = True
if not has_models:
self.model_combo.addItem("No models available", None)
self._set_buttons_enabled(False)
else:
self._set_buttons_enabled(True)
except Exception as e: except Exception as e:
logger.error(f"Error loading models: {e}") logger.error(f"Error loading models: {e}")
@@ -249,25 +286,39 @@ class DetectionTab(QWidget):
QMessageBox.warning(self, "No Model", "Please select a model first.") QMessageBox.warning(self, "No Model", "Please select a model first.")
return return
model_path = model_data["path"] model_path = model_data.get("path")
model_id = model_data["id"] if not model_path:
QMessageBox.warning(
self, "Invalid Model", "Selected model is missing a file path."
)
return
# Ensure we have a valid model ID (create entry for base model if needed) if not Path(model_path).exists():
if model_id == 0: QMessageBox.critical(
# Create database entry for base model self,
base_model = self.config_manager.get( "Model Not Found",
"models.default_base_model", "yolov8s-seg.pt" f"The selected model file could not be found:\n{model_path}",
)
model_id = self.db_manager.add_model(
model_name="Base Model",
model_version="pretrained",
model_path=base_model,
base_model=base_model,
) )
return
model_id = model_data.get("id")
# Ensure we have a database entry for the selected model
if model_id in (None, 0):
model_id = self._ensure_model_record(model_data)
if not model_id:
QMessageBox.critical(
self,
"Model Registration Failed",
"Unable to register the selected model in the database.",
)
return
normalized_model_path = self._normalize_model_path(model_path) or model_path
# Create inference engine # Create inference engine
self.inference_engine = InferenceEngine( self.inference_engine = InferenceEngine(
model_path, self.db_manager, model_id normalized_model_path, self.db_manager, model_id
) )
# Get confidence threshold # Get confidence threshold
@@ -338,6 +389,76 @@ class DetectionTab(QWidget):
self.batch_btn.setEnabled(enabled) self.batch_btn.setEnabled(enabled)
self.model_combo.setEnabled(enabled) self.model_combo.setEnabled(enabled)
def _discover_local_models(self) -> list:
"""Scan the models directory for standalone .pt files."""
models_dir = self.config_manager.get_models_directory()
if not models_dir:
return []
models_path = Path(models_dir)
if not models_path.exists():
return []
try:
return sorted(
[p for p in models_path.rglob("*.pt") if p.is_file()],
key=lambda p: str(p).lower(),
)
except Exception as e:
logger.warning(f"Error discovering local models: {e}")
return []
def _normalize_model_path(self, path_value) -> str:
"""Return a normalized absolute path string for comparison."""
if not path_value:
return ""
try:
return str(Path(path_value).resolve())
except Exception:
return str(path_value)
def _ensure_model_record(self, model_data: dict) -> Optional[int]:
"""Ensure a database record exists for the selected model."""
model_path = model_data.get("path")
if not model_path:
return None
normalized_target = self._normalize_model_path(model_path)
try:
existing_models = self.db_manager.get_models()
for model in existing_models:
existing_path = model.get("model_path")
if not existing_path:
continue
normalized_existing = self._normalize_model_path(existing_path)
if (
normalized_existing == normalized_target
or existing_path == model_path
):
return model["id"]
model_name = (
model_data.get("model_name") or Path(model_path).stem or "Custom Model"
)
model_version = (
model_data.get("model_version") or model_data.get("source") or "local"
)
base_model = model_data.get(
"base_model",
self.config_manager.get("models.default_base_model", "yolov8s-seg.pt"),
)
return self.db_manager.add_model(
model_name=model_name,
model_version=model_version,
model_path=normalized_target,
base_model=base_model,
)
except Exception as e:
logger.error(f"Failed to ensure model record for {model_path}: {e}")
return None
def refresh(self): def refresh(self):
"""Refresh the tab.""" """Refresh the tab."""
self._load_models() self._load_models()

View File

@@ -1,15 +1,39 @@
""" """
Results tab for the microscopy object detection application. Results tab for browsing stored detections and visualizing overlays.
""" """
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox from pathlib import Path
from typing import Dict, List, Optional
from PySide6.QtWidgets import (
QWidget,
QVBoxLayout,
QHBoxLayout,
QLabel,
QGroupBox,
QPushButton,
QSplitter,
QTableWidget,
QTableWidgetItem,
QHeaderView,
QAbstractItemView,
QMessageBox,
QCheckBox,
)
from PySide6.QtCore import Qt
from src.database.db_manager import DatabaseManager from src.database.db_manager import DatabaseManager
from src.utils.config_manager import ConfigManager from src.utils.config_manager import ConfigManager
from src.utils.logger import get_logger
from src.utils.image import Image, ImageLoadError
from src.gui.widgets import AnnotationCanvasWidget
logger = get_logger(__name__)
class ResultsTab(QWidget): class ResultsTab(QWidget):
"""Results tab placeholder.""" """Results tab showing detection history and preview overlays."""
def __init__( def __init__(
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
@@ -18,29 +42,387 @@ class ResultsTab(QWidget):
self.db_manager = db_manager self.db_manager = db_manager
self.config_manager = config_manager self.config_manager = config_manager
self.detection_summary: List[Dict] = []
self.current_selection: Optional[Dict] = None
self.current_image: Optional[Image] = None
self.current_detections: List[Dict] = []
self._image_path_cache: Dict[str, str] = {}
self._setup_ui() self._setup_ui()
self.refresh()
def _setup_ui(self): def _setup_ui(self):
"""Setup user interface.""" """Setup user interface."""
layout = QVBoxLayout() layout = QVBoxLayout()
group = QGroupBox("Results") # Splitter for list + preview
group_layout = QVBoxLayout() splitter = QSplitter(Qt.Horizontal)
label = QLabel(
"Results viewer will be implemented here.\n\n"
"Features:\n"
"- Detection history browser\n"
"- Advanced filtering\n"
"- Statistics dashboard\n"
"- Export functionality"
)
group_layout.addWidget(label)
group.setLayout(group_layout)
layout.addWidget(group) # Left pane: detection list
layout.addStretch() left_container = QWidget()
left_layout = QVBoxLayout()
left_layout.setContentsMargins(0, 0, 0, 0)
controls_layout = QHBoxLayout()
self.refresh_btn = QPushButton("Refresh")
self.refresh_btn.clicked.connect(self.refresh)
controls_layout.addWidget(self.refresh_btn)
controls_layout.addStretch()
left_layout.addLayout(controls_layout)
self.results_table = QTableWidget(0, 5)
self.results_table.setHorizontalHeaderLabels(
["Image", "Model", "Detections", "Classes", "Last Updated"]
)
self.results_table.horizontalHeader().setSectionResizeMode(
0, QHeaderView.Stretch
)
self.results_table.horizontalHeader().setSectionResizeMode(
1, QHeaderView.Stretch
)
self.results_table.horizontalHeader().setSectionResizeMode(
2, QHeaderView.ResizeToContents
)
self.results_table.horizontalHeader().setSectionResizeMode(
3, QHeaderView.Stretch
)
self.results_table.horizontalHeader().setSectionResizeMode(
4, QHeaderView.ResizeToContents
)
self.results_table.setSelectionBehavior(QAbstractItemView.SelectRows)
self.results_table.setSelectionMode(QAbstractItemView.SingleSelection)
self.results_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
self.results_table.itemSelectionChanged.connect(self._on_result_selected)
left_layout.addWidget(self.results_table)
left_container.setLayout(left_layout)
# Right pane: preview canvas and controls
right_container = QWidget()
right_layout = QVBoxLayout()
right_layout.setContentsMargins(0, 0, 0, 0)
preview_group = QGroupBox("Detection Preview")
preview_layout = QVBoxLayout()
self.preview_canvas = AnnotationCanvasWidget()
self.preview_canvas.set_polyline_enabled(False)
self.preview_canvas.set_show_bboxes(True)
preview_layout.addWidget(self.preview_canvas)
toggles_layout = QHBoxLayout()
self.show_masks_checkbox = QCheckBox("Show Masks")
self.show_masks_checkbox.setChecked(False)
self.show_masks_checkbox.stateChanged.connect(self._apply_detection_overlays)
self.show_bboxes_checkbox = QCheckBox("Show Bounding Boxes")
self.show_bboxes_checkbox.setChecked(True)
self.show_bboxes_checkbox.stateChanged.connect(self._toggle_bboxes)
toggles_layout.addWidget(self.show_masks_checkbox)
toggles_layout.addWidget(self.show_bboxes_checkbox)
toggles_layout.addStretch()
preview_layout.addLayout(toggles_layout)
self.summary_label = QLabel("Select a detection result to preview.")
self.summary_label.setWordWrap(True)
preview_layout.addWidget(self.summary_label)
preview_group.setLayout(preview_layout)
right_layout.addWidget(preview_group)
right_container.setLayout(right_layout)
splitter.addWidget(left_container)
splitter.addWidget(right_container)
splitter.setStretchFactor(0, 1)
splitter.setStretchFactor(1, 2)
layout.addWidget(splitter)
self.setLayout(layout) self.setLayout(layout)
def refresh(self): def refresh(self):
"""Refresh the tab.""" """Refresh the detection list and preview."""
pass self._load_detection_summary()
self._populate_results_table()
self.current_selection = None
self.current_image = None
self.current_detections = []
self.preview_canvas.clear()
self.summary_label.setText("Select a detection result to preview.")
def _load_detection_summary(self):
"""Load latest detection summaries grouped by image + model."""
try:
detections = self.db_manager.get_detections(limit=500)
summary_map: Dict[tuple, Dict] = {}
for det in detections:
key = (det["image_id"], det["model_id"])
metadata = det.get("metadata") or {}
entry = summary_map.setdefault(
key,
{
"image_id": det["image_id"],
"model_id": det["model_id"],
"image_path": det.get("image_path"),
"image_filename": det.get("image_filename")
or det.get("image_path"),
"model_name": det.get("model_name", ""),
"model_version": det.get("model_version", ""),
"last_detected": det.get("detected_at"),
"count": 0,
"classes": set(),
"source_path": metadata.get("source_path"),
"repository_root": metadata.get("repository_root"),
},
)
entry["count"] += 1
if det.get("detected_at") and (
not entry.get("last_detected")
or str(det.get("detected_at")) > str(entry.get("last_detected"))
):
entry["last_detected"] = det.get("detected_at")
if det.get("class_name"):
entry["classes"].add(det["class_name"])
if metadata.get("source_path") and not entry.get("source_path"):
entry["source_path"] = metadata.get("source_path")
if metadata.get("repository_root") and not entry.get("repository_root"):
entry["repository_root"] = metadata.get("repository_root")
self.detection_summary = sorted(
summary_map.values(),
key=lambda x: str(x.get("last_detected") or ""),
reverse=True,
)
except Exception as e:
logger.error(f"Failed to load detection summary: {e}")
QMessageBox.critical(
self,
"Error",
f"Failed to load detection results:\n{str(e)}",
)
self.detection_summary = []
def _populate_results_table(self):
"""Populate the table widget with detection summaries."""
self.results_table.setRowCount(len(self.detection_summary))
for row, entry in enumerate(self.detection_summary):
model_label = f"{entry['model_name']} {entry['model_version']}".strip()
class_list = (
", ".join(sorted(entry["classes"])) if entry["classes"] else "-"
)
items = [
QTableWidgetItem(entry.get("image_filename", "")),
QTableWidgetItem(model_label),
QTableWidgetItem(str(entry.get("count", 0))),
QTableWidgetItem(class_list),
QTableWidgetItem(str(entry.get("last_detected") or "")),
]
for col, item in enumerate(items):
item.setData(Qt.UserRole, row)
self.results_table.setItem(row, col, item)
self.results_table.clearSelection()
def _on_result_selected(self):
"""Handle selection changes in the detection table."""
selected_items = self.results_table.selectedItems()
if not selected_items:
return
row = selected_items[0].data(Qt.UserRole)
if row is None or row >= len(self.detection_summary):
return
entry = self.detection_summary[row]
if (
self.current_selection
and self.current_selection.get("image_id") == entry["image_id"]
and self.current_selection.get("model_id") == entry["model_id"]
):
return
self.current_selection = entry
image_path = self._resolve_image_path(entry)
if not image_path:
QMessageBox.warning(
self,
"Image Not Found",
"Unable to locate the image file for this detection.",
)
return
try:
self.current_image = Image(image_path)
self.preview_canvas.load_image(self.current_image)
except ImageLoadError as e:
logger.error(f"Failed to load image '{image_path}': {e}")
QMessageBox.critical(
self,
"Image Error",
f"Failed to load image for preview:\n{str(e)}",
)
return
self._load_detections_for_selection(entry)
self._apply_detection_overlays()
self._update_summary_label(entry)
def _load_detections_for_selection(self, entry: Dict):
"""Load detection records for the selected image/model pair."""
self.current_detections = []
if not entry:
return
try:
filters = {"image_id": entry["image_id"], "model_id": entry["model_id"]}
self.current_detections = self.db_manager.get_detections(filters)
except Exception as e:
logger.error(f"Failed to load detections for preview: {e}")
QMessageBox.critical(
self,
"Error",
f"Failed to load detections for this image:\n{str(e)}",
)
self.current_detections = []
def _apply_detection_overlays(self):
"""Draw detections onto the preview canvas based on current toggles."""
self.preview_canvas.clear_annotations()
self.preview_canvas.set_show_bboxes(self.show_bboxes_checkbox.isChecked())
if not self.current_detections or not self.current_image:
return
for det in self.current_detections:
color = self._get_class_color(det.get("class_name"))
if self.show_masks_checkbox.isChecked() and det.get("segmentation_mask"):
mask_points = self._convert_mask(det["segmentation_mask"])
if mask_points:
self.preview_canvas.draw_saved_polyline(mask_points, color)
bbox = [
det.get("x_min"),
det.get("y_min"),
det.get("x_max"),
det.get("y_max"),
]
if all(v is not None for v in bbox):
self.preview_canvas.draw_saved_bbox(bbox, color)
def _convert_mask(self, mask_points: List[List[float]]) -> List[List[float]]:
"""Convert stored [x, y] masks to [y, x] format for the canvas."""
converted = []
for point in mask_points:
if len(point) >= 2:
x, y = point[0], point[1]
converted.append([y, x])
return converted
def _toggle_bboxes(self):
"""Update bounding box visibility on the canvas."""
self.preview_canvas.set_show_bboxes(self.show_bboxes_checkbox.isChecked())
# Re-render to respect show/hide when toggled
self._apply_detection_overlays()
def _update_summary_label(self, entry: Dict):
"""Display textual summary for the selected detection run."""
classes = ", ".join(sorted(entry.get("classes", []))) or "-"
summary_text = (
f"Image: {entry.get('image_filename', 'unknown')}\n"
f"Model: {entry.get('model_name', '')} {entry.get('model_version', '')}\n"
f"Detections: {entry.get('count', 0)}\n"
f"Classes: {classes}\n"
f"Last Updated: {entry.get('last_detected', 'n/a')}"
)
self.summary_label.setText(summary_text)
def _resolve_image_path(self, entry: Dict) -> Optional[str]:
"""Resolve an image path using metadata, cache, and repository hints."""
relative_path = entry.get("image_path") if entry else None
cache_key = relative_path or entry.get("source_path")
if cache_key and cache_key in self._image_path_cache:
cached = Path(self._image_path_cache[cache_key])
if cached.exists():
return self._image_path_cache[cache_key]
del self._image_path_cache[cache_key]
candidates = []
source_path = entry.get("source_path") if entry else None
if source_path:
candidates.append(Path(source_path))
repo_roots = []
if entry.get("repository_root"):
repo_roots.append(entry["repository_root"])
config_repo = self.config_manager.get_image_repository_path()
if config_repo:
repo_roots.append(config_repo)
for root in repo_roots:
if relative_path:
candidates.append(Path(root) / relative_path)
if relative_path:
candidates.append(Path(relative_path))
for candidate in candidates:
try:
if candidate and candidate.exists():
resolved = str(candidate.resolve())
if cache_key:
self._image_path_cache[cache_key] = resolved
return resolved
except Exception:
continue
# Fallback: search by filename in known roots
filename = Path(relative_path).name if relative_path else None
if filename:
search_roots = [Path(root) for root in repo_roots if root]
if not search_roots:
search_roots = [Path("data")]
match = self._search_in_roots(filename, search_roots)
if match:
resolved = str(match.resolve())
if cache_key:
self._image_path_cache[cache_key] = resolved
return resolved
return None
def _search_in_roots(self, filename: str, roots: List[Path]) -> Optional[Path]:
"""Search for a file name within a list of root directories."""
for root in roots:
try:
if not root.exists():
continue
for candidate in root.rglob(filename):
return candidate
except Exception as e:
logger.debug(f"Error searching for {filename} in {root}: {e}")
return None
def _get_class_color(self, class_name: Optional[str]) -> str:
"""Return consistent color hex for a class name."""
if not class_name:
return "#FF6B6B"
color_map = self.config_manager.get_bbox_colors()
if class_name in color_map:
return color_map[class_name]
# Deterministic fallback color based on hash
palette = [
"#FF6B6B",
"#4ECDC4",
"#FFD166",
"#1D3557",
"#F4A261",
"#E76F51",
]
return palette[hash(class_name) % len(palette)]

View File

@@ -263,7 +263,7 @@ class AnnotationCanvasWidget(QWidget):
height, height,
bytes_per_line, bytes_per_line,
self.current_image.qtimage_format, self.current_image.qtimage_format,
) ).copy() # Copy so Qt owns the buffer even after numpy array goes out of scope
self.original_pixmap = QPixmap.fromImage(qimage) self.original_pixmap = QPixmap.fromImage(qimage)

View File

@@ -137,7 +137,7 @@ class ImageDisplayWidget(QWidget):
height, height,
bytes_per_line, bytes_per_line,
self.current_image.qtimage_format, self.current_image.qtimage_format,
) ).copy() # Copy to ensure Qt owns its memory after this scope
# Convert to pixmap # Convert to pixmap
pixmap = QPixmap.fromImage(qimage) pixmap = QPixmap.fromImage(qimage)

View File

@@ -42,6 +42,7 @@ class InferenceEngine:
relative_path: str, relative_path: str,
conf: float = 0.25, conf: float = 0.25,
save_to_db: bool = True, save_to_db: bool = True,
repository_root: Optional[str] = None,
) -> Dict: ) -> Dict:
""" """
Detect objects in a single image. Detect objects in a single image.
@@ -51,11 +52,17 @@ class InferenceEngine:
relative_path: Relative path from repository root relative_path: Relative path from repository root
conf: Confidence threshold conf: Confidence threshold
save_to_db: Whether to save results to database save_to_db: Whether to save results to database
repository_root: Base directory used to compute relative_path (if known)
Returns: Returns:
Dictionary with detection results Dictionary with detection results
""" """
try: try:
# Normalize storage path (fall back to absolute path when repo root is unknown)
stored_relative_path = relative_path
if not repository_root:
stored_relative_path = str(Path(image_path).resolve())
# Get image dimensions # Get image dimensions
img = Image.open(image_path) img = Image.open(image_path)
width, height = img.size width, height = img.size
@@ -66,34 +73,58 @@ class InferenceEngine:
# Add/get image in database # Add/get image in database
image_id = self.db_manager.get_or_create_image( image_id = self.db_manager.get_or_create_image(
relative_path=relative_path, relative_path=stored_relative_path,
filename=Path(image_path).name, filename=Path(image_path).name,
width=width, width=width,
height=height, height=height,
) )
# Save detections to database inserted_count = 0
if save_to_db and detections: deleted_count = 0
detection_records = []
for det in detections:
# Use normalized bbox from detection
bbox_normalized = det[
"bbox_normalized"
] # [x_min, y_min, x_max, y_max]
record = { # Save detections to database, replacing any previous results for this image/model
"image_id": image_id, if save_to_db:
"model_id": self.model_id, deleted_count = self.db_manager.delete_detections_for_image(
"class_name": det["class_name"], image_id, self.model_id
"bbox": tuple(bbox_normalized), )
"confidence": det["confidence"], if detections:
"segmentation_mask": det.get("segmentation_mask"), detection_records = []
"metadata": {"class_id": det["class_id"]}, for det in detections:
} # Use normalized bbox from detection
detection_records.append(record) bbox_normalized = det[
"bbox_normalized"
] # [x_min, y_min, x_max, y_max]
self.db_manager.add_detections_batch(detection_records) metadata = {
logger.info(f"Saved {len(detection_records)} detections to database") "class_id": det["class_id"],
"source_path": str(Path(image_path).resolve()),
}
if repository_root:
metadata["repository_root"] = str(
Path(repository_root).resolve()
)
record = {
"image_id": image_id,
"model_id": self.model_id,
"class_name": det["class_name"],
"bbox": tuple(bbox_normalized),
"confidence": det["confidence"],
"segmentation_mask": det.get("segmentation_mask"),
"metadata": metadata,
}
detection_records.append(record)
inserted_count = self.db_manager.add_detections_batch(
detection_records
)
logger.info(
f"Saved {inserted_count} detections to database (replaced {deleted_count})"
)
else:
logger.info(
f"Detection run removed {deleted_count} stale entries but produced no new detections"
)
return { return {
"success": True, "success": True,
@@ -142,7 +173,12 @@ class InferenceEngine:
rel_path = get_relative_path(image_path, repository_root) rel_path = get_relative_path(image_path, repository_root)
# Perform detection # Perform detection
result = self.detect_single(image_path, rel_path, conf) result = self.detect_single(
image_path,
rel_path,
conf=conf,
repository_root=repository_root,
)
results.append(result) results.append(result)
# Update progress # Update progress

View File

@@ -7,6 +7,9 @@ from ultralytics import YOLO
from pathlib import Path from pathlib import Path
from typing import Optional, List, Dict, Callable, Any from typing import Optional, List, Dict, Callable, Any
import torch import torch
from PIL import Image
import tempfile
import os
from src.utils.logger import get_logger from src.utils.logger import get_logger
@@ -162,10 +165,12 @@ class YOLOWrapper:
if self.model is None: if self.model is None:
self.load_model() self.load_model()
prepared_source, cleanup_path = self._prepare_source(source)
try: try:
logger.info(f"Running inference on {source}") logger.info(f"Running inference on {source}")
results = self.model.predict( results = self.model.predict(
source=source, source=prepared_source,
conf=conf, conf=conf,
iou=iou, iou=iou,
save=save, save=save,
@@ -182,6 +187,14 @@ class YOLOWrapper:
except Exception as e: except Exception as e:
logger.error(f"Error during inference: {e}") logger.error(f"Error during inference: {e}")
raise raise
finally:
if cleanup_path:
try:
os.remove(cleanup_path)
except OSError as cleanup_error:
logger.warning(
f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}"
)
def export( def export(
self, format: str = "onnx", output_path: Optional[str] = None, **kwargs self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
@@ -210,6 +223,36 @@ class YOLOWrapper:
logger.error(f"Error exporting model: {e}") logger.error(f"Error exporting model: {e}")
raise raise
def _prepare_source(self, source):
"""Convert single-channel images to RGB temporarily for inference."""
cleanup_path = None
if isinstance(source, (str, Path)):
source_path = Path(source)
if source_path.is_file():
try:
with Image.open(source_path) as img:
if len(img.getbands()) == 1:
rgb_img = img.convert("RGB")
suffix = source_path.suffix or ".png"
tmp = tempfile.NamedTemporaryFile(
suffix=suffix, delete=False
)
tmp_path = tmp.name
tmp.close()
rgb_img.save(tmp_path)
cleanup_path = tmp_path
logger.info(
f"Converted single-channel image {source_path} to RGB for inference"
)
return tmp_path, cleanup_path
except Exception as convert_error:
logger.warning(
f"Failed to preprocess {source_path} as RGB, continuing with original file: {convert_error}"
)
return source, cleanup_path
def _format_training_results(self, results) -> Dict[str, Any]: def _format_training_results(self, results) -> Dict[str, Any]:
"""Format training results into dictionary.""" """Format training results into dictionary."""
try: try: