3 Commits

11 changed files with 1266 additions and 97 deletions

View File

@@ -40,8 +40,15 @@ class DatabaseManager:
conn = self.get_connection() conn = self.get_connection()
try: try:
# Check if annotations table needs migration # Pre-schema migrations.
# These must run BEFORE executing schema.sql because schema.sql may
# contain CREATE INDEX statements referencing newly added columns.
#
# 1) Check if annotations table needs migration (may drop an old table)
self._migrate_annotations_table(conn) self._migrate_annotations_table(conn)
# 2) Ensure images table has the required columns (e.g. 'source')
self._migrate_images_table(conn)
conn.commit()
# Read schema file and execute # Read schema file and execute
schema_path = Path(__file__).parent / "schema.sql" schema_path = Path(__file__).parent / "schema.sql"
@@ -53,6 +60,19 @@ class DatabaseManager:
finally: finally:
conn.close() conn.close()
def _migrate_images_table(self, conn: sqlite3.Connection) -> None:
"""Migrate images table to include the 'source' column if missing."""
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='images'")
if not cursor.fetchone():
return
cursor.execute("PRAGMA table_info(images)")
columns = {row[1] for row in cursor.fetchall()}
if "source" not in columns:
cursor.execute("ALTER TABLE images ADD COLUMN source TEXT")
def _migrate_annotations_table(self, conn: sqlite3.Connection) -> None: def _migrate_annotations_table(self, conn: sqlite3.Connection) -> None:
""" """
Migrate annotations table from old schema (class_name) to new schema (class_id). Migrate annotations table from old schema (class_name) to new schema (class_id).
@@ -85,6 +105,103 @@ class DatabaseManager:
conn.execute("PRAGMA foreign_keys = ON") # Enable foreign keys conn.execute("PRAGMA foreign_keys = ON") # Enable foreign keys
return conn return conn
# ==================== Detection Run Operations ====================
def upsert_detection_run(
self,
image_id: int,
model_id: int,
count: int,
metadata: Optional[Dict] = None,
) -> bool:
"""Insert/update a per-image per-model detection run summary.
This enables the UI to show runs even when zero detections were produced.
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO detection_runs (image_id, model_id, detected_at, count, metadata)
VALUES (?, ?, CURRENT_TIMESTAMP, ?, ?)
ON CONFLICT(image_id, model_id) DO UPDATE SET
detected_at = CURRENT_TIMESTAMP,
count = excluded.count,
metadata = excluded.metadata
""",
(
int(image_id),
int(model_id),
int(count),
json.dumps(metadata) if metadata else None,
),
)
conn.commit()
return True
finally:
conn.close()
def get_detection_run_summaries(self, limit: int = 500, offset: int = 0) -> List[Dict]:
"""Return latest detection run summaries grouped by image+model.
Includes runs with 0 detections.
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"""
SELECT
dr.image_id,
dr.model_id,
dr.detected_at,
dr.count,
dr.metadata,
i.relative_path AS image_path,
i.filename AS image_filename,
m.model_name,
m.model_version,
GROUP_CONCAT(DISTINCT d.class_name) AS classes
FROM detection_runs dr
JOIN images i ON dr.image_id = i.id
JOIN models m ON dr.model_id = m.id
LEFT JOIN detections d
ON d.image_id = dr.image_id AND d.model_id = dr.model_id
GROUP BY dr.image_id, dr.model_id
ORDER BY dr.detected_at DESC
LIMIT ? OFFSET ?
""",
(int(limit), int(offset)),
)
rows: List[Dict] = []
for row in cursor.fetchall():
item = dict(row)
if item.get("metadata"):
try:
item["metadata"] = json.loads(item["metadata"])
except Exception:
item["metadata"] = None
rows.append(item)
return rows
finally:
conn.close()
def get_detection_run_total(self) -> int:
"""Return total number of detection_runs rows."""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) AS cnt FROM detection_runs")
row = cursor.fetchone()
return int(row["cnt"] if row and row["cnt"] is not None else 0)
finally:
conn.close()
# ==================== Model Operations ==================== # ==================== Model Operations ====================
def add_model( def add_model(
@@ -233,6 +350,7 @@ class DatabaseManager:
height: int, height: int,
captured_at: Optional[datetime] = None, captured_at: Optional[datetime] = None,
checksum: Optional[str] = None, checksum: Optional[str] = None,
source: Optional[str] = None,
) -> int: ) -> int:
""" """
Add a new image to the database. Add a new image to the database.
@@ -253,10 +371,10 @@ class DatabaseManager:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute(
""" """
INSERT INTO images (relative_path, filename, width, height, captured_at, checksum) INSERT INTO images (relative_path, filename, width, height, captured_at, checksum, source)
VALUES (?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?)
""", """,
(relative_path, filename, width, height, captured_at, checksum), (relative_path, filename, width, height, captured_at, checksum, source),
) )
conn.commit() conn.commit()
return cursor.lastrowid return cursor.lastrowid
@@ -286,6 +404,18 @@ class DatabaseManager:
return existing["id"] return existing["id"]
return self.add_image(relative_path, filename, width, height) return self.add_image(relative_path, filename, width, height)
def set_image_source(self, image_id: int, source: Optional[str]) -> bool:
"""Set/update the source marker for an image row."""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("UPDATE images SET source = ? WHERE id = ?", (source, int(image_id)))
conn.commit()
return cursor.rowcount > 0
finally:
conn.close()
# ==================== Detection Operations ==================== # ==================== Detection Operations ====================
def add_detection( def add_detection(
@@ -494,6 +624,14 @@ class DatabaseManager:
conn = self.get_connection() conn = self.get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
# Also clear detection run summaries so the Results tab does not continue
# to show historical runs after detections have been wiped.
try:
cursor.execute("DELETE FROM detection_runs")
except sqlite3.OperationalError:
# Backwards-compatible: table may not exist on older DB files.
pass
cursor.execute("DELETE FROM detections") cursor.execute("DELETE FROM detections")
conn.commit() conn.commit()
return cursor.rowcount return cursor.rowcount
@@ -658,6 +796,153 @@ class DatabaseManager:
# ==================== Annotation Operations ==================== # ==================== Annotation Operations ====================
def get_images_summary(
self,
name_filter: Optional[str] = None,
source_filter: Optional[str] = None,
order_by: str = "filename",
order_dir: str = "ASC",
limit: Optional[int] = None,
offset: int = 0,
) -> List[Dict]:
"""Return all images with annotation counts (including zero).
This is used by the Annotation tab to populate the image list even when
no annotations exist yet.
Args:
name_filter: Optional substring filter applied to filename/relative_path.
order_by: One of: 'filename', 'relative_path', 'annotation_count', 'added_at'.
order_dir: 'ASC' or 'DESC'.
limit: Optional max number of rows.
offset: Pagination offset.
Returns:
List of dicts: {id, relative_path, filename, added_at, annotation_count}
"""
allowed_order_by = {
"filename": "i.filename",
"relative_path": "i.relative_path",
"annotation_count": "annotation_count",
"added_at": "i.added_at",
}
order_expr = allowed_order_by.get(order_by, "i.filename")
dir_norm = str(order_dir).upper().strip()
if dir_norm not in {"ASC", "DESC"}:
dir_norm = "ASC"
conn = self.get_connection()
try:
params: List[Any] = []
where_clauses: List[str] = []
if name_filter:
token = f"%{name_filter}%"
where_clauses.append("(i.filename LIKE ? OR i.relative_path LIKE ?)")
params.extend([token, token])
if source_filter:
where_clauses.append("i.source = ?")
params.append(source_filter)
where_sql = ""
if where_clauses:
where_sql = "WHERE " + " AND ".join(where_clauses)
limit_sql = ""
if limit is not None:
limit_sql = " LIMIT ? OFFSET ?"
params.extend([int(limit), int(offset)])
query = f"""
SELECT
i.id,
i.relative_path,
i.filename,
i.added_at,
COUNT(a.id) AS annotation_count
FROM images i
LEFT JOIN annotations a ON a.image_id = i.id
{where_sql}
GROUP BY i.id
ORDER BY {order_expr} {dir_norm}
{limit_sql}
"""
cursor = conn.cursor()
cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()]
finally:
conn.close()
def get_annotated_images_summary(
self,
name_filter: Optional[str] = None,
order_by: str = "filename",
order_dir: str = "ASC",
limit: Optional[int] = None,
offset: int = 0,
) -> List[Dict]:
"""Return images that have at least one manual annotation.
Args:
name_filter: Optional substring filter applied to filename/relative_path.
order_by: One of: 'filename', 'relative_path', 'annotation_count', 'added_at'.
order_dir: 'ASC' or 'DESC'.
limit: Optional max number of rows.
offset: Pagination offset.
Returns:
List of dicts: {id, relative_path, filename, added_at, annotation_count}
"""
allowed_order_by = {
"filename": "i.filename",
"relative_path": "i.relative_path",
"annotation_count": "annotation_count",
"added_at": "i.added_at",
}
order_expr = allowed_order_by.get(order_by, "i.filename")
dir_norm = str(order_dir).upper().strip()
if dir_norm not in {"ASC", "DESC"}:
dir_norm = "ASC"
conn = self.get_connection()
try:
params: List[Any] = []
where_sql = ""
if name_filter:
# Case-insensitive substring search.
token = f"%{name_filter}%"
where_sql = "WHERE (i.filename LIKE ? OR i.relative_path LIKE ?)"
params.extend([token, token])
limit_sql = ""
if limit is not None:
limit_sql = " LIMIT ? OFFSET ?"
params.extend([int(limit), int(offset)])
query = f"""
SELECT
i.id,
i.relative_path,
i.filename,
i.added_at,
COUNT(a.id) AS annotation_count
FROM images i
JOIN annotations a ON a.image_id = i.id
{where_sql}
GROUP BY i.id
HAVING annotation_count > 0
ORDER BY {order_expr} {dir_norm}
{limit_sql}
"""
cursor = conn.cursor()
cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()]
finally:
conn.close()
def add_annotation( def add_annotation(
self, self,
image_id: int, image_id: int,
@@ -763,6 +1048,27 @@ class DatabaseManager:
finally: finally:
conn.close() conn.close()
def delete_annotations_for_image(self, image_id: int) -> int:
"""Delete ALL annotations for a specific image.
This is primarily used for import/overwrite workflows.
Args:
image_id: ID of the image whose annotations should be deleted.
Returns:
Number of rows deleted.
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("DELETE FROM annotations WHERE image_id = ?", (int(image_id),))
conn.commit()
return int(cursor.rowcount or 0)
finally:
conn.close()
# ==================== Object Class Operations ==================== # ==================== Object Class Operations ====================
def get_object_classes(self) -> List[Dict]: def get_object_classes(self) -> List[Dict]:

View File

@@ -7,7 +7,7 @@ CREATE TABLE IF NOT EXISTS models (
model_name TEXT NOT NULL, model_name TEXT NOT NULL,
model_version TEXT NOT NULL, model_version TEXT NOT NULL,
model_path TEXT NOT NULL, model_path TEXT NOT NULL,
base_model TEXT NOT NULL DEFAULT 'yolov8s.pt', base_model TEXT NOT NULL DEFAULT 'yolo11s.pt',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
training_params TEXT, -- JSON string of training parameters training_params TEXT, -- JSON string of training parameters
metrics TEXT, -- JSON string of validation metrics metrics TEXT, -- JSON string of validation metrics
@@ -23,7 +23,8 @@ CREATE TABLE IF NOT EXISTS images (
height INTEGER, height INTEGER,
captured_at TIMESTAMP, captured_at TIMESTAMP,
added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
checksum TEXT checksum TEXT,
source TEXT
); );
-- Detections table: stores detection results -- Detections table: stores detection results
@@ -44,6 +45,19 @@ CREATE TABLE IF NOT EXISTS detections (
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE
); );
-- Detection runs table: stores per-image per-model run summaries (including 0 detections)
CREATE TABLE IF NOT EXISTS detection_runs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
image_id INTEGER NOT NULL,
model_id INTEGER NOT NULL,
detected_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
count INTEGER NOT NULL DEFAULT 0,
metadata TEXT,
UNIQUE(image_id, model_id),
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE,
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE
);
-- Object classes table: stores annotation class definitions with colors -- Object classes table: stores annotation class definitions with colors
CREATE TABLE IF NOT EXISTS object_classes ( CREATE TABLE IF NOT EXISTS object_classes (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -80,9 +94,13 @@ CREATE INDEX IF NOT EXISTS idx_detections_model_id ON detections(model_id);
CREATE INDEX IF NOT EXISTS idx_detections_class_name ON detections(class_name); CREATE INDEX IF NOT EXISTS idx_detections_class_name ON detections(class_name);
CREATE INDEX IF NOT EXISTS idx_detections_detected_at ON detections(detected_at); CREATE INDEX IF NOT EXISTS idx_detections_detected_at ON detections(detected_at);
CREATE INDEX IF NOT EXISTS idx_detections_confidence ON detections(confidence); CREATE INDEX IF NOT EXISTS idx_detections_confidence ON detections(confidence);
CREATE INDEX IF NOT EXISTS idx_detection_runs_image_id ON detection_runs(image_id);
CREATE INDEX IF NOT EXISTS idx_detection_runs_model_id ON detection_runs(model_id);
CREATE INDEX IF NOT EXISTS idx_detection_runs_detected_at ON detection_runs(detected_at);
CREATE INDEX IF NOT EXISTS idx_images_relative_path ON images(relative_path); CREATE INDEX IF NOT EXISTS idx_images_relative_path ON images(relative_path);
CREATE INDEX IF NOT EXISTS idx_images_added_at ON images(added_at); CREATE INDEX IF NOT EXISTS idx_images_added_at ON images(added_at);
CREATE INDEX IF NOT EXISTS idx_images_source ON images(source);
CREATE INDEX IF NOT EXISTS idx_annotations_image_id ON annotations(image_id); CREATE INDEX IF NOT EXISTS idx_annotations_image_id ON annotations(image_id);
CREATE INDEX IF NOT EXISTS idx_annotations_class_id ON annotations(class_id); CREATE INDEX IF NOT EXISTS idx_annotations_class_id ON annotations(class_id);
CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at); CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at);
CREATE INDEX IF NOT EXISTS idx_object_classes_class_name ON object_classes(class_name); CREATE INDEX IF NOT EXISTS idx_object_classes_class_name ON object_classes(class_name);

View File

@@ -199,6 +199,10 @@ class MainWindow(QMainWindow):
self.training_tab.refresh() self.training_tab.refresh()
if hasattr(self, "results_tab"): if hasattr(self, "results_tab"):
self.results_tab.refresh() self.results_tab.refresh()
if hasattr(self, "annotation_tab"):
self.annotation_tab.refresh()
if hasattr(self, "validation_tab"):
self.validation_tab.refresh()
except Exception as e: except Exception as e:
logger.error(f"Error applying settings: {e}") logger.error(f"Error applying settings: {e}")
@@ -215,6 +219,14 @@ class MainWindow(QMainWindow):
logger.debug(f"Switched to tab: {tab_name}") logger.debug(f"Switched to tab: {tab_name}")
self._update_status(f"Viewing: {tab_name}") self._update_status(f"Viewing: {tab_name}")
# Ensure the Annotation tab always shows up-to-date DB-backed lists.
try:
current_widget = self.tab_widget.widget(index)
if hasattr(self, "annotation_tab") and current_widget is self.annotation_tab:
self.annotation_tab.refresh()
except Exception as exc:
logger.debug(f"Failed to refresh annotation tab on selection: {exc}")
def _show_database_stats(self): def _show_database_stats(self):
"""Show database statistics dialog.""" """Show database statistics dialog."""
try: try:
@@ -527,6 +539,11 @@ class MainWindow(QMainWindow):
if hasattr(self, "training_tab"): if hasattr(self, "training_tab"):
self.training_tab.shutdown() self.training_tab.shutdown()
if hasattr(self, "annotation_tab"): if hasattr(self, "annotation_tab"):
# Best-effort refresh so DB-backed UI state is consistent at shutdown.
try:
self.annotation_tab.refresh()
except Exception:
pass
self.annotation_tab.save_state() self.annotation_tab.save_state()
logger.info("Application closing") logger.info("Application closing")

View File

@@ -13,6 +13,11 @@ from PySide6.QtWidgets import (
QFileDialog, QFileDialog,
QMessageBox, QMessageBox,
QSplitter, QSplitter,
QLineEdit,
QTableWidget,
QTableWidgetItem,
QHeaderView,
QAbstractItemView,
) )
from PySide6.QtCore import Qt, QSettings from PySide6.QtCore import Qt, QSettings
from pathlib import Path from pathlib import Path
@@ -50,6 +55,32 @@ class AnnotationTab(QWidget):
self.main_splitter = QSplitter(Qt.Horizontal) self.main_splitter = QSplitter(Qt.Horizontal)
self.main_splitter.setHandleWidth(10) self.main_splitter.setHandleWidth(10)
# { Left-most pane: annotated images list
annotated_group = QGroupBox("Annotated Images")
annotated_layout = QVBoxLayout()
filter_row = QHBoxLayout()
filter_row.addWidget(QLabel("Filter:"))
self.annotated_filter_edit = QLineEdit()
self.annotated_filter_edit.setPlaceholderText("Type to filter by image name…")
self.annotated_filter_edit.textChanged.connect(self._refresh_annotated_images_list)
filter_row.addWidget(self.annotated_filter_edit, 1)
annotated_layout.addLayout(filter_row)
self.annotated_images_table = QTableWidget(0, 2)
self.annotated_images_table.setHorizontalHeaderLabels(["Image", "Annotations"])
self.annotated_images_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Stretch)
self.annotated_images_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeToContents)
self.annotated_images_table.setSelectionBehavior(QAbstractItemView.SelectRows)
self.annotated_images_table.setSelectionMode(QAbstractItemView.SingleSelection)
self.annotated_images_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
self.annotated_images_table.setSortingEnabled(True)
self.annotated_images_table.itemSelectionChanged.connect(self._on_annotated_image_selected)
annotated_layout.addWidget(self.annotated_images_table, 1)
annotated_group.setLayout(annotated_layout)
# }
# { Left splitter for image display and zoom info # { Left splitter for image display and zoom info
self.left_splitter = QSplitter(Qt.Vertical) self.left_splitter = QSplitter(Qt.Vertical)
self.left_splitter.setHandleWidth(10) self.left_splitter.setHandleWidth(10)
@@ -104,11 +135,27 @@ class AnnotationTab(QWidget):
load_group = QGroupBox("Image Loading") load_group = QGroupBox("Image Loading")
load_layout = QVBoxLayout() load_layout = QVBoxLayout()
# Load image button # Buttons row
button_layout = QHBoxLayout() button_layout = QHBoxLayout()
self.load_image_btn = QPushButton("Load Image") self.load_image_btn = QPushButton("Load Image")
self.load_image_btn.clicked.connect(self._load_image) self.load_image_btn.clicked.connect(self._load_image)
button_layout.addWidget(self.load_image_btn) button_layout.addWidget(self.load_image_btn)
self.import_images_btn = QPushButton("Import Images")
self.import_images_btn.setToolTip(
"Import one or more images into the database.\n" "Images already present in the DB are skipped."
)
self.import_images_btn.clicked.connect(self._import_images)
button_layout.addWidget(self.import_images_btn)
self.import_annotations_btn = QPushButton("Import Annotations")
self.import_annotations_btn.setToolTip(
"Import YOLO .txt annotation files and register them with their corresponding images.\n"
"Existing annotations for those images will be overwritten."
)
self.import_annotations_btn.clicked.connect(self._import_annotations)
button_layout.addWidget(self.import_annotations_btn)
button_layout.addStretch() button_layout.addStretch()
load_layout.addLayout(button_layout) load_layout.addLayout(button_layout)
@@ -120,12 +167,13 @@ class AnnotationTab(QWidget):
self.right_splitter.addWidget(load_group) self.right_splitter.addWidget(load_group)
# } # }
# Add both splitters to the main horizontal splitter # Add list + both splitters to the main horizontal splitter
self.main_splitter.addWidget(annotated_group)
self.main_splitter.addWidget(self.left_splitter) self.main_splitter.addWidget(self.left_splitter)
self.main_splitter.addWidget(self.right_splitter) self.main_splitter.addWidget(self.right_splitter)
# Set initial sizes: 75% for left (image), 25% for right (controls) # Set initial sizes: list (left), canvas (middle), controls (right)
self.main_splitter.setSizes([750, 250]) self.main_splitter.setSizes([320, 650, 280])
layout.addWidget(self.main_splitter) layout.addWidget(self.main_splitter)
self.setLayout(layout) self.setLayout(layout)
@@ -133,6 +181,375 @@ class AnnotationTab(QWidget):
# Restore splitter positions from settings # Restore splitter positions from settings
self._restore_state() self._restore_state()
# Populate list on startup.
self._refresh_annotated_images_list()
def _import_images(self) -> None:
"""Import one or more images into the database and refresh the list."""
settings = QSettings("microscopy_app", "object_detection")
last_dir = settings.value("annotation_tab/last_image_import_directory", None)
repo_root = (self.config_manager.get_image_repository_path() or "").strip()
if last_dir and Path(str(last_dir)).exists():
start_dir = str(last_dir)
elif repo_root and Path(repo_root).exists():
start_dir = repo_root
else:
start_dir = str(Path.home())
# Build filter string for supported extensions
patterns = " ".join(f"*{ext}" for ext in Image.SUPPORTED_EXTENSIONS)
file_paths, _ = QFileDialog.getOpenFileNames(
self,
"Select Image File(s)",
start_dir,
f"Images ({patterns})",
)
if not file_paths:
return
try:
settings.setValue("annotation_tab/last_image_import_directory", str(Path(file_paths[0]).parent))
# Keep compatibility with the existing image resolver fallback (it checks last_directory).
settings.setValue("annotation_tab/last_directory", str(Path(file_paths[0]).parent))
except Exception:
pass
imported = 0
tagged_existing = 0
skipped = 0
errors: list[str] = []
for fp in file_paths:
try:
img_path = Path(fp)
img = Image(str(img_path))
relative_path = self._compute_relative_path_for_repo(img_path)
# Skip if already present
existing = self.db_manager.get_image_by_path(relative_path)
if existing:
# If the image already exists (e.g. created earlier by other workflows),
# tag it as being managed by the Annotation tab so it becomes visible
# in the left list.
try:
self.db_manager.set_image_source(int(existing["id"]), "annotation_tab")
tagged_existing += 1
except Exception:
# If tagging fails, fall back to treating as skipped.
skipped += 1
continue
image_id = self.db_manager.add_image(
relative_path,
img_path.name,
img.width,
img.height,
source="annotation_tab",
)
try:
# In case the DB row was created by an older schema/migration path.
self.db_manager.set_image_source(image_id, "annotation_tab")
except Exception:
pass
imported += 1
except ImageLoadError as exc:
skipped += 1
errors.append(f"Failed to load image {fp}: {exc}")
except Exception as exc:
skipped += 1
errors.append(f"Failed to import image {fp}: {exc}")
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
msg = (
f"Imported: {imported}\n"
f"Already in DB (tagged for Annotation tab): {tagged_existing}\n"
f"Skipped (errors): {skipped}"
)
if errors:
details = "\n".join(errors[:25])
if len(errors) > 25:
details += f"\n... and {len(errors) - 25} more"
msg += "\n\nDetails:\n" + details
QMessageBox.information(self, "Import Images", msg)
# ==================== Import annotations (YOLO .txt) ====================
def _import_annotations(self) -> None:
"""Import YOLO segmentation/bbox annotations from one or more .txt files."""
settings = QSettings("microscopy_app", "object_detection")
last_dir = settings.value("annotation_tab/last_annotation_directory", None)
# Default start dir: repo root if set, otherwise last used, otherwise home.
repo_root = (self.config_manager.get_image_repository_path() or "").strip()
if last_dir and Path(str(last_dir)).exists():
start_dir = str(last_dir)
elif repo_root and Path(repo_root).exists():
start_dir = repo_root
else:
start_dir = str(Path.home())
file_paths, _ = QFileDialog.getOpenFileNames(
self,
"Select YOLO Annotation File(s)",
start_dir,
"YOLO annotations (*.txt)",
)
if not file_paths:
return
# Persist last annotation directory for the next import.
try:
settings.setValue("annotation_tab/last_annotation_directory", str(Path(file_paths[0]).parent))
except Exception:
pass
imported_images = 0
imported_annotations = 0
overwritten_images = 0
skipped = 0
errors: list[str] = []
for label_file in file_paths:
label_path = Path(label_file)
try:
image_path = self._infer_corresponding_image_path(label_path)
if not image_path:
skipped += 1
errors.append(f"Image not found for label file: {label_path}")
continue
# Load image to obtain width/height for DB entry.
img = Image(str(image_path))
# Store in DB using a repo-relative path if possible.
relative_path = self._compute_relative_path_for_repo(image_path)
image_id = self.db_manager.get_or_create_image(relative_path, image_path.name, img.width, img.height)
try:
self.db_manager.set_image_source(image_id, "annotation_tab")
except Exception:
pass
# Overwrite existing annotations for this image.
try:
deleted = self.db_manager.delete_annotations_for_image(image_id)
except AttributeError:
# Safety fallback if older DBManager is used.
deleted = 0
if deleted > 0:
overwritten_images += 1
# Parse YOLO lines and insert as annotations.
parsed = self._parse_yolo_annotation_file(label_path)
if not parsed:
# Empty/invalid label file: treat as "clear" operation (already deleted above)
imported_images += 1
continue
db_classes = self.db_manager.get_object_classes() or []
classes_by_index = {idx: row for idx, row in enumerate(db_classes)}
for class_index, bbox, poly in parsed:
class_row = classes_by_index.get(int(class_index))
if not class_row:
skipped += 1
errors.append(
f"Unknown class index {class_index} in {label_path.name}. "
"Create object classes in the UI first (class index is based on DB ordering)."
)
continue
ann_id = self.db_manager.add_annotation(
image_id=image_id,
class_id=int(class_row["id"]),
bbox=bbox,
annotator="import",
segmentation_mask=poly,
verified=False,
)
if ann_id:
imported_annotations += 1
imported_images += 1
# If we imported for the currently open image, reload.
if self.current_image_id and int(self.current_image_id) == int(image_id):
self._load_annotations_for_current_image()
except ImageLoadError as exc:
skipped += 1
errors.append(f"Failed to load image for {label_path.name}: {exc}")
except Exception as exc:
skipped += 1
errors.append(f"Import failed for {label_path.name}: {exc}")
# Refresh annotated images list.
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
summary = (
f"Imported files: {len(file_paths)}\n"
f"Images processed: {imported_images}\n"
f"Annotations inserted: {imported_annotations}\n"
f"Images overwritten (had existing annotations): {overwritten_images}\n"
f"Skipped: {skipped}"
)
if errors:
# Cap error details to avoid huge dialogs.
details = "\n".join(errors[:25])
if len(errors) > 25:
details += f"\n... and {len(errors) - 25} more"
summary += "\n\nDetails:\n" + details
QMessageBox.information(self, "Import Annotations", summary)
def _infer_corresponding_image_path(self, label_path: Path) -> Path | None:
"""Infer image path from YOLO label file path.
Requirement: image(s) live in an `images/` folder located in the label file's parent directory.
Example:
/dataset/train/labels/img123.txt -> /dataset/train/images/img123.(any supported ext)
"""
parent = label_path.parent
images_dir = parent.parent / "images"
stem = label_path.stem
# 1) Direct stem match in images dir (any supported extension)
for ext in Image.SUPPORTED_EXTENSIONS:
candidate = images_dir / f"{stem}{ext}"
if candidate.exists() and candidate.is_file():
return candidate
# 2) Fallback: repository-root search by filename
repo_root = (self.config_manager.get_image_repository_path() or "").strip()
if repo_root:
root = Path(repo_root).expanduser()
try:
if root.exists():
for ext in Image.SUPPORTED_EXTENSIONS:
filename = f"{stem}{ext}"
for match in root.rglob(filename):
if match.is_file():
return match.resolve()
except Exception:
pass
return None
def _compute_relative_path_for_repo(self, image_path: Path) -> str:
"""Compute a stable `relative_path` suitable for DB storage.
Policy:
- If an image repository root is configured and the image is under it, store a repo-relative path.
- Otherwise, store an absolute resolved path so the image can be reopened later.
"""
repo_root = (self.config_manager.get_image_repository_path() or "").strip()
try:
if repo_root:
repo_root_path = Path(repo_root).expanduser().resolve()
img_resolved = image_path.expanduser().resolve()
if img_resolved.is_relative_to(repo_root_path):
return img_resolved.relative_to(repo_root_path).as_posix()
except Exception:
pass
try:
return str(image_path.expanduser().resolve())
except Exception:
return str(image_path)
def _parse_yolo_annotation_file(
self, label_path: Path
) -> list[tuple[int, tuple[float, float, float, float], list[list[float]] | None]]:
"""Parse a YOLO .txt label file.
Supports:
- YOLO segmentation polygon format: "class x1 y1 x2 y2 ..." (normalized)
- YOLO bbox format: "class x_center y_center width height" (normalized)
Returns:
List of (class_index, bbox_xyxy_norm, segmentation_mask_db)
Where segmentation_mask_db is [[y_norm, x_norm], ...] or None.
"""
out: list[tuple[int, tuple[float, float, float, float], list[list[float]] | None]] = []
try:
raw = label_path.read_text(encoding="utf-8").splitlines()
except OSError as exc:
logger.error(f"Failed to read label file {label_path}: {exc}")
return out
for line in raw:
stripped = line.strip()
if not stripped:
continue
parts = stripped.split()
if len(parts) < 5:
# not enough for bbox
continue
try:
class_idx = int(float(parts[0]))
coords = [float(x) for x in parts[1:]]
except Exception:
continue
# Segmentation polygon format (>= 6 values)
if len(coords) >= 6:
# bbox is not explicitly present in this format in our importer; compute from polygon.
xs = coords[0::2]
ys = coords[1::2]
if not xs or not ys:
continue
x_min, x_max = min(xs), max(xs)
y_min, y_max = min(ys), max(ys)
bbox = (
self._clamp01(x_min),
self._clamp01(y_min),
self._clamp01(x_max),
self._clamp01(y_max),
)
# Convert to DB polyline convention: [[y_norm, x_norm], ...]
poly: list[list[float]] = []
for x, y in zip(xs, ys):
poly.append([self._clamp01(float(y)), self._clamp01(float(x))])
# Ensure closure for consistency (optional)
if poly and poly[0] != poly[-1]:
poly.append(list(poly[0]))
out.append((class_idx, bbox, poly))
continue
# bbox format: xc yc w h
if len(coords) >= 4:
xc, yc, w, h = coords[:4]
x_min = xc - w / 2.0
y_min = yc - h / 2.0
x_max = xc + w / 2.0
y_max = yc + h / 2.0
bbox = (
self._clamp01(float(x_min)),
self._clamp01(float(y_min)),
self._clamp01(float(x_max)),
self._clamp01(float(y_max)),
)
out.append((class_idx, bbox, None))
return out
@staticmethod
def _clamp01(value: float) -> float:
if value < 0.0:
return 0.0
if value > 1.0:
return 1.0
return float(value)
def _load_image(self): def _load_image(self):
"""Load and display an image file.""" """Load and display an image file."""
# Get last opened directory from QSettings # Get last opened directory from QSettings
@@ -166,13 +583,32 @@ class AnnotationTab(QWidget):
settings.setValue("annotation_tab/last_directory", str(Path(file_path).parent)) settings.setValue("annotation_tab/last_directory", str(Path(file_path).parent))
# Get or create image in database # Get or create image in database
relative_path = str(Path(file_path).name) # Simplified for now repo_root = self.config_manager.get_image_repository_path()
relative_path: str
try:
if repo_root:
repo_root_path = Path(repo_root).expanduser().resolve()
file_resolved = Path(file_path).expanduser().resolve()
if file_resolved.is_relative_to(repo_root_path):
relative_path = file_resolved.relative_to(repo_root_path).as_posix()
else:
# Fallback: store filename only to avoid leaking absolute paths.
relative_path = file_resolved.name
else:
relative_path = str(Path(file_path).name)
except Exception:
relative_path = str(Path(file_path).name)
self.current_image_id = self.db_manager.get_or_create_image( self.current_image_id = self.db_manager.get_or_create_image(
relative_path, relative_path,
Path(file_path).name, Path(file_path).name,
self.current_image.width, self.current_image.width,
self.current_image.height, self.current_image.height,
) )
# Mark as managed by Annotation tab so it appears in the left list.
try:
self.db_manager.set_image_source(int(self.current_image_id), "annotation_tab")
except Exception:
pass
# Display image using the AnnotationCanvasWidget # Display image using the AnnotationCanvasWidget
self.annotation_canvas.load_image(self.current_image) self.annotation_canvas.load_image(self.current_image)
@@ -180,6 +616,9 @@ class AnnotationTab(QWidget):
# Load and display any existing annotations for this image # Load and display any existing annotations for this image
self._load_annotations_for_current_image() self._load_annotations_for_current_image()
# Update annotated images list (newly annotated image added/selected).
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
# Update info label # Update info label
self._update_image_info() self._update_image_info()
@@ -275,6 +714,9 @@ class AnnotationTab(QWidget):
# Reload annotations from DB and redraw (respecting current class filter) # Reload annotations from DB and redraw (respecting current class filter)
self._load_annotations_for_current_image() self._load_annotations_for_current_image()
# Update list counts.
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
except Exception as e: except Exception as e:
logger.error(f"Failed to save annotation: {e}") logger.error(f"Failed to save annotation: {e}")
QMessageBox.critical(self, "Error", f"Failed to save annotation:\n{str(e)}") QMessageBox.critical(self, "Error", f"Failed to save annotation:\n{str(e)}")
@@ -405,6 +847,9 @@ class AnnotationTab(QWidget):
self.annotation_tools.set_has_selected_annotation(False) self.annotation_tools.set_has_selected_annotation(False)
self._load_annotations_for_current_image() self._load_annotations_for_current_image()
# Update list counts.
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
except Exception as e: except Exception as e:
logger.error(f"Failed to delete annotations: {e}") logger.error(f"Failed to delete annotations: {e}")
QMessageBox.critical( QMessageBox.critical(
@@ -521,4 +966,173 @@ class AnnotationTab(QWidget):
def refresh(self): def refresh(self):
"""Refresh the tab.""" """Refresh the tab."""
pass self._refresh_annotated_images_list(select_image_id=self.current_image_id)
# ==================== Annotated images list ====================
def _refresh_annotated_images_list(self, select_image_id: int | None = None) -> None:
"""Reload annotated-images list from the database."""
if not hasattr(self, "annotated_images_table"):
return
# Preserve selection if possible
desired_id = select_image_id if select_image_id is not None else self.current_image_id
name_filter = ""
if hasattr(self, "annotated_filter_edit"):
name_filter = self.annotated_filter_edit.text().strip()
try:
rows = self.db_manager.get_images_summary(name_filter=name_filter, source_filter="annotation_tab")
except Exception as exc:
logger.error(f"Failed to load images summary: {exc}")
rows = []
sorting_enabled = self.annotated_images_table.isSortingEnabled()
self.annotated_images_table.setSortingEnabled(False)
self.annotated_images_table.blockSignals(True)
try:
self.annotated_images_table.setRowCount(len(rows))
for r, entry in enumerate(rows):
image_name = str(entry.get("filename") or "")
count = int(entry.get("annotation_count") or 0)
rel_path = str(entry.get("relative_path") or "")
name_item = QTableWidgetItem(image_name)
# Tooltip shows full path of the image (best-effort: repository_root + relative_path)
full_path = rel_path
repo_root = self.config_manager.get_image_repository_path()
if repo_root and rel_path and not Path(rel_path).is_absolute():
try:
full_path = str((Path(repo_root) / rel_path).resolve())
except Exception:
full_path = str(Path(repo_root) / rel_path)
name_item.setToolTip(full_path)
name_item.setData(Qt.UserRole, int(entry.get("id")))
name_item.setData(Qt.UserRole + 1, rel_path)
count_item = QTableWidgetItem()
# Use EditRole to ensure numeric sorting.
count_item.setData(Qt.EditRole, count)
count_item.setData(Qt.UserRole, int(entry.get("id")))
count_item.setData(Qt.UserRole + 1, rel_path)
self.annotated_images_table.setItem(r, 0, name_item)
self.annotated_images_table.setItem(r, 1, count_item)
# Re-select desired row
if desired_id is not None:
for r in range(self.annotated_images_table.rowCount()):
item = self.annotated_images_table.item(r, 0)
if item and item.data(Qt.UserRole) == desired_id:
self.annotated_images_table.selectRow(r)
break
finally:
self.annotated_images_table.blockSignals(False)
self.annotated_images_table.setSortingEnabled(sorting_enabled)
def _on_annotated_image_selected(self) -> None:
"""When user clicks an item in the list, load that image in the annotation canvas."""
selected = self.annotated_images_table.selectedItems()
if not selected:
return
# Row selection -> take the first column item
row = self.annotated_images_table.currentRow()
item = self.annotated_images_table.item(row, 0)
if not item:
return
image_id = item.data(Qt.UserRole)
rel_path = item.data(Qt.UserRole + 1) or ""
if not image_id:
return
image_path = self._resolve_image_path_for_relative_path(rel_path)
if not image_path:
QMessageBox.warning(
self,
"Image Not Found",
"Unable to locate image on disk for:\n"
f"{rel_path}\n\n"
"Tip: set Settings → Image repository path to the folder containing your images.",
)
return
try:
self.current_image = Image(image_path)
self.current_image_path = image_path
self.current_image_id = int(image_id)
self.annotation_canvas.load_image(self.current_image)
self._load_annotations_for_current_image()
self._update_image_info()
except ImageLoadError as exc:
logger.error(f"Failed to load image '{image_path}': {exc}")
QMessageBox.critical(self, "Error Loading Image", f"Failed to load image:\n{exc}")
except Exception as exc:
logger.error(f"Unexpected error loading image '{image_path}': {exc}")
QMessageBox.critical(self, "Error", f"Unexpected error:\n{exc}")
def _resolve_image_path_for_relative_path(self, relative_path: str) -> str | None:
"""Best-effort conversion from a DB relative_path to an on-disk file path."""
rel = (relative_path or "").strip()
if not rel:
return None
candidates: list[Path] = []
# 1) Repository root + relative
repo_root = (self.config_manager.get_image_repository_path() or "").strip()
if repo_root:
candidates.append(Path(repo_root) / rel)
# 2) If the DB path is absolute, try it directly.
candidates.append(Path(rel))
# 3) Try the directory of the currently loaded image (helps when DB stores only filenames)
if self.current_image_path:
try:
candidates.append(Path(self.current_image_path).expanduser().resolve().parent / Path(rel).name)
except Exception:
pass
# 4) Try the last directory used by the annotation file picker
try:
settings = QSettings("microscopy_app", "object_detection")
last_dir = settings.value("annotation_tab/last_directory", None)
if last_dir:
candidates.append(Path(str(last_dir)) / Path(rel).name)
except Exception:
pass
# 4b) Try the last directory used by the image import picker
try:
settings = QSettings("microscopy_app", "object_detection")
last_import_dir = settings.value("annotation_tab/last_image_import_directory", None)
if last_import_dir:
candidates.append(Path(str(last_import_dir)) / Path(rel).name)
except Exception:
pass
for p in candidates:
try:
expanded = p.expanduser()
if expanded.exists() and expanded.is_file():
return str(expanded.resolve())
except Exception:
continue
# 5) Fallback: search by filename within repository root.
filename = Path(rel).name
if repo_root and filename:
root = Path(repo_root).expanduser()
try:
if root.exists():
for match in root.rglob(filename):
if match.is_file():
return str(match.resolve())
except Exception as exc:
logger.debug(f"Search for {filename} under {root} failed: {exc}")
return None

View File

@@ -40,6 +40,12 @@ class ResultsTab(QWidget):
self.db_manager = db_manager self.db_manager = db_manager
self.config_manager = config_manager self.config_manager = config_manager
# Pagination
self.page_size = 200
self.current_page = 0 # 0-based
self.total_runs = 0
self.total_pages = 0
self.detection_summary: List[Dict] = [] self.detection_summary: List[Dict] = []
self.current_selection: Optional[Dict] = None self.current_selection: Optional[Dict] = None
self.current_image: Optional[Image] = None self.current_image: Optional[Image] = None
@@ -66,6 +72,20 @@ class ResultsTab(QWidget):
self.refresh_btn.clicked.connect(self.refresh) self.refresh_btn.clicked.connect(self.refresh)
controls_layout.addWidget(self.refresh_btn) controls_layout.addWidget(self.refresh_btn)
self.prev_page_btn = QPushButton("◀ Prev")
self.prev_page_btn.setToolTip("Previous page")
self.prev_page_btn.clicked.connect(self._prev_page)
controls_layout.addWidget(self.prev_page_btn)
self.next_page_btn = QPushButton("Next ▶")
self.next_page_btn.setToolTip("Next page")
self.next_page_btn.clicked.connect(self._next_page)
controls_layout.addWidget(self.next_page_btn)
self.page_label = QLabel("Page 1/1")
self.page_label.setMinimumWidth(100)
controls_layout.addWidget(self.page_label)
self.delete_all_btn = QPushButton("Delete All Detections") self.delete_all_btn = QPushButton("Delete All Detections")
self.delete_all_btn.setToolTip( self.delete_all_btn.setToolTip(
"Permanently delete ALL detections from the database.\n" "This cannot be undone." "Permanently delete ALL detections from the database.\n" "This cannot be undone."
@@ -183,6 +203,8 @@ class ResultsTab(QWidget):
def refresh(self): def refresh(self):
"""Refresh the detection list and preview.""" """Refresh the detection list and preview."""
# Reset to first page on refresh.
self.current_page = 0
self._load_detection_summary() self._load_detection_summary()
self._populate_results_table() self._populate_results_table()
self.current_selection = None self.current_selection = None
@@ -193,57 +215,135 @@ class ResultsTab(QWidget):
if hasattr(self, "export_labels_btn"): if hasattr(self, "export_labels_btn"):
self.export_labels_btn.setEnabled(False) self.export_labels_btn.setEnabled(False)
self._update_pagination_controls()
def _prev_page(self):
"""Go to previous results page."""
if self.current_page <= 0:
return
self.current_page -= 1
self._load_detection_summary()
self._populate_results_table()
self._update_pagination_controls()
def _next_page(self):
"""Go to next results page."""
if self.total_pages and self.current_page >= (self.total_pages - 1):
return
self.current_page += 1
self._load_detection_summary()
self._populate_results_table()
self._update_pagination_controls()
def _update_pagination_controls(self):
"""Update pagination label/button enabled state."""
# Default state (safe)
total_pages = max(int(self.total_pages or 0), 1)
current_page = min(max(int(self.current_page or 0), 0), total_pages - 1)
self.current_page = current_page
self.page_label.setText(f"Page {current_page + 1}/{total_pages}")
self.prev_page_btn.setEnabled(current_page > 0)
self.next_page_btn.setEnabled(current_page < (total_pages - 1))
def _load_detection_summary(self): def _load_detection_summary(self):
"""Load latest detection summaries grouped by image + model.""" """Load latest detection summaries grouped by image + model."""
try: try:
detections = self.db_manager.get_detections(limit=500) # Prefer run summaries (supports zero-detection runs). Fall back to legacy
summary_map: Dict[tuple, Dict] = {} # detection aggregation if the DB/table isn't available.
try:
self.total_runs = int(self.db_manager.get_detection_run_total())
except Exception:
self.total_runs = 0
for det in detections: self.total_pages = (self.total_runs + self.page_size - 1) // self.page_size if self.total_runs > 0 else 1
key = (det["image_id"], det["model_id"])
metadata = det.get("metadata") or {} offset = int(self.current_page) * int(self.page_size)
entry = summary_map.setdefault( runs = self.db_manager.get_detection_run_summaries(
key, limit=int(self.page_size),
offset=offset,
)
summary: List[Dict] = []
for run in runs:
meta = run.get("metadata") or {}
classes_raw = run.get("classes")
classes = set()
if isinstance(classes_raw, str) and classes_raw.strip():
classes = {c.strip() for c in classes_raw.split(",") if c.strip()}
summary.append(
{ {
"image_id": det["image_id"], "image_id": run.get("image_id"),
"model_id": det["model_id"], "model_id": run.get("model_id"),
"image_path": det.get("image_path"), "image_path": run.get("image_path"),
"image_filename": det.get("image_filename") or det.get("image_path"), "image_filename": run.get("image_filename") or run.get("image_path"),
"model_name": det.get("model_name", ""), "model_name": run.get("model_name", ""),
"model_version": det.get("model_version", ""), "model_version": run.get("model_version", ""),
"last_detected": det.get("detected_at"), "last_detected": run.get("detected_at"),
"count": 0, "count": int(run.get("count") or 0),
"classes": set(), "classes": classes,
"source_path": metadata.get("source_path"), "source_path": meta.get("source_path"),
"repository_root": metadata.get("repository_root"), "repository_root": meta.get("repository_root"),
}, }
) )
entry["count"] += 1 self.detection_summary = summary
if det.get("detected_at") and (
not entry.get("last_detected") or str(det.get("detected_at")) > str(entry.get("last_detected"))
):
entry["last_detected"] = det.get("detected_at")
if det.get("class_name"):
entry["classes"].add(det["class_name"])
if metadata.get("source_path") and not entry.get("source_path"):
entry["source_path"] = metadata.get("source_path")
if metadata.get("repository_root") and not entry.get("repository_root"):
entry["repository_root"] = metadata.get("repository_root")
self.detection_summary = sorted(
summary_map.values(),
key=lambda x: str(x.get("last_detected") or ""),
reverse=True,
)
except Exception as e: except Exception as e:
logger.error(f"Failed to load detection summary: {e}") logger.error(f"Failed to load detection run summaries, falling back: {e}")
QMessageBox.critical( # Disable pagination if we can't page via detection_runs.
self, self.total_runs = 0
"Error", self.total_pages = 1
f"Failed to load detection results:\n{str(e)}", self.current_page = 0
) try:
self.detection_summary = [] detections = self.db_manager.get_detections(limit=int(self.page_size))
summary_map: Dict[tuple, Dict] = {}
for det in detections:
key = (det["image_id"], det["model_id"])
metadata = det.get("metadata") or {}
entry = summary_map.setdefault(
key,
{
"image_id": det["image_id"],
"model_id": det["model_id"],
"image_path": det.get("image_path"),
"image_filename": det.get("image_filename") or det.get("image_path"),
"model_name": det.get("model_name", ""),
"model_version": det.get("model_version", ""),
"last_detected": det.get("detected_at"),
"count": 0,
"classes": set(),
"source_path": metadata.get("source_path"),
"repository_root": metadata.get("repository_root"),
},
)
entry["count"] += 1
if det.get("detected_at") and (
not entry.get("last_detected") or str(det.get("detected_at")) > str(entry.get("last_detected"))
):
entry["last_detected"] = det.get("detected_at")
if det.get("class_name"):
entry["classes"].add(det["class_name"])
if metadata.get("source_path") and not entry.get("source_path"):
entry["source_path"] = metadata.get("source_path")
if metadata.get("repository_root") and not entry.get("repository_root"):
entry["repository_root"] = metadata.get("repository_root")
self.detection_summary = sorted(
summary_map.values(),
key=lambda x: str(x.get("last_detected") or ""),
reverse=True,
)
except Exception as inner:
logger.error(f"Failed to load detection summary: {inner}")
QMessageBox.critical(
self,
"Error",
f"Failed to load detection results:\n{str(inner)}",
)
self.detection_summary = []
def _populate_results_table(self): def _populate_results_table(self):
"""Populate the table widget with detection summaries.""" """Populate the table widget with detection summaries."""

View File

@@ -905,9 +905,14 @@ class TrainingTab(QWidget):
if stats["registered_images"]: if stats["registered_images"]:
message += f" {stats['registered_images']} image(s) had database-backed annotations." message += f" {stats['registered_images']} image(s) had database-backed annotations."
if stats["missing_records"]: if stats["missing_records"]:
message += ( preserved = stats.get("preserved_existing_labels", 0)
f" {stats['missing_records']} image(s) had no database entry; empty label files were written." if preserved:
) message += (
f" {stats['missing_records']} image(s) had no database annotations; "
f"preserved {preserved} existing label file(s) (no overwrite)."
)
else:
message += f" {stats['missing_records']} image(s) had no database annotations; empty label files were written."
split_messages.append(message) split_messages.append(message)
for msg in split_messages: for msg in split_messages:
@@ -929,6 +934,7 @@ class TrainingTab(QWidget):
processed_images = 0 processed_images = 0
registered_images = 0 registered_images = 0
missing_records = 0 missing_records = 0
preserved_existing_labels = 0
total_annotations = 0 total_annotations = 0
for image_file in images_dir.rglob("*"): for image_file in images_dir.rglob("*"):
@@ -950,6 +956,12 @@ class TrainingTab(QWidget):
else: else:
missing_records += 1 missing_records += 1
# If the database has no entry for this image, do not overwrite an existing label file
# with an empty one (preserve any manually created labels on disk).
if not found and label_path.exists():
preserved_existing_labels += 1
continue
annotations_written = 0 annotations_written = 0
with open(label_path, "w", encoding="utf-8") as handle: with open(label_path, "w", encoding="utf-8") as handle:
for entry in annotation_entries: for entry in annotation_entries:
@@ -979,6 +991,7 @@ class TrainingTab(QWidget):
"processed_images": processed_images, "processed_images": processed_images,
"registered_images": registered_images, "registered_images": registered_images,
"missing_records": missing_records, "missing_records": missing_records,
"preserved_existing_labels": preserved_existing_labels,
"total_annotations": total_annotations, "total_annotations": total_annotations,
} }
@@ -1008,6 +1021,10 @@ class TrainingTab(QWidget):
resolved_image = image_path.resolve() resolved_image = image_path.resolve()
candidates: List[str] = [] candidates: List[str] = []
# Some DBs store absolute paths in `images.relative_path`.
# Include the absolute resolved path as a lookup candidate.
candidates.append(resolved_image.as_posix())
for base in (dataset_root, images_dir): for base in (dataset_root, images_dir):
try: try:
relative = resolved_image.relative_to(base.resolve()).as_posix() relative = resolved_image.relative_to(base.resolve()).as_posix()
@@ -1032,6 +1049,13 @@ class TrainingTab(QWidget):
return False, [] return False, []
annotations = self.db_manager.get_annotations_for_image(image_row["id"]) or [] annotations = self.db_manager.get_annotations_for_image(image_row["id"]) or []
# Treat "found" as "has database-backed annotations".
# If the image exists in DB but has no annotations yet, we don't want to overwrite
# an existing label file on disk with an empty one.
if not annotations:
return False, []
yolo_entries: List[Dict[str, Any]] = [] yolo_entries: List[Dict[str, Any]] = []
for ann in annotations: for ann in annotations:

View File

@@ -84,25 +84,19 @@ class InferenceEngine:
# Save detections to database, replacing any previous results for this image/model # Save detections to database, replacing any previous results for this image/model
if save_to_db: if save_to_db:
deleted_count = self.db_manager.delete_detections_for_image( deleted_count = self.db_manager.delete_detections_for_image(image_id, self.model_id)
image_id, self.model_id
)
if detections: if detections:
detection_records = [] detection_records = []
for det in detections: for det in detections:
# Use normalized bbox from detection # Use normalized bbox from detection
bbox_normalized = det[ bbox_normalized = det["bbox_normalized"] # [x_min, y_min, x_max, y_max]
"bbox_normalized"
] # [x_min, y_min, x_max, y_max]
metadata = { metadata = {
"class_id": det["class_id"], "class_id": det["class_id"],
"source_path": str(Path(image_path).resolve()), "source_path": str(Path(image_path).resolve()),
} }
if repository_root: if repository_root:
metadata["repository_root"] = str( metadata["repository_root"] = str(Path(repository_root).resolve())
Path(repository_root).resolve()
)
record = { record = {
"image_id": image_id, "image_id": image_id,
@@ -115,16 +109,27 @@ class InferenceEngine:
} }
detection_records.append(record) detection_records.append(record)
inserted_count = self.db_manager.add_detections_batch( inserted_count = self.db_manager.add_detections_batch(detection_records)
detection_records logger.info(f"Saved {inserted_count} detections to database (replaced {deleted_count})")
)
logger.info(
f"Saved {inserted_count} detections to database (replaced {deleted_count})"
)
else: else:
logger.info( logger.info(f"Detection run removed {deleted_count} stale entries but produced no new detections")
f"Detection run removed {deleted_count} stale entries but produced no new detections"
# Always store a run summary so the Results tab can show zero-detection runs.
try:
run_metadata = {
"source_path": str(Path(image_path).resolve()),
}
if repository_root:
run_metadata["repository_root"] = str(Path(repository_root).resolve())
self.db_manager.upsert_detection_run(
image_id=image_id,
model_id=self.model_id,
count=len(detections),
metadata=run_metadata,
) )
except Exception as exc:
# Non-fatal: detection records may still be present.
logger.warning(f"Failed to store detection run summary: {exc}")
return { return {
"success": True, "success": True,
@@ -232,9 +237,7 @@ class InferenceEngine:
for det in detections: for det in detections:
# Get color for this class # Get color for this class
class_name = det["class_name"] class_name = det["class_name"]
color_hex = bbox_colors.get( color_hex = bbox_colors.get(class_name, bbox_colors.get("default", "#00FF00"))
class_name, bbox_colors.get("default", "#00FF00")
)
color = self._hex_to_bgr(color_hex) color = self._hex_to_bgr(color_hex)
# Draw segmentation mask if available and requested # Draw segmentation mask if available and requested
@@ -243,10 +246,7 @@ class InferenceEngine:
if mask_normalized and len(mask_normalized) > 0: if mask_normalized and len(mask_normalized) > 0:
# Convert normalized coordinates to absolute pixels # Convert normalized coordinates to absolute pixels
mask_points = np.array( mask_points = np.array(
[ [[int(pt[0] * width), int(pt[1] * height)] for pt in mask_normalized],
[int(pt[0] * width), int(pt[1] * height)]
for pt in mask_normalized
],
dtype=np.int32, dtype=np.int32,
) )
@@ -270,9 +270,7 @@ class InferenceEngine:
label = f"{class_name} {det['confidence']:.2f}" label = f"{class_name} {det['confidence']:.2f}"
# Draw label background # Draw label background
(label_w, label_h), baseline = cv2.getTextSize( (label_w, label_h), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
)
cv2.rectangle( cv2.rectangle(
img, img,
(x1, y1 - label_h - baseline - 5), (x1, y1 - label_h - baseline - 5),

View File

@@ -186,7 +186,7 @@ class YOLOWrapper:
raise RuntimeError(f"Failed to load model from {self.model_path}") raise RuntimeError(f"Failed to load model from {self.model_path}")
prepared_source, cleanup_path = self._prepare_source(source) prepared_source, cleanup_path = self._prepare_source(source)
imgsz = 1088 imgsz = 640 # 1088
try: try:
logger.info(f"Running inference on {source} -> prepared_source {prepared_source}") logger.info(f"Running inference on {source} -> prepared_source {prepared_source}")
results = self.model.predict( results = self.model.predict(

View File

@@ -6,6 +6,9 @@ import cv2
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
from skimage.restoration import rolling_ball
from skimage.filters import threshold_otsu
from scipy.ndimage import median_filter, gaussian_filter
from src.utils.logger import get_logger from src.utils.logger import get_logger
from src.utils.file_utils import validate_file_path, is_image_file from src.utils.file_utils import validate_file_path, is_image_file
@@ -37,6 +40,8 @@ def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.5) -> np.ndarray:
a1[a1 > p999] = p999 a1[a1 > p999] = p999
a1 /= a1.max() a1 /= a1.max()
# print("Using get_pseudo_rgb")
if 1: if 1:
a2 = a1.copy() a2 = a1.copy()
a2 = a2**gamma a2 = a2**gamma
@@ -46,14 +51,75 @@ def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.5) -> np.ndarray:
p9999 = np.percentile(a3, 99.99) p9999 = np.percentile(a3, 99.99)
a3[a3 > p9999] = p9999 a3[a3 > p9999] = p9999
a3 /= a3.max() a3 /= a3.max()
out = np.stack([a1, a2, a3], axis=0)
else:
out = np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
return out
# return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0) # return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
# return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0) # return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
out = np.stack([a1, a2, a3], axis=0)
# print(any(np.isnan(out).flatten())) # print(any(np.isnan(out).flatten()))
def _get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.3) -> np.ndarray:
"""
Convert a grayscale image to a pseudo-RGB image using a gamma correction.
Args:
arr: Input grayscale image as numpy array
Returns:
Pseudo-RGB image as numpy array
"""
if arr.ndim != 2:
raise ValueError("Input array must be a grayscale image with shape (H, W)")
radius = 80
# bg = rolling_ball(arr, radius=radius)
a1 = arr.copy().astype(np.float32)
a1 = a1.astype(np.float32)
# a1 -= bg
# a1[a1 < 0] = 0
# a1 -= np.percentile(a1, 2)
# a1[a1 < 0] = 0
p999 = np.percentile(a1, 99.99)
a1[a1 > p999] = p999
a1 /= a1.max()
print("Using get_pseudo_rgb")
if 1:
a2 = a1.copy()
_a2 = a2**gamma
thr = threshold_otsu(_a2)
mask = gaussian_filter((_a2 > thr).astype(np.float32), sigma=5)
mask[mask > 0.0001] = 1
mask[mask <= 0.0001] = 0
a2 *= mask
# bg2 = rolling_ball(a2, radius=radius)
# a2 -= bg2
a2 -= np.percentile(a2, 2)
a2[a2 < 0] = 0
a2 /= a2.max()
a3 = a1.copy()
p9999 = np.percentile(a3, 99.99)
a3[a3 > p9999] = p9999
a3 /= a3.max()
out = np.stack([a1, a2, _a2], axis=0)
else:
out = np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
return out return out
# return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
# return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
# print(any(np.isnan(out).flatten()))
class ImageLoadError(Exception): class ImageLoadError(Exception):
"""Exception raised when an image cannot be loaded.""" """Exception raised when an image cannot be loaded."""
@@ -330,17 +396,20 @@ class Image:
return self._channels >= 3 return self._channels >= 3
def save(self, path: Union[str, Path], pseudo_rgb: bool = True) -> None: def save(self, path: Union[str, Path], pseudo_rgb: bool = True) -> None:
# print("Image.save", self.data.shape)
if self.channels == 1: if self.channels == 1:
# print("Image.save grayscale")
if pseudo_rgb: if pseudo_rgb:
img = get_pseudo_rgb(self.data) img = get_pseudo_rgb(self.data)
print("Image.save", img.shape) # print("Image.save", img.shape)
else: else:
img = np.repeat(self.data, 3, axis=2) img = np.repeat(self.data, 3, axis=2)
# print("Image.save no pseudo", img.shape)
else: else:
raise NotImplementedError("Only grayscale images are supported for now.") raise NotImplementedError("Only grayscale images are supported for now.")
# print("Image.save imwrite", img.shape)
imwrite(path, data=img) imwrite(path, data=img)
def __repr__(self) -> str: def __repr__(self) -> str:

View File

@@ -97,7 +97,7 @@ class UT:
class_index: int = 0, class_index: int = 0,
): ):
"""Export rois to a file""" """Export rois to a file"""
with open(path / subfolder / f"{self.stem}.txt", "w") as f: with open(path / subfolder / f"{PREFIX}-{self.stem}.txt", "w") as f:
for i, roi in enumerate(self.rois): for i, roi in enumerate(self.rois):
rc = roi.subpixel_coordinates rc = roi.subpixel_coordinates
if rc is None: if rc is None:
@@ -129,8 +129,8 @@ class UT:
self.image = np.max(self.image[channel], axis=0) self.image = np.max(self.image[channel], axis=0)
print(self.image.shape) print(self.image.shape)
print(path / subfolder / f"{self.stem}.tif") print(path / subfolder / f"{PREFIX}_{self.stem}.tif")
with TiffWriter(path / subfolder / f"{self.stem}.tif") as tif: with TiffWriter(path / subfolder / f"{PREFIX}-{self.stem}.tif") as tif:
tif.write(self.image) tif.write(self.image)
@@ -145,11 +145,19 @@ if __name__ == "__main__":
action="store_false", action="store_false",
help="Source does not have labels, export only images", help="Source does not have labels, export only images",
) )
parser.add_argument("--prefix", help="Prefix for output files")
args = parser.parse_args() args = parser.parse_args()
PREFIX = args.prefix
# print(args) # print(args)
# aa # aa
# for path in args.input:
# print(path)
# ut = UT(path, args.no_labels)
# ut.export_image(args.output, plane_mode="max projection", channel=0)
# ut.export_rois(args.output, class_index=0)
for path in args.input: for path in args.input:
print("Path:", path) print("Path:", path)
if not args.no_labels: if not args.no_labels:

View File

@@ -5,7 +5,7 @@ from tifffile import imread, imwrite
from shapely.geometry import LineString from shapely.geometry import LineString
from copy import deepcopy from copy import deepcopy
from scipy.ndimage import zoom from scipy.ndimage import zoom
from skimage.restoration import rolling_ball
# debug # debug
from src.utils.image import Image from src.utils.image import Image
@@ -160,8 +160,11 @@ class YoloLabelReader:
class ImageSplitter: class ImageSplitter:
def __init__(self, image_path: Path, label_path: Path): def __init__(self, image_path: Path, label_path: Path, subtract_bg: bool = False):
self.image = imread(image_path) self.image = imread(image_path)
if subtract_bg:
self.image = self.image - rolling_ball(self.image, radius=12)
self.image[self.image < 0] = 0
self.image_path = image_path self.image_path = image_path
self.label_path = label_path self.label_path = label_path
if not label_path.exists(): if not label_path.exists():
@@ -273,13 +276,14 @@ def main(args):
if args.output: if args.output:
args.output.mkdir(exist_ok=True, parents=True) args.output.mkdir(exist_ok=True, parents=True)
(args.output / "images").mkdir(exist_ok=True) (args.output / "images").mkdir(exist_ok=True)
(args.output / "images-zoomed").mkdir(exist_ok=True) # (args.output / "images-zoomed").mkdir(exist_ok=True)
(args.output / "labels").mkdir(exist_ok=True) (args.output / "labels").mkdir(exist_ok=True)
for image_path in (args.input / "images").glob("*.tif"): for image_path in (args.input / "images").glob("*.tif"):
data = ImageSplitter( data = ImageSplitter(
image_path=image_path, image_path=image_path,
label_path=(args.input / "labels" / image_path.stem).with_suffix(".txt"), label_path=(args.input / "labels" / image_path.stem).with_suffix(".txt"),
subtract_bg=args.subtract_bg,
) )
if args.split_around_label: if args.split_around_label:
@@ -332,10 +336,20 @@ def main(args):
if labels is not None: if labels is not None:
with open(args.output / "labels" / f"{image_path.stem}_{tile_reference}.txt", "w") as f: with open(args.output / "labels" / f"{image_path.stem}_{tile_reference}.txt", "w") as f:
print(
f"Writing {len(labels)} labels to {args.output / 'labels' / f'{image_path.stem}_{tile_reference}.txt'}"
)
for label in labels: for label in labels:
# label.offset_label(tile.shape[1], tile.shape[0]) # label.offset_label(tile.shape[1], tile.shape[0])
f.write(label.to_string() + "\n") f.write(label.to_string() + "\n")
# { debug
if debug:
print(label.to_string())
# } debug
# break
# break
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
@@ -363,6 +377,7 @@ if __name__ == "__main__":
default=67, default=67,
help="Padding around the label when splitting around the label.", help="Padding around the label when splitting around the label.",
) )
parser.add_argument("-bg", "--subtract-bg", action="store_true", help="Subtract background from the image.")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)