From 98bc89691b2943fd6c1b11eb40946d085d7c1a3f Mon Sep 17 00:00:00 2001 From: Martin Laasmaa Date: Fri, 23 Jan 2026 08:39:33 +0200 Subject: [PATCH] Major changes in annotation and results showing/storing --- src/database/db_manager.py | 105 ++++++++++++++++++++ src/database/schema.sql | 16 ++++ src/gui/tabs/results_tab.py | 186 +++++++++++++++++++++++++++--------- src/model/inference.py | 52 +++++----- src/model/yolo_wrapper.py | 2 +- src/utils/image.py | 77 ++++++++++++++- src/utils/image_splitter.py | 9 +- 7 files changed, 370 insertions(+), 77 deletions(-) diff --git a/src/database/db_manager.py b/src/database/db_manager.py index 9bf2f04..0f6ce1a 100644 --- a/src/database/db_manager.py +++ b/src/database/db_manager.py @@ -105,6 +105,103 @@ class DatabaseManager: conn.execute("PRAGMA foreign_keys = ON") # Enable foreign keys return conn + # ==================== Detection Run Operations ==================== + + def upsert_detection_run( + self, + image_id: int, + model_id: int, + count: int, + metadata: Optional[Dict] = None, + ) -> bool: + """Insert/update a per-image per-model detection run summary. + + This enables the UI to show runs even when zero detections were produced. + """ + + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute( + """ + INSERT INTO detection_runs (image_id, model_id, detected_at, count, metadata) + VALUES (?, ?, CURRENT_TIMESTAMP, ?, ?) + ON CONFLICT(image_id, model_id) DO UPDATE SET + detected_at = CURRENT_TIMESTAMP, + count = excluded.count, + metadata = excluded.metadata + """, + ( + int(image_id), + int(model_id), + int(count), + json.dumps(metadata) if metadata else None, + ), + ) + conn.commit() + return True + finally: + conn.close() + + def get_detection_run_summaries(self, limit: int = 500, offset: int = 0) -> List[Dict]: + """Return latest detection run summaries grouped by image+model. + + Includes runs with 0 detections. + """ + + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute( + """ + SELECT + dr.image_id, + dr.model_id, + dr.detected_at, + dr.count, + dr.metadata, + i.relative_path AS image_path, + i.filename AS image_filename, + m.model_name, + m.model_version, + GROUP_CONCAT(DISTINCT d.class_name) AS classes + FROM detection_runs dr + JOIN images i ON dr.image_id = i.id + JOIN models m ON dr.model_id = m.id + LEFT JOIN detections d + ON d.image_id = dr.image_id AND d.model_id = dr.model_id + GROUP BY dr.image_id, dr.model_id + ORDER BY dr.detected_at DESC + LIMIT ? OFFSET ? + """, + (int(limit), int(offset)), + ) + + rows: List[Dict] = [] + for row in cursor.fetchall(): + item = dict(row) + if item.get("metadata"): + try: + item["metadata"] = json.loads(item["metadata"]) + except Exception: + item["metadata"] = None + rows.append(item) + return rows + finally: + conn.close() + + def get_detection_run_total(self) -> int: + """Return total number of detection_runs rows.""" + + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute("SELECT COUNT(*) AS cnt FROM detection_runs") + row = cursor.fetchone() + return int(row["cnt"] if row and row["cnt"] is not None else 0) + finally: + conn.close() + # ==================== Model Operations ==================== def add_model( @@ -527,6 +624,14 @@ class DatabaseManager: conn = self.get_connection() try: cursor = conn.cursor() + # Also clear detection run summaries so the Results tab does not continue + # to show historical runs after detections have been wiped. + try: + cursor.execute("DELETE FROM detection_runs") + except sqlite3.OperationalError: + # Backwards-compatible: table may not exist on older DB files. + pass + cursor.execute("DELETE FROM detections") conn.commit() return cursor.rowcount diff --git a/src/database/schema.sql b/src/database/schema.sql index abef287..04fab0d 100644 --- a/src/database/schema.sql +++ b/src/database/schema.sql @@ -45,6 +45,19 @@ CREATE TABLE IF NOT EXISTS detections ( FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE ); +-- Detection runs table: stores per-image per-model run summaries (including 0 detections) +CREATE TABLE IF NOT EXISTS detection_runs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + image_id INTEGER NOT NULL, + model_id INTEGER NOT NULL, + detected_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + count INTEGER NOT NULL DEFAULT 0, + metadata TEXT, + UNIQUE(image_id, model_id), + FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE, + FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE +); + -- Object classes table: stores annotation class definitions with colors CREATE TABLE IF NOT EXISTS object_classes ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -81,6 +94,9 @@ CREATE INDEX IF NOT EXISTS idx_detections_model_id ON detections(model_id); CREATE INDEX IF NOT EXISTS idx_detections_class_name ON detections(class_name); CREATE INDEX IF NOT EXISTS idx_detections_detected_at ON detections(detected_at); CREATE INDEX IF NOT EXISTS idx_detections_confidence ON detections(confidence); +CREATE INDEX IF NOT EXISTS idx_detection_runs_image_id ON detection_runs(image_id); +CREATE INDEX IF NOT EXISTS idx_detection_runs_model_id ON detection_runs(model_id); +CREATE INDEX IF NOT EXISTS idx_detection_runs_detected_at ON detection_runs(detected_at); CREATE INDEX IF NOT EXISTS idx_images_relative_path ON images(relative_path); CREATE INDEX IF NOT EXISTS idx_images_added_at ON images(added_at); CREATE INDEX IF NOT EXISTS idx_images_source ON images(source); diff --git a/src/gui/tabs/results_tab.py b/src/gui/tabs/results_tab.py index 29858a0..b19da54 100644 --- a/src/gui/tabs/results_tab.py +++ b/src/gui/tabs/results_tab.py @@ -40,6 +40,12 @@ class ResultsTab(QWidget): self.db_manager = db_manager self.config_manager = config_manager + # Pagination + self.page_size = 200 + self.current_page = 0 # 0-based + self.total_runs = 0 + self.total_pages = 0 + self.detection_summary: List[Dict] = [] self.current_selection: Optional[Dict] = None self.current_image: Optional[Image] = None @@ -66,6 +72,20 @@ class ResultsTab(QWidget): self.refresh_btn.clicked.connect(self.refresh) controls_layout.addWidget(self.refresh_btn) + self.prev_page_btn = QPushButton("◀ Prev") + self.prev_page_btn.setToolTip("Previous page") + self.prev_page_btn.clicked.connect(self._prev_page) + controls_layout.addWidget(self.prev_page_btn) + + self.next_page_btn = QPushButton("Next ▶") + self.next_page_btn.setToolTip("Next page") + self.next_page_btn.clicked.connect(self._next_page) + controls_layout.addWidget(self.next_page_btn) + + self.page_label = QLabel("Page 1/1") + self.page_label.setMinimumWidth(100) + controls_layout.addWidget(self.page_label) + self.delete_all_btn = QPushButton("Delete All Detections") self.delete_all_btn.setToolTip( "Permanently delete ALL detections from the database.\n" "This cannot be undone." @@ -183,6 +203,8 @@ class ResultsTab(QWidget): def refresh(self): """Refresh the detection list and preview.""" + # Reset to first page on refresh. + self.current_page = 0 self._load_detection_summary() self._populate_results_table() self.current_selection = None @@ -193,57 +215,135 @@ class ResultsTab(QWidget): if hasattr(self, "export_labels_btn"): self.export_labels_btn.setEnabled(False) + self._update_pagination_controls() + + def _prev_page(self): + """Go to previous results page.""" + if self.current_page <= 0: + return + self.current_page -= 1 + self._load_detection_summary() + self._populate_results_table() + self._update_pagination_controls() + + def _next_page(self): + """Go to next results page.""" + if self.total_pages and self.current_page >= (self.total_pages - 1): + return + self.current_page += 1 + self._load_detection_summary() + self._populate_results_table() + self._update_pagination_controls() + + def _update_pagination_controls(self): + """Update pagination label/button enabled state.""" + + # Default state (safe) + total_pages = max(int(self.total_pages or 0), 1) + current_page = min(max(int(self.current_page or 0), 0), total_pages - 1) + self.current_page = current_page + + self.page_label.setText(f"Page {current_page + 1}/{total_pages}") + self.prev_page_btn.setEnabled(current_page > 0) + self.next_page_btn.setEnabled(current_page < (total_pages - 1)) + def _load_detection_summary(self): """Load latest detection summaries grouped by image + model.""" try: - detections = self.db_manager.get_detections(limit=500) - summary_map: Dict[tuple, Dict] = {} + # Prefer run summaries (supports zero-detection runs). Fall back to legacy + # detection aggregation if the DB/table isn't available. + try: + self.total_runs = int(self.db_manager.get_detection_run_total()) + except Exception: + self.total_runs = 0 - for det in detections: - key = (det["image_id"], det["model_id"]) - metadata = det.get("metadata") or {} - entry = summary_map.setdefault( - key, + self.total_pages = (self.total_runs + self.page_size - 1) // self.page_size if self.total_runs > 0 else 1 + + offset = int(self.current_page) * int(self.page_size) + runs = self.db_manager.get_detection_run_summaries( + limit=int(self.page_size), + offset=offset, + ) + summary: List[Dict] = [] + for run in runs: + meta = run.get("metadata") or {} + classes_raw = run.get("classes") + classes = set() + if isinstance(classes_raw, str) and classes_raw.strip(): + classes = {c.strip() for c in classes_raw.split(",") if c.strip()} + + summary.append( { - "image_id": det["image_id"], - "model_id": det["model_id"], - "image_path": det.get("image_path"), - "image_filename": det.get("image_filename") or det.get("image_path"), - "model_name": det.get("model_name", ""), - "model_version": det.get("model_version", ""), - "last_detected": det.get("detected_at"), - "count": 0, - "classes": set(), - "source_path": metadata.get("source_path"), - "repository_root": metadata.get("repository_root"), - }, + "image_id": run.get("image_id"), + "model_id": run.get("model_id"), + "image_path": run.get("image_path"), + "image_filename": run.get("image_filename") or run.get("image_path"), + "model_name": run.get("model_name", ""), + "model_version": run.get("model_version", ""), + "last_detected": run.get("detected_at"), + "count": int(run.get("count") or 0), + "classes": classes, + "source_path": meta.get("source_path"), + "repository_root": meta.get("repository_root"), + } ) - entry["count"] += 1 - if det.get("detected_at") and ( - not entry.get("last_detected") or str(det.get("detected_at")) > str(entry.get("last_detected")) - ): - entry["last_detected"] = det.get("detected_at") - if det.get("class_name"): - entry["classes"].add(det["class_name"]) - if metadata.get("source_path") and not entry.get("source_path"): - entry["source_path"] = metadata.get("source_path") - if metadata.get("repository_root") and not entry.get("repository_root"): - entry["repository_root"] = metadata.get("repository_root") - - self.detection_summary = sorted( - summary_map.values(), - key=lambda x: str(x.get("last_detected") or ""), - reverse=True, - ) + self.detection_summary = summary except Exception as e: - logger.error(f"Failed to load detection summary: {e}") - QMessageBox.critical( - self, - "Error", - f"Failed to load detection results:\n{str(e)}", - ) - self.detection_summary = [] + logger.error(f"Failed to load detection run summaries, falling back: {e}") + # Disable pagination if we can't page via detection_runs. + self.total_runs = 0 + self.total_pages = 1 + self.current_page = 0 + try: + detections = self.db_manager.get_detections(limit=int(self.page_size)) + summary_map: Dict[tuple, Dict] = {} + + for det in detections: + key = (det["image_id"], det["model_id"]) + metadata = det.get("metadata") or {} + entry = summary_map.setdefault( + key, + { + "image_id": det["image_id"], + "model_id": det["model_id"], + "image_path": det.get("image_path"), + "image_filename": det.get("image_filename") or det.get("image_path"), + "model_name": det.get("model_name", ""), + "model_version": det.get("model_version", ""), + "last_detected": det.get("detected_at"), + "count": 0, + "classes": set(), + "source_path": metadata.get("source_path"), + "repository_root": metadata.get("repository_root"), + }, + ) + + entry["count"] += 1 + if det.get("detected_at") and ( + not entry.get("last_detected") or str(det.get("detected_at")) > str(entry.get("last_detected")) + ): + entry["last_detected"] = det.get("detected_at") + if det.get("class_name"): + entry["classes"].add(det["class_name"]) + if metadata.get("source_path") and not entry.get("source_path"): + entry["source_path"] = metadata.get("source_path") + if metadata.get("repository_root") and not entry.get("repository_root"): + entry["repository_root"] = metadata.get("repository_root") + + self.detection_summary = sorted( + summary_map.values(), + key=lambda x: str(x.get("last_detected") or ""), + reverse=True, + ) + except Exception as inner: + logger.error(f"Failed to load detection summary: {inner}") + QMessageBox.critical( + self, + "Error", + f"Failed to load detection results:\n{str(inner)}", + ) + self.detection_summary = [] def _populate_results_table(self): """Populate the table widget with detection summaries.""" diff --git a/src/model/inference.py b/src/model/inference.py index 0e4aea3..fb2e13c 100644 --- a/src/model/inference.py +++ b/src/model/inference.py @@ -84,25 +84,19 @@ class InferenceEngine: # Save detections to database, replacing any previous results for this image/model if save_to_db: - deleted_count = self.db_manager.delete_detections_for_image( - image_id, self.model_id - ) + deleted_count = self.db_manager.delete_detections_for_image(image_id, self.model_id) if detections: detection_records = [] for det in detections: # Use normalized bbox from detection - bbox_normalized = det[ - "bbox_normalized" - ] # [x_min, y_min, x_max, y_max] + bbox_normalized = det["bbox_normalized"] # [x_min, y_min, x_max, y_max] metadata = { "class_id": det["class_id"], "source_path": str(Path(image_path).resolve()), } if repository_root: - metadata["repository_root"] = str( - Path(repository_root).resolve() - ) + metadata["repository_root"] = str(Path(repository_root).resolve()) record = { "image_id": image_id, @@ -115,16 +109,27 @@ class InferenceEngine: } detection_records.append(record) - inserted_count = self.db_manager.add_detections_batch( - detection_records - ) - logger.info( - f"Saved {inserted_count} detections to database (replaced {deleted_count})" - ) + inserted_count = self.db_manager.add_detections_batch(detection_records) + logger.info(f"Saved {inserted_count} detections to database (replaced {deleted_count})") else: - logger.info( - f"Detection run removed {deleted_count} stale entries but produced no new detections" + logger.info(f"Detection run removed {deleted_count} stale entries but produced no new detections") + + # Always store a run summary so the Results tab can show zero-detection runs. + try: + run_metadata = { + "source_path": str(Path(image_path).resolve()), + } + if repository_root: + run_metadata["repository_root"] = str(Path(repository_root).resolve()) + self.db_manager.upsert_detection_run( + image_id=image_id, + model_id=self.model_id, + count=len(detections), + metadata=run_metadata, ) + except Exception as exc: + # Non-fatal: detection records may still be present. + logger.warning(f"Failed to store detection run summary: {exc}") return { "success": True, @@ -232,9 +237,7 @@ class InferenceEngine: for det in detections: # Get color for this class class_name = det["class_name"] - color_hex = bbox_colors.get( - class_name, bbox_colors.get("default", "#00FF00") - ) + color_hex = bbox_colors.get(class_name, bbox_colors.get("default", "#00FF00")) color = self._hex_to_bgr(color_hex) # Draw segmentation mask if available and requested @@ -243,10 +246,7 @@ class InferenceEngine: if mask_normalized and len(mask_normalized) > 0: # Convert normalized coordinates to absolute pixels mask_points = np.array( - [ - [int(pt[0] * width), int(pt[1] * height)] - for pt in mask_normalized - ], + [[int(pt[0] * width), int(pt[1] * height)] for pt in mask_normalized], dtype=np.int32, ) @@ -270,9 +270,7 @@ class InferenceEngine: label = f"{class_name} {det['confidence']:.2f}" # Draw label background - (label_w, label_h), baseline = cv2.getTextSize( - label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1 - ) + (label_w, label_h), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) cv2.rectangle( img, (x1, y1 - label_h - baseline - 5), diff --git a/src/model/yolo_wrapper.py b/src/model/yolo_wrapper.py index a15a7bd..1a3ca55 100644 --- a/src/model/yolo_wrapper.py +++ b/src/model/yolo_wrapper.py @@ -186,7 +186,7 @@ class YOLOWrapper: raise RuntimeError(f"Failed to load model from {self.model_path}") prepared_source, cleanup_path = self._prepare_source(source) - imgsz = 1088 + imgsz = 640 # 1088 try: logger.info(f"Running inference on {source} -> prepared_source {prepared_source}") results = self.model.predict( diff --git a/src/utils/image.py b/src/utils/image.py index 286ffdf..804eef6 100644 --- a/src/utils/image.py +++ b/src/utils/image.py @@ -6,6 +6,9 @@ import cv2 import numpy as np from pathlib import Path from typing import Optional, Tuple, Union +from skimage.restoration import rolling_ball +from skimage.filters import threshold_otsu +from scipy.ndimage import median_filter, gaussian_filter from src.utils.logger import get_logger from src.utils.file_utils import validate_file_path, is_image_file @@ -37,6 +40,8 @@ def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.5) -> np.ndarray: a1[a1 > p999] = p999 a1 /= a1.max() + # print("Using get_pseudo_rgb") + if 1: a2 = a1.copy() a2 = a2**gamma @@ -46,14 +51,75 @@ def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.5) -> np.ndarray: p9999 = np.percentile(a3, 99.99) a3[a3 > p9999] = p9999 a3 /= a3.max() + out = np.stack([a1, a2, a3], axis=0) + + else: + out = np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0) + + return out # return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0) # return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0) - out = np.stack([a1, a2, a3], axis=0) # print(any(np.isnan(out).flatten())) + +def _get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.3) -> np.ndarray: + """ + Convert a grayscale image to a pseudo-RGB image using a gamma correction. + + Args: + arr: Input grayscale image as numpy array + + Returns: + Pseudo-RGB image as numpy array + """ + if arr.ndim != 2: + raise ValueError("Input array must be a grayscale image with shape (H, W)") + + radius = 80 + # bg = rolling_ball(arr, radius=radius) + a1 = arr.copy().astype(np.float32) + a1 = a1.astype(np.float32) + # a1 -= bg + # a1[a1 < 0] = 0 + # a1 -= np.percentile(a1, 2) + # a1[a1 < 0] = 0 + p999 = np.percentile(a1, 99.99) + a1[a1 > p999] = p999 + a1 /= a1.max() + + print("Using get_pseudo_rgb") + + if 1: + a2 = a1.copy() + _a2 = a2**gamma + thr = threshold_otsu(_a2) + mask = gaussian_filter((_a2 > thr).astype(np.float32), sigma=5) + mask[mask > 0.0001] = 1 + mask[mask <= 0.0001] = 0 + a2 *= mask + + # bg2 = rolling_ball(a2, radius=radius) + # a2 -= bg2 + a2 -= np.percentile(a2, 2) + a2[a2 < 0] = 0 + a2 /= a2.max() + + a3 = a1.copy() + p9999 = np.percentile(a3, 99.99) + a3[a3 > p9999] = p9999 + a3 /= a3.max() + out = np.stack([a1, a2, _a2], axis=0) + + else: + out = np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0) + return out + # return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0) + # return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0) + # print(any(np.isnan(out).flatten())) + class ImageLoadError(Exception): """Exception raised when an image cannot be loaded.""" @@ -330,17 +396,20 @@ class Image: return self._channels >= 3 def save(self, path: Union[str, Path], pseudo_rgb: bool = True) -> None: - + # print("Image.save", self.data.shape) if self.channels == 1: + # print("Image.save grayscale") if pseudo_rgb: img = get_pseudo_rgb(self.data) - print("Image.save", img.shape) + # print("Image.save", img.shape) + else: img = np.repeat(self.data, 3, axis=2) - + # print("Image.save no pseudo", img.shape) else: raise NotImplementedError("Only grayscale images are supported for now.") + # print("Image.save imwrite", img.shape) imwrite(path, data=img) def __repr__(self) -> str: diff --git a/src/utils/image_splitter.py b/src/utils/image_splitter.py index bef3941..c24ffb7 100644 --- a/src/utils/image_splitter.py +++ b/src/utils/image_splitter.py @@ -5,7 +5,7 @@ from tifffile import imread, imwrite from shapely.geometry import LineString from copy import deepcopy from scipy.ndimage import zoom - +from skimage.restoration import rolling_ball # debug from src.utils.image import Image @@ -160,8 +160,11 @@ class YoloLabelReader: class ImageSplitter: - def __init__(self, image_path: Path, label_path: Path): + def __init__(self, image_path: Path, label_path: Path, subtract_bg: bool = False): self.image = imread(image_path) + if subtract_bg: + self.image = self.image - rolling_ball(self.image, radius=12) + self.image[self.image < 0] = 0 self.image_path = image_path self.label_path = label_path if not label_path.exists(): @@ -280,6 +283,7 @@ def main(args): data = ImageSplitter( image_path=image_path, label_path=(args.input / "labels" / image_path.stem).with_suffix(".txt"), + subtract_bg=args.subtract_bg, ) if args.split_around_label: @@ -373,6 +377,7 @@ if __name__ == "__main__": default=67, help="Padding around the label when splitting around the label.", ) + parser.add_argument("-bg", "--subtract-bg", action="store_true", help="Subtract background from the image.") args = parser.parse_args() main(args)