Major changes in annotation and results showing/storing

This commit is contained in:
2026-01-23 08:39:33 +02:00
parent 3c8247b3bc
commit 98bc89691b
7 changed files with 370 additions and 77 deletions

View File

@@ -105,6 +105,103 @@ class DatabaseManager:
conn.execute("PRAGMA foreign_keys = ON") # Enable foreign keys
return conn
# ==================== Detection Run Operations ====================
def upsert_detection_run(
self,
image_id: int,
model_id: int,
count: int,
metadata: Optional[Dict] = None,
) -> bool:
"""Insert/update a per-image per-model detection run summary.
This enables the UI to show runs even when zero detections were produced.
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO detection_runs (image_id, model_id, detected_at, count, metadata)
VALUES (?, ?, CURRENT_TIMESTAMP, ?, ?)
ON CONFLICT(image_id, model_id) DO UPDATE SET
detected_at = CURRENT_TIMESTAMP,
count = excluded.count,
metadata = excluded.metadata
""",
(
int(image_id),
int(model_id),
int(count),
json.dumps(metadata) if metadata else None,
),
)
conn.commit()
return True
finally:
conn.close()
def get_detection_run_summaries(self, limit: int = 500, offset: int = 0) -> List[Dict]:
"""Return latest detection run summaries grouped by image+model.
Includes runs with 0 detections.
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"""
SELECT
dr.image_id,
dr.model_id,
dr.detected_at,
dr.count,
dr.metadata,
i.relative_path AS image_path,
i.filename AS image_filename,
m.model_name,
m.model_version,
GROUP_CONCAT(DISTINCT d.class_name) AS classes
FROM detection_runs dr
JOIN images i ON dr.image_id = i.id
JOIN models m ON dr.model_id = m.id
LEFT JOIN detections d
ON d.image_id = dr.image_id AND d.model_id = dr.model_id
GROUP BY dr.image_id, dr.model_id
ORDER BY dr.detected_at DESC
LIMIT ? OFFSET ?
""",
(int(limit), int(offset)),
)
rows: List[Dict] = []
for row in cursor.fetchall():
item = dict(row)
if item.get("metadata"):
try:
item["metadata"] = json.loads(item["metadata"])
except Exception:
item["metadata"] = None
rows.append(item)
return rows
finally:
conn.close()
def get_detection_run_total(self) -> int:
"""Return total number of detection_runs rows."""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) AS cnt FROM detection_runs")
row = cursor.fetchone()
return int(row["cnt"] if row and row["cnt"] is not None else 0)
finally:
conn.close()
# ==================== Model Operations ====================
def add_model(
@@ -527,6 +624,14 @@ class DatabaseManager:
conn = self.get_connection()
try:
cursor = conn.cursor()
# Also clear detection run summaries so the Results tab does not continue
# to show historical runs after detections have been wiped.
try:
cursor.execute("DELETE FROM detection_runs")
except sqlite3.OperationalError:
# Backwards-compatible: table may not exist on older DB files.
pass
cursor.execute("DELETE FROM detections")
conn.commit()
return cursor.rowcount

View File

@@ -45,6 +45,19 @@ CREATE TABLE IF NOT EXISTS detections (
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE
);
-- Detection runs table: stores per-image per-model run summaries (including 0 detections)
CREATE TABLE IF NOT EXISTS detection_runs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
image_id INTEGER NOT NULL,
model_id INTEGER NOT NULL,
detected_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
count INTEGER NOT NULL DEFAULT 0,
metadata TEXT,
UNIQUE(image_id, model_id),
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE,
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE
);
-- Object classes table: stores annotation class definitions with colors
CREATE TABLE IF NOT EXISTS object_classes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -81,6 +94,9 @@ CREATE INDEX IF NOT EXISTS idx_detections_model_id ON detections(model_id);
CREATE INDEX IF NOT EXISTS idx_detections_class_name ON detections(class_name);
CREATE INDEX IF NOT EXISTS idx_detections_detected_at ON detections(detected_at);
CREATE INDEX IF NOT EXISTS idx_detections_confidence ON detections(confidence);
CREATE INDEX IF NOT EXISTS idx_detection_runs_image_id ON detection_runs(image_id);
CREATE INDEX IF NOT EXISTS idx_detection_runs_model_id ON detection_runs(model_id);
CREATE INDEX IF NOT EXISTS idx_detection_runs_detected_at ON detection_runs(detected_at);
CREATE INDEX IF NOT EXISTS idx_images_relative_path ON images(relative_path);
CREATE INDEX IF NOT EXISTS idx_images_added_at ON images(added_at);
CREATE INDEX IF NOT EXISTS idx_images_source ON images(source);

View File

@@ -40,6 +40,12 @@ class ResultsTab(QWidget):
self.db_manager = db_manager
self.config_manager = config_manager
# Pagination
self.page_size = 200
self.current_page = 0 # 0-based
self.total_runs = 0
self.total_pages = 0
self.detection_summary: List[Dict] = []
self.current_selection: Optional[Dict] = None
self.current_image: Optional[Image] = None
@@ -66,6 +72,20 @@ class ResultsTab(QWidget):
self.refresh_btn.clicked.connect(self.refresh)
controls_layout.addWidget(self.refresh_btn)
self.prev_page_btn = QPushButton("◀ Prev")
self.prev_page_btn.setToolTip("Previous page")
self.prev_page_btn.clicked.connect(self._prev_page)
controls_layout.addWidget(self.prev_page_btn)
self.next_page_btn = QPushButton("Next ▶")
self.next_page_btn.setToolTip("Next page")
self.next_page_btn.clicked.connect(self._next_page)
controls_layout.addWidget(self.next_page_btn)
self.page_label = QLabel("Page 1/1")
self.page_label.setMinimumWidth(100)
controls_layout.addWidget(self.page_label)
self.delete_all_btn = QPushButton("Delete All Detections")
self.delete_all_btn.setToolTip(
"Permanently delete ALL detections from the database.\n" "This cannot be undone."
@@ -183,6 +203,8 @@ class ResultsTab(QWidget):
def refresh(self):
"""Refresh the detection list and preview."""
# Reset to first page on refresh.
self.current_page = 0
self._load_detection_summary()
self._populate_results_table()
self.current_selection = None
@@ -193,57 +215,135 @@ class ResultsTab(QWidget):
if hasattr(self, "export_labels_btn"):
self.export_labels_btn.setEnabled(False)
self._update_pagination_controls()
def _prev_page(self):
"""Go to previous results page."""
if self.current_page <= 0:
return
self.current_page -= 1
self._load_detection_summary()
self._populate_results_table()
self._update_pagination_controls()
def _next_page(self):
"""Go to next results page."""
if self.total_pages and self.current_page >= (self.total_pages - 1):
return
self.current_page += 1
self._load_detection_summary()
self._populate_results_table()
self._update_pagination_controls()
def _update_pagination_controls(self):
"""Update pagination label/button enabled state."""
# Default state (safe)
total_pages = max(int(self.total_pages or 0), 1)
current_page = min(max(int(self.current_page or 0), 0), total_pages - 1)
self.current_page = current_page
self.page_label.setText(f"Page {current_page + 1}/{total_pages}")
self.prev_page_btn.setEnabled(current_page > 0)
self.next_page_btn.setEnabled(current_page < (total_pages - 1))
def _load_detection_summary(self):
"""Load latest detection summaries grouped by image + model."""
try:
detections = self.db_manager.get_detections(limit=500)
summary_map: Dict[tuple, Dict] = {}
# Prefer run summaries (supports zero-detection runs). Fall back to legacy
# detection aggregation if the DB/table isn't available.
try:
self.total_runs = int(self.db_manager.get_detection_run_total())
except Exception:
self.total_runs = 0
for det in detections:
key = (det["image_id"], det["model_id"])
metadata = det.get("metadata") or {}
entry = summary_map.setdefault(
key,
self.total_pages = (self.total_runs + self.page_size - 1) // self.page_size if self.total_runs > 0 else 1
offset = int(self.current_page) * int(self.page_size)
runs = self.db_manager.get_detection_run_summaries(
limit=int(self.page_size),
offset=offset,
)
summary: List[Dict] = []
for run in runs:
meta = run.get("metadata") or {}
classes_raw = run.get("classes")
classes = set()
if isinstance(classes_raw, str) and classes_raw.strip():
classes = {c.strip() for c in classes_raw.split(",") if c.strip()}
summary.append(
{
"image_id": det["image_id"],
"model_id": det["model_id"],
"image_path": det.get("image_path"),
"image_filename": det.get("image_filename") or det.get("image_path"),
"model_name": det.get("model_name", ""),
"model_version": det.get("model_version", ""),
"last_detected": det.get("detected_at"),
"count": 0,
"classes": set(),
"source_path": metadata.get("source_path"),
"repository_root": metadata.get("repository_root"),
},
"image_id": run.get("image_id"),
"model_id": run.get("model_id"),
"image_path": run.get("image_path"),
"image_filename": run.get("image_filename") or run.get("image_path"),
"model_name": run.get("model_name", ""),
"model_version": run.get("model_version", ""),
"last_detected": run.get("detected_at"),
"count": int(run.get("count") or 0),
"classes": classes,
"source_path": meta.get("source_path"),
"repository_root": meta.get("repository_root"),
}
)
entry["count"] += 1
if det.get("detected_at") and (
not entry.get("last_detected") or str(det.get("detected_at")) > str(entry.get("last_detected"))
):
entry["last_detected"] = det.get("detected_at")
if det.get("class_name"):
entry["classes"].add(det["class_name"])
if metadata.get("source_path") and not entry.get("source_path"):
entry["source_path"] = metadata.get("source_path")
if metadata.get("repository_root") and not entry.get("repository_root"):
entry["repository_root"] = metadata.get("repository_root")
self.detection_summary = sorted(
summary_map.values(),
key=lambda x: str(x.get("last_detected") or ""),
reverse=True,
)
self.detection_summary = summary
except Exception as e:
logger.error(f"Failed to load detection summary: {e}")
QMessageBox.critical(
self,
"Error",
f"Failed to load detection results:\n{str(e)}",
)
self.detection_summary = []
logger.error(f"Failed to load detection run summaries, falling back: {e}")
# Disable pagination if we can't page via detection_runs.
self.total_runs = 0
self.total_pages = 1
self.current_page = 0
try:
detections = self.db_manager.get_detections(limit=int(self.page_size))
summary_map: Dict[tuple, Dict] = {}
for det in detections:
key = (det["image_id"], det["model_id"])
metadata = det.get("metadata") or {}
entry = summary_map.setdefault(
key,
{
"image_id": det["image_id"],
"model_id": det["model_id"],
"image_path": det.get("image_path"),
"image_filename": det.get("image_filename") or det.get("image_path"),
"model_name": det.get("model_name", ""),
"model_version": det.get("model_version", ""),
"last_detected": det.get("detected_at"),
"count": 0,
"classes": set(),
"source_path": metadata.get("source_path"),
"repository_root": metadata.get("repository_root"),
},
)
entry["count"] += 1
if det.get("detected_at") and (
not entry.get("last_detected") or str(det.get("detected_at")) > str(entry.get("last_detected"))
):
entry["last_detected"] = det.get("detected_at")
if det.get("class_name"):
entry["classes"].add(det["class_name"])
if metadata.get("source_path") and not entry.get("source_path"):
entry["source_path"] = metadata.get("source_path")
if metadata.get("repository_root") and not entry.get("repository_root"):
entry["repository_root"] = metadata.get("repository_root")
self.detection_summary = sorted(
summary_map.values(),
key=lambda x: str(x.get("last_detected") or ""),
reverse=True,
)
except Exception as inner:
logger.error(f"Failed to load detection summary: {inner}")
QMessageBox.critical(
self,
"Error",
f"Failed to load detection results:\n{str(inner)}",
)
self.detection_summary = []
def _populate_results_table(self):
"""Populate the table widget with detection summaries."""

View File

@@ -84,25 +84,19 @@ class InferenceEngine:
# Save detections to database, replacing any previous results for this image/model
if save_to_db:
deleted_count = self.db_manager.delete_detections_for_image(
image_id, self.model_id
)
deleted_count = self.db_manager.delete_detections_for_image(image_id, self.model_id)
if detections:
detection_records = []
for det in detections:
# Use normalized bbox from detection
bbox_normalized = det[
"bbox_normalized"
] # [x_min, y_min, x_max, y_max]
bbox_normalized = det["bbox_normalized"] # [x_min, y_min, x_max, y_max]
metadata = {
"class_id": det["class_id"],
"source_path": str(Path(image_path).resolve()),
}
if repository_root:
metadata["repository_root"] = str(
Path(repository_root).resolve()
)
metadata["repository_root"] = str(Path(repository_root).resolve())
record = {
"image_id": image_id,
@@ -115,16 +109,27 @@ class InferenceEngine:
}
detection_records.append(record)
inserted_count = self.db_manager.add_detections_batch(
detection_records
)
logger.info(
f"Saved {inserted_count} detections to database (replaced {deleted_count})"
)
inserted_count = self.db_manager.add_detections_batch(detection_records)
logger.info(f"Saved {inserted_count} detections to database (replaced {deleted_count})")
else:
logger.info(
f"Detection run removed {deleted_count} stale entries but produced no new detections"
logger.info(f"Detection run removed {deleted_count} stale entries but produced no new detections")
# Always store a run summary so the Results tab can show zero-detection runs.
try:
run_metadata = {
"source_path": str(Path(image_path).resolve()),
}
if repository_root:
run_metadata["repository_root"] = str(Path(repository_root).resolve())
self.db_manager.upsert_detection_run(
image_id=image_id,
model_id=self.model_id,
count=len(detections),
metadata=run_metadata,
)
except Exception as exc:
# Non-fatal: detection records may still be present.
logger.warning(f"Failed to store detection run summary: {exc}")
return {
"success": True,
@@ -232,9 +237,7 @@ class InferenceEngine:
for det in detections:
# Get color for this class
class_name = det["class_name"]
color_hex = bbox_colors.get(
class_name, bbox_colors.get("default", "#00FF00")
)
color_hex = bbox_colors.get(class_name, bbox_colors.get("default", "#00FF00"))
color = self._hex_to_bgr(color_hex)
# Draw segmentation mask if available and requested
@@ -243,10 +246,7 @@ class InferenceEngine:
if mask_normalized and len(mask_normalized) > 0:
# Convert normalized coordinates to absolute pixels
mask_points = np.array(
[
[int(pt[0] * width), int(pt[1] * height)]
for pt in mask_normalized
],
[[int(pt[0] * width), int(pt[1] * height)] for pt in mask_normalized],
dtype=np.int32,
)
@@ -270,9 +270,7 @@ class InferenceEngine:
label = f"{class_name} {det['confidence']:.2f}"
# Draw label background
(label_w, label_h), baseline = cv2.getTextSize(
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
)
(label_w, label_h), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
cv2.rectangle(
img,
(x1, y1 - label_h - baseline - 5),

View File

@@ -186,7 +186,7 @@ class YOLOWrapper:
raise RuntimeError(f"Failed to load model from {self.model_path}")
prepared_source, cleanup_path = self._prepare_source(source)
imgsz = 1088
imgsz = 640 # 1088
try:
logger.info(f"Running inference on {source} -> prepared_source {prepared_source}")
results = self.model.predict(

View File

@@ -6,6 +6,9 @@ import cv2
import numpy as np
from pathlib import Path
from typing import Optional, Tuple, Union
from skimage.restoration import rolling_ball
from skimage.filters import threshold_otsu
from scipy.ndimage import median_filter, gaussian_filter
from src.utils.logger import get_logger
from src.utils.file_utils import validate_file_path, is_image_file
@@ -37,6 +40,8 @@ def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.5) -> np.ndarray:
a1[a1 > p999] = p999
a1 /= a1.max()
# print("Using get_pseudo_rgb")
if 1:
a2 = a1.copy()
a2 = a2**gamma
@@ -46,14 +51,75 @@ def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.5) -> np.ndarray:
p9999 = np.percentile(a3, 99.99)
a3[a3 > p9999] = p9999
a3 /= a3.max()
out = np.stack([a1, a2, a3], axis=0)
else:
out = np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
return out
# return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
# return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
out = np.stack([a1, a2, a3], axis=0)
# print(any(np.isnan(out).flatten()))
def _get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.3) -> np.ndarray:
"""
Convert a grayscale image to a pseudo-RGB image using a gamma correction.
Args:
arr: Input grayscale image as numpy array
Returns:
Pseudo-RGB image as numpy array
"""
if arr.ndim != 2:
raise ValueError("Input array must be a grayscale image with shape (H, W)")
radius = 80
# bg = rolling_ball(arr, radius=radius)
a1 = arr.copy().astype(np.float32)
a1 = a1.astype(np.float32)
# a1 -= bg
# a1[a1 < 0] = 0
# a1 -= np.percentile(a1, 2)
# a1[a1 < 0] = 0
p999 = np.percentile(a1, 99.99)
a1[a1 > p999] = p999
a1 /= a1.max()
print("Using get_pseudo_rgb")
if 1:
a2 = a1.copy()
_a2 = a2**gamma
thr = threshold_otsu(_a2)
mask = gaussian_filter((_a2 > thr).astype(np.float32), sigma=5)
mask[mask > 0.0001] = 1
mask[mask <= 0.0001] = 0
a2 *= mask
# bg2 = rolling_ball(a2, radius=radius)
# a2 -= bg2
a2 -= np.percentile(a2, 2)
a2[a2 < 0] = 0
a2 /= a2.max()
a3 = a1.copy()
p9999 = np.percentile(a3, 99.99)
a3[a3 > p9999] = p9999
a3 /= a3.max()
out = np.stack([a1, a2, _a2], axis=0)
else:
out = np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
return out
# return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
# return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
# print(any(np.isnan(out).flatten()))
class ImageLoadError(Exception):
"""Exception raised when an image cannot be loaded."""
@@ -330,17 +396,20 @@ class Image:
return self._channels >= 3
def save(self, path: Union[str, Path], pseudo_rgb: bool = True) -> None:
# print("Image.save", self.data.shape)
if self.channels == 1:
# print("Image.save grayscale")
if pseudo_rgb:
img = get_pseudo_rgb(self.data)
print("Image.save", img.shape)
# print("Image.save", img.shape)
else:
img = np.repeat(self.data, 3, axis=2)
# print("Image.save no pseudo", img.shape)
else:
raise NotImplementedError("Only grayscale images are supported for now.")
# print("Image.save imwrite", img.shape)
imwrite(path, data=img)
def __repr__(self) -> str:

View File

@@ -5,7 +5,7 @@ from tifffile import imread, imwrite
from shapely.geometry import LineString
from copy import deepcopy
from scipy.ndimage import zoom
from skimage.restoration import rolling_ball
# debug
from src.utils.image import Image
@@ -160,8 +160,11 @@ class YoloLabelReader:
class ImageSplitter:
def __init__(self, image_path: Path, label_path: Path):
def __init__(self, image_path: Path, label_path: Path, subtract_bg: bool = False):
self.image = imread(image_path)
if subtract_bg:
self.image = self.image - rolling_ball(self.image, radius=12)
self.image[self.image < 0] = 0
self.image_path = image_path
self.label_path = label_path
if not label_path.exists():
@@ -280,6 +283,7 @@ def main(args):
data = ImageSplitter(
image_path=image_path,
label_path=(args.input / "labels" / image_path.stem).with_suffix(".txt"),
subtract_bg=args.subtract_bg,
)
if args.split_around_label:
@@ -373,6 +377,7 @@ if __name__ == "__main__":
default=67,
help="Padding around the label when splitting around the label.",
)
parser.add_argument("-bg", "--subtract-bg", action="store_true", help="Subtract background from the image.")
args = parser.parse_args()
main(args)