Major changes in annotation and results showing/storing

This commit is contained in:
2026-01-23 08:39:33 +02:00
parent 3c8247b3bc
commit 98bc89691b
7 changed files with 370 additions and 77 deletions

View File

@@ -105,6 +105,103 @@ class DatabaseManager:
conn.execute("PRAGMA foreign_keys = ON") # Enable foreign keys conn.execute("PRAGMA foreign_keys = ON") # Enable foreign keys
return conn return conn
# ==================== Detection Run Operations ====================
def upsert_detection_run(
self,
image_id: int,
model_id: int,
count: int,
metadata: Optional[Dict] = None,
) -> bool:
"""Insert/update a per-image per-model detection run summary.
This enables the UI to show runs even when zero detections were produced.
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO detection_runs (image_id, model_id, detected_at, count, metadata)
VALUES (?, ?, CURRENT_TIMESTAMP, ?, ?)
ON CONFLICT(image_id, model_id) DO UPDATE SET
detected_at = CURRENT_TIMESTAMP,
count = excluded.count,
metadata = excluded.metadata
""",
(
int(image_id),
int(model_id),
int(count),
json.dumps(metadata) if metadata else None,
),
)
conn.commit()
return True
finally:
conn.close()
def get_detection_run_summaries(self, limit: int = 500, offset: int = 0) -> List[Dict]:
"""Return latest detection run summaries grouped by image+model.
Includes runs with 0 detections.
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"""
SELECT
dr.image_id,
dr.model_id,
dr.detected_at,
dr.count,
dr.metadata,
i.relative_path AS image_path,
i.filename AS image_filename,
m.model_name,
m.model_version,
GROUP_CONCAT(DISTINCT d.class_name) AS classes
FROM detection_runs dr
JOIN images i ON dr.image_id = i.id
JOIN models m ON dr.model_id = m.id
LEFT JOIN detections d
ON d.image_id = dr.image_id AND d.model_id = dr.model_id
GROUP BY dr.image_id, dr.model_id
ORDER BY dr.detected_at DESC
LIMIT ? OFFSET ?
""",
(int(limit), int(offset)),
)
rows: List[Dict] = []
for row in cursor.fetchall():
item = dict(row)
if item.get("metadata"):
try:
item["metadata"] = json.loads(item["metadata"])
except Exception:
item["metadata"] = None
rows.append(item)
return rows
finally:
conn.close()
def get_detection_run_total(self) -> int:
"""Return total number of detection_runs rows."""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) AS cnt FROM detection_runs")
row = cursor.fetchone()
return int(row["cnt"] if row and row["cnt"] is not None else 0)
finally:
conn.close()
# ==================== Model Operations ==================== # ==================== Model Operations ====================
def add_model( def add_model(
@@ -527,6 +624,14 @@ class DatabaseManager:
conn = self.get_connection() conn = self.get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
# Also clear detection run summaries so the Results tab does not continue
# to show historical runs after detections have been wiped.
try:
cursor.execute("DELETE FROM detection_runs")
except sqlite3.OperationalError:
# Backwards-compatible: table may not exist on older DB files.
pass
cursor.execute("DELETE FROM detections") cursor.execute("DELETE FROM detections")
conn.commit() conn.commit()
return cursor.rowcount return cursor.rowcount

View File

@@ -45,6 +45,19 @@ CREATE TABLE IF NOT EXISTS detections (
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE
); );
-- Detection runs table: stores per-image per-model run summaries (including 0 detections)
CREATE TABLE IF NOT EXISTS detection_runs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
image_id INTEGER NOT NULL,
model_id INTEGER NOT NULL,
detected_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
count INTEGER NOT NULL DEFAULT 0,
metadata TEXT,
UNIQUE(image_id, model_id),
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE,
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE
);
-- Object classes table: stores annotation class definitions with colors -- Object classes table: stores annotation class definitions with colors
CREATE TABLE IF NOT EXISTS object_classes ( CREATE TABLE IF NOT EXISTS object_classes (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -81,6 +94,9 @@ CREATE INDEX IF NOT EXISTS idx_detections_model_id ON detections(model_id);
CREATE INDEX IF NOT EXISTS idx_detections_class_name ON detections(class_name); CREATE INDEX IF NOT EXISTS idx_detections_class_name ON detections(class_name);
CREATE INDEX IF NOT EXISTS idx_detections_detected_at ON detections(detected_at); CREATE INDEX IF NOT EXISTS idx_detections_detected_at ON detections(detected_at);
CREATE INDEX IF NOT EXISTS idx_detections_confidence ON detections(confidence); CREATE INDEX IF NOT EXISTS idx_detections_confidence ON detections(confidence);
CREATE INDEX IF NOT EXISTS idx_detection_runs_image_id ON detection_runs(image_id);
CREATE INDEX IF NOT EXISTS idx_detection_runs_model_id ON detection_runs(model_id);
CREATE INDEX IF NOT EXISTS idx_detection_runs_detected_at ON detection_runs(detected_at);
CREATE INDEX IF NOT EXISTS idx_images_relative_path ON images(relative_path); CREATE INDEX IF NOT EXISTS idx_images_relative_path ON images(relative_path);
CREATE INDEX IF NOT EXISTS idx_images_added_at ON images(added_at); CREATE INDEX IF NOT EXISTS idx_images_added_at ON images(added_at);
CREATE INDEX IF NOT EXISTS idx_images_source ON images(source); CREATE INDEX IF NOT EXISTS idx_images_source ON images(source);

View File

@@ -40,6 +40,12 @@ class ResultsTab(QWidget):
self.db_manager = db_manager self.db_manager = db_manager
self.config_manager = config_manager self.config_manager = config_manager
# Pagination
self.page_size = 200
self.current_page = 0 # 0-based
self.total_runs = 0
self.total_pages = 0
self.detection_summary: List[Dict] = [] self.detection_summary: List[Dict] = []
self.current_selection: Optional[Dict] = None self.current_selection: Optional[Dict] = None
self.current_image: Optional[Image] = None self.current_image: Optional[Image] = None
@@ -66,6 +72,20 @@ class ResultsTab(QWidget):
self.refresh_btn.clicked.connect(self.refresh) self.refresh_btn.clicked.connect(self.refresh)
controls_layout.addWidget(self.refresh_btn) controls_layout.addWidget(self.refresh_btn)
self.prev_page_btn = QPushButton("◀ Prev")
self.prev_page_btn.setToolTip("Previous page")
self.prev_page_btn.clicked.connect(self._prev_page)
controls_layout.addWidget(self.prev_page_btn)
self.next_page_btn = QPushButton("Next ▶")
self.next_page_btn.setToolTip("Next page")
self.next_page_btn.clicked.connect(self._next_page)
controls_layout.addWidget(self.next_page_btn)
self.page_label = QLabel("Page 1/1")
self.page_label.setMinimumWidth(100)
controls_layout.addWidget(self.page_label)
self.delete_all_btn = QPushButton("Delete All Detections") self.delete_all_btn = QPushButton("Delete All Detections")
self.delete_all_btn.setToolTip( self.delete_all_btn.setToolTip(
"Permanently delete ALL detections from the database.\n" "This cannot be undone." "Permanently delete ALL detections from the database.\n" "This cannot be undone."
@@ -183,6 +203,8 @@ class ResultsTab(QWidget):
def refresh(self): def refresh(self):
"""Refresh the detection list and preview.""" """Refresh the detection list and preview."""
# Reset to first page on refresh.
self.current_page = 0
self._load_detection_summary() self._load_detection_summary()
self._populate_results_table() self._populate_results_table()
self.current_selection = None self.current_selection = None
@@ -193,10 +215,88 @@ class ResultsTab(QWidget):
if hasattr(self, "export_labels_btn"): if hasattr(self, "export_labels_btn"):
self.export_labels_btn.setEnabled(False) self.export_labels_btn.setEnabled(False)
self._update_pagination_controls()
def _prev_page(self):
"""Go to previous results page."""
if self.current_page <= 0:
return
self.current_page -= 1
self._load_detection_summary()
self._populate_results_table()
self._update_pagination_controls()
def _next_page(self):
"""Go to next results page."""
if self.total_pages and self.current_page >= (self.total_pages - 1):
return
self.current_page += 1
self._load_detection_summary()
self._populate_results_table()
self._update_pagination_controls()
def _update_pagination_controls(self):
"""Update pagination label/button enabled state."""
# Default state (safe)
total_pages = max(int(self.total_pages or 0), 1)
current_page = min(max(int(self.current_page or 0), 0), total_pages - 1)
self.current_page = current_page
self.page_label.setText(f"Page {current_page + 1}/{total_pages}")
self.prev_page_btn.setEnabled(current_page > 0)
self.next_page_btn.setEnabled(current_page < (total_pages - 1))
def _load_detection_summary(self): def _load_detection_summary(self):
"""Load latest detection summaries grouped by image + model.""" """Load latest detection summaries grouped by image + model."""
try: try:
detections = self.db_manager.get_detections(limit=500) # Prefer run summaries (supports zero-detection runs). Fall back to legacy
# detection aggregation if the DB/table isn't available.
try:
self.total_runs = int(self.db_manager.get_detection_run_total())
except Exception:
self.total_runs = 0
self.total_pages = (self.total_runs + self.page_size - 1) // self.page_size if self.total_runs > 0 else 1
offset = int(self.current_page) * int(self.page_size)
runs = self.db_manager.get_detection_run_summaries(
limit=int(self.page_size),
offset=offset,
)
summary: List[Dict] = []
for run in runs:
meta = run.get("metadata") or {}
classes_raw = run.get("classes")
classes = set()
if isinstance(classes_raw, str) and classes_raw.strip():
classes = {c.strip() for c in classes_raw.split(",") if c.strip()}
summary.append(
{
"image_id": run.get("image_id"),
"model_id": run.get("model_id"),
"image_path": run.get("image_path"),
"image_filename": run.get("image_filename") or run.get("image_path"),
"model_name": run.get("model_name", ""),
"model_version": run.get("model_version", ""),
"last_detected": run.get("detected_at"),
"count": int(run.get("count") or 0),
"classes": classes,
"source_path": meta.get("source_path"),
"repository_root": meta.get("repository_root"),
}
)
self.detection_summary = summary
except Exception as e:
logger.error(f"Failed to load detection run summaries, falling back: {e}")
# Disable pagination if we can't page via detection_runs.
self.total_runs = 0
self.total_pages = 1
self.current_page = 0
try:
detections = self.db_manager.get_detections(limit=int(self.page_size))
summary_map: Dict[tuple, Dict] = {} summary_map: Dict[tuple, Dict] = {}
for det in detections: for det in detections:
@@ -236,12 +336,12 @@ class ResultsTab(QWidget):
key=lambda x: str(x.get("last_detected") or ""), key=lambda x: str(x.get("last_detected") or ""),
reverse=True, reverse=True,
) )
except Exception as e: except Exception as inner:
logger.error(f"Failed to load detection summary: {e}") logger.error(f"Failed to load detection summary: {inner}")
QMessageBox.critical( QMessageBox.critical(
self, self,
"Error", "Error",
f"Failed to load detection results:\n{str(e)}", f"Failed to load detection results:\n{str(inner)}",
) )
self.detection_summary = [] self.detection_summary = []

View File

@@ -84,25 +84,19 @@ class InferenceEngine:
# Save detections to database, replacing any previous results for this image/model # Save detections to database, replacing any previous results for this image/model
if save_to_db: if save_to_db:
deleted_count = self.db_manager.delete_detections_for_image( deleted_count = self.db_manager.delete_detections_for_image(image_id, self.model_id)
image_id, self.model_id
)
if detections: if detections:
detection_records = [] detection_records = []
for det in detections: for det in detections:
# Use normalized bbox from detection # Use normalized bbox from detection
bbox_normalized = det[ bbox_normalized = det["bbox_normalized"] # [x_min, y_min, x_max, y_max]
"bbox_normalized"
] # [x_min, y_min, x_max, y_max]
metadata = { metadata = {
"class_id": det["class_id"], "class_id": det["class_id"],
"source_path": str(Path(image_path).resolve()), "source_path": str(Path(image_path).resolve()),
} }
if repository_root: if repository_root:
metadata["repository_root"] = str( metadata["repository_root"] = str(Path(repository_root).resolve())
Path(repository_root).resolve()
)
record = { record = {
"image_id": image_id, "image_id": image_id,
@@ -115,16 +109,27 @@ class InferenceEngine:
} }
detection_records.append(record) detection_records.append(record)
inserted_count = self.db_manager.add_detections_batch( inserted_count = self.db_manager.add_detections_batch(detection_records)
detection_records logger.info(f"Saved {inserted_count} detections to database (replaced {deleted_count})")
)
logger.info(
f"Saved {inserted_count} detections to database (replaced {deleted_count})"
)
else: else:
logger.info( logger.info(f"Detection run removed {deleted_count} stale entries but produced no new detections")
f"Detection run removed {deleted_count} stale entries but produced no new detections"
# Always store a run summary so the Results tab can show zero-detection runs.
try:
run_metadata = {
"source_path": str(Path(image_path).resolve()),
}
if repository_root:
run_metadata["repository_root"] = str(Path(repository_root).resolve())
self.db_manager.upsert_detection_run(
image_id=image_id,
model_id=self.model_id,
count=len(detections),
metadata=run_metadata,
) )
except Exception as exc:
# Non-fatal: detection records may still be present.
logger.warning(f"Failed to store detection run summary: {exc}")
return { return {
"success": True, "success": True,
@@ -232,9 +237,7 @@ class InferenceEngine:
for det in detections: for det in detections:
# Get color for this class # Get color for this class
class_name = det["class_name"] class_name = det["class_name"]
color_hex = bbox_colors.get( color_hex = bbox_colors.get(class_name, bbox_colors.get("default", "#00FF00"))
class_name, bbox_colors.get("default", "#00FF00")
)
color = self._hex_to_bgr(color_hex) color = self._hex_to_bgr(color_hex)
# Draw segmentation mask if available and requested # Draw segmentation mask if available and requested
@@ -243,10 +246,7 @@ class InferenceEngine:
if mask_normalized and len(mask_normalized) > 0: if mask_normalized and len(mask_normalized) > 0:
# Convert normalized coordinates to absolute pixels # Convert normalized coordinates to absolute pixels
mask_points = np.array( mask_points = np.array(
[ [[int(pt[0] * width), int(pt[1] * height)] for pt in mask_normalized],
[int(pt[0] * width), int(pt[1] * height)]
for pt in mask_normalized
],
dtype=np.int32, dtype=np.int32,
) )
@@ -270,9 +270,7 @@ class InferenceEngine:
label = f"{class_name} {det['confidence']:.2f}" label = f"{class_name} {det['confidence']:.2f}"
# Draw label background # Draw label background
(label_w, label_h), baseline = cv2.getTextSize( (label_w, label_h), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
)
cv2.rectangle( cv2.rectangle(
img, img,
(x1, y1 - label_h - baseline - 5), (x1, y1 - label_h - baseline - 5),

View File

@@ -186,7 +186,7 @@ class YOLOWrapper:
raise RuntimeError(f"Failed to load model from {self.model_path}") raise RuntimeError(f"Failed to load model from {self.model_path}")
prepared_source, cleanup_path = self._prepare_source(source) prepared_source, cleanup_path = self._prepare_source(source)
imgsz = 1088 imgsz = 640 # 1088
try: try:
logger.info(f"Running inference on {source} -> prepared_source {prepared_source}") logger.info(f"Running inference on {source} -> prepared_source {prepared_source}")
results = self.model.predict( results = self.model.predict(

View File

@@ -6,6 +6,9 @@ import cv2
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
from skimage.restoration import rolling_ball
from skimage.filters import threshold_otsu
from scipy.ndimage import median_filter, gaussian_filter
from src.utils.logger import get_logger from src.utils.logger import get_logger
from src.utils.file_utils import validate_file_path, is_image_file from src.utils.file_utils import validate_file_path, is_image_file
@@ -37,6 +40,8 @@ def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.5) -> np.ndarray:
a1[a1 > p999] = p999 a1[a1 > p999] = p999
a1 /= a1.max() a1 /= a1.max()
# print("Using get_pseudo_rgb")
if 1: if 1:
a2 = a1.copy() a2 = a1.copy()
a2 = a2**gamma a2 = a2**gamma
@@ -46,14 +51,75 @@ def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.5) -> np.ndarray:
p9999 = np.percentile(a3, 99.99) p9999 = np.percentile(a3, 99.99)
a3[a3 > p9999] = p9999 a3[a3 > p9999] = p9999
a3 /= a3.max() a3 /= a3.max()
out = np.stack([a1, a2, a3], axis=0)
else:
out = np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
return out
# return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0) # return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
# return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0) # return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
out = np.stack([a1, a2, a3], axis=0)
# print(any(np.isnan(out).flatten())) # print(any(np.isnan(out).flatten()))
def _get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.3) -> np.ndarray:
"""
Convert a grayscale image to a pseudo-RGB image using a gamma correction.
Args:
arr: Input grayscale image as numpy array
Returns:
Pseudo-RGB image as numpy array
"""
if arr.ndim != 2:
raise ValueError("Input array must be a grayscale image with shape (H, W)")
radius = 80
# bg = rolling_ball(arr, radius=radius)
a1 = arr.copy().astype(np.float32)
a1 = a1.astype(np.float32)
# a1 -= bg
# a1[a1 < 0] = 0
# a1 -= np.percentile(a1, 2)
# a1[a1 < 0] = 0
p999 = np.percentile(a1, 99.99)
a1[a1 > p999] = p999
a1 /= a1.max()
print("Using get_pseudo_rgb")
if 1:
a2 = a1.copy()
_a2 = a2**gamma
thr = threshold_otsu(_a2)
mask = gaussian_filter((_a2 > thr).astype(np.float32), sigma=5)
mask[mask > 0.0001] = 1
mask[mask <= 0.0001] = 0
a2 *= mask
# bg2 = rolling_ball(a2, radius=radius)
# a2 -= bg2
a2 -= np.percentile(a2, 2)
a2[a2 < 0] = 0
a2 /= a2.max()
a3 = a1.copy()
p9999 = np.percentile(a3, 99.99)
a3[a3 > p9999] = p9999
a3 /= a3.max()
out = np.stack([a1, a2, _a2], axis=0)
else:
out = np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
return out return out
# return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
# return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
# print(any(np.isnan(out).flatten()))
class ImageLoadError(Exception): class ImageLoadError(Exception):
"""Exception raised when an image cannot be loaded.""" """Exception raised when an image cannot be loaded."""
@@ -330,17 +396,20 @@ class Image:
return self._channels >= 3 return self._channels >= 3
def save(self, path: Union[str, Path], pseudo_rgb: bool = True) -> None: def save(self, path: Union[str, Path], pseudo_rgb: bool = True) -> None:
# print("Image.save", self.data.shape)
if self.channels == 1: if self.channels == 1:
# print("Image.save grayscale")
if pseudo_rgb: if pseudo_rgb:
img = get_pseudo_rgb(self.data) img = get_pseudo_rgb(self.data)
print("Image.save", img.shape) # print("Image.save", img.shape)
else: else:
img = np.repeat(self.data, 3, axis=2) img = np.repeat(self.data, 3, axis=2)
# print("Image.save no pseudo", img.shape)
else: else:
raise NotImplementedError("Only grayscale images are supported for now.") raise NotImplementedError("Only grayscale images are supported for now.")
# print("Image.save imwrite", img.shape)
imwrite(path, data=img) imwrite(path, data=img)
def __repr__(self) -> str: def __repr__(self) -> str:

View File

@@ -5,7 +5,7 @@ from tifffile import imread, imwrite
from shapely.geometry import LineString from shapely.geometry import LineString
from copy import deepcopy from copy import deepcopy
from scipy.ndimage import zoom from scipy.ndimage import zoom
from skimage.restoration import rolling_ball
# debug # debug
from src.utils.image import Image from src.utils.image import Image
@@ -160,8 +160,11 @@ class YoloLabelReader:
class ImageSplitter: class ImageSplitter:
def __init__(self, image_path: Path, label_path: Path): def __init__(self, image_path: Path, label_path: Path, subtract_bg: bool = False):
self.image = imread(image_path) self.image = imread(image_path)
if subtract_bg:
self.image = self.image - rolling_ball(self.image, radius=12)
self.image[self.image < 0] = 0
self.image_path = image_path self.image_path = image_path
self.label_path = label_path self.label_path = label_path
if not label_path.exists(): if not label_path.exists():
@@ -280,6 +283,7 @@ def main(args):
data = ImageSplitter( data = ImageSplitter(
image_path=image_path, image_path=image_path,
label_path=(args.input / "labels" / image_path.stem).with_suffix(".txt"), label_path=(args.input / "labels" / image_path.stem).with_suffix(".txt"),
subtract_bg=args.subtract_bg,
) )
if args.split_around_label: if args.split_around_label:
@@ -373,6 +377,7 @@ if __name__ == "__main__":
default=67, default=67,
help="Padding around the label when splitting around the label.", help="Padding around the label when splitting around the label.",
) )
parser.add_argument("-bg", "--subtract-bg", action="store_true", help="Subtract background from the image.")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)