5 Commits

11 changed files with 1551 additions and 176 deletions

View File

@@ -40,8 +40,15 @@ class DatabaseManager:
conn = self.get_connection() conn = self.get_connection()
try: try:
# Check if annotations table needs migration # Pre-schema migrations.
# These must run BEFORE executing schema.sql because schema.sql may
# contain CREATE INDEX statements referencing newly added columns.
#
# 1) Check if annotations table needs migration (may drop an old table)
self._migrate_annotations_table(conn) self._migrate_annotations_table(conn)
# 2) Ensure images table has the required columns (e.g. 'source')
self._migrate_images_table(conn)
conn.commit()
# Read schema file and execute # Read schema file and execute
schema_path = Path(__file__).parent / "schema.sql" schema_path = Path(__file__).parent / "schema.sql"
@@ -53,6 +60,19 @@ class DatabaseManager:
finally: finally:
conn.close() conn.close()
def _migrate_images_table(self, conn: sqlite3.Connection) -> None:
"""Migrate images table to include the 'source' column if missing."""
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='images'")
if not cursor.fetchone():
return
cursor.execute("PRAGMA table_info(images)")
columns = {row[1] for row in cursor.fetchall()}
if "source" not in columns:
cursor.execute("ALTER TABLE images ADD COLUMN source TEXT")
def _migrate_annotations_table(self, conn: sqlite3.Connection) -> None: def _migrate_annotations_table(self, conn: sqlite3.Connection) -> None:
""" """
Migrate annotations table from old schema (class_name) to new schema (class_id). Migrate annotations table from old schema (class_name) to new schema (class_id).
@@ -85,6 +105,103 @@ class DatabaseManager:
conn.execute("PRAGMA foreign_keys = ON") # Enable foreign keys conn.execute("PRAGMA foreign_keys = ON") # Enable foreign keys
return conn return conn
# ==================== Detection Run Operations ====================
def upsert_detection_run(
self,
image_id: int,
model_id: int,
count: int,
metadata: Optional[Dict] = None,
) -> bool:
"""Insert/update a per-image per-model detection run summary.
This enables the UI to show runs even when zero detections were produced.
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO detection_runs (image_id, model_id, detected_at, count, metadata)
VALUES (?, ?, CURRENT_TIMESTAMP, ?, ?)
ON CONFLICT(image_id, model_id) DO UPDATE SET
detected_at = CURRENT_TIMESTAMP,
count = excluded.count,
metadata = excluded.metadata
""",
(
int(image_id),
int(model_id),
int(count),
json.dumps(metadata) if metadata else None,
),
)
conn.commit()
return True
finally:
conn.close()
def get_detection_run_summaries(self, limit: int = 500, offset: int = 0) -> List[Dict]:
"""Return latest detection run summaries grouped by image+model.
Includes runs with 0 detections.
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"""
SELECT
dr.image_id,
dr.model_id,
dr.detected_at,
dr.count,
dr.metadata,
i.relative_path AS image_path,
i.filename AS image_filename,
m.model_name,
m.model_version,
GROUP_CONCAT(DISTINCT d.class_name) AS classes
FROM detection_runs dr
JOIN images i ON dr.image_id = i.id
JOIN models m ON dr.model_id = m.id
LEFT JOIN detections d
ON d.image_id = dr.image_id AND d.model_id = dr.model_id
GROUP BY dr.image_id, dr.model_id
ORDER BY dr.detected_at DESC
LIMIT ? OFFSET ?
""",
(int(limit), int(offset)),
)
rows: List[Dict] = []
for row in cursor.fetchall():
item = dict(row)
if item.get("metadata"):
try:
item["metadata"] = json.loads(item["metadata"])
except Exception:
item["metadata"] = None
rows.append(item)
return rows
finally:
conn.close()
def get_detection_run_total(self) -> int:
"""Return total number of detection_runs rows."""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) AS cnt FROM detection_runs")
row = cursor.fetchone()
return int(row["cnt"] if row and row["cnt"] is not None else 0)
finally:
conn.close()
# ==================== Model Operations ==================== # ==================== Model Operations ====================
def add_model( def add_model(
@@ -201,6 +318,28 @@ class DatabaseManager:
finally: finally:
conn.close() conn.close()
def delete_model(self, model_id: int) -> bool:
"""Delete a model from the database.
Note: detections referencing this model are deleted automatically via
the `detections.model_id` foreign key (ON DELETE CASCADE).
Args:
model_id: ID of the model to delete.
Returns:
True if a model row was deleted, False otherwise.
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("DELETE FROM models WHERE id = ?", (model_id,))
conn.commit()
return cursor.rowcount > 0
finally:
conn.close()
# ==================== Image Operations ==================== # ==================== Image Operations ====================
def add_image( def add_image(
@@ -211,6 +350,7 @@ class DatabaseManager:
height: int, height: int,
captured_at: Optional[datetime] = None, captured_at: Optional[datetime] = None,
checksum: Optional[str] = None, checksum: Optional[str] = None,
source: Optional[str] = None,
) -> int: ) -> int:
""" """
Add a new image to the database. Add a new image to the database.
@@ -231,10 +371,10 @@ class DatabaseManager:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute(
""" """
INSERT INTO images (relative_path, filename, width, height, captured_at, checksum) INSERT INTO images (relative_path, filename, width, height, captured_at, checksum, source)
VALUES (?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?)
""", """,
(relative_path, filename, width, height, captured_at, checksum), (relative_path, filename, width, height, captured_at, checksum, source),
) )
conn.commit() conn.commit()
return cursor.lastrowid return cursor.lastrowid
@@ -264,6 +404,18 @@ class DatabaseManager:
return existing["id"] return existing["id"]
return self.add_image(relative_path, filename, width, height) return self.add_image(relative_path, filename, width, height)
def set_image_source(self, image_id: int, source: Optional[str]) -> bool:
"""Set/update the source marker for an image row."""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("UPDATE images SET source = ? WHERE id = ?", (source, int(image_id)))
conn.commit()
return cursor.rowcount > 0
finally:
conn.close()
# ==================== Detection Operations ==================== # ==================== Detection Operations ====================
def add_detection( def add_detection(
@@ -472,6 +624,14 @@ class DatabaseManager:
conn = self.get_connection() conn = self.get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
# Also clear detection run summaries so the Results tab does not continue
# to show historical runs after detections have been wiped.
try:
cursor.execute("DELETE FROM detection_runs")
except sqlite3.OperationalError:
# Backwards-compatible: table may not exist on older DB files.
pass
cursor.execute("DELETE FROM detections") cursor.execute("DELETE FROM detections")
conn.commit() conn.commit()
return cursor.rowcount return cursor.rowcount
@@ -636,6 +796,153 @@ class DatabaseManager:
# ==================== Annotation Operations ==================== # ==================== Annotation Operations ====================
def get_images_summary(
self,
name_filter: Optional[str] = None,
source_filter: Optional[str] = None,
order_by: str = "filename",
order_dir: str = "ASC",
limit: Optional[int] = None,
offset: int = 0,
) -> List[Dict]:
"""Return all images with annotation counts (including zero).
This is used by the Annotation tab to populate the image list even when
no annotations exist yet.
Args:
name_filter: Optional substring filter applied to filename/relative_path.
order_by: One of: 'filename', 'relative_path', 'annotation_count', 'added_at'.
order_dir: 'ASC' or 'DESC'.
limit: Optional max number of rows.
offset: Pagination offset.
Returns:
List of dicts: {id, relative_path, filename, added_at, annotation_count}
"""
allowed_order_by = {
"filename": "i.filename",
"relative_path": "i.relative_path",
"annotation_count": "annotation_count",
"added_at": "i.added_at",
}
order_expr = allowed_order_by.get(order_by, "i.filename")
dir_norm = str(order_dir).upper().strip()
if dir_norm not in {"ASC", "DESC"}:
dir_norm = "ASC"
conn = self.get_connection()
try:
params: List[Any] = []
where_clauses: List[str] = []
if name_filter:
token = f"%{name_filter}%"
where_clauses.append("(i.filename LIKE ? OR i.relative_path LIKE ?)")
params.extend([token, token])
if source_filter:
where_clauses.append("i.source = ?")
params.append(source_filter)
where_sql = ""
if where_clauses:
where_sql = "WHERE " + " AND ".join(where_clauses)
limit_sql = ""
if limit is not None:
limit_sql = " LIMIT ? OFFSET ?"
params.extend([int(limit), int(offset)])
query = f"""
SELECT
i.id,
i.relative_path,
i.filename,
i.added_at,
COUNT(a.id) AS annotation_count
FROM images i
LEFT JOIN annotations a ON a.image_id = i.id
{where_sql}
GROUP BY i.id
ORDER BY {order_expr} {dir_norm}
{limit_sql}
"""
cursor = conn.cursor()
cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()]
finally:
conn.close()
def get_annotated_images_summary(
self,
name_filter: Optional[str] = None,
order_by: str = "filename",
order_dir: str = "ASC",
limit: Optional[int] = None,
offset: int = 0,
) -> List[Dict]:
"""Return images that have at least one manual annotation.
Args:
name_filter: Optional substring filter applied to filename/relative_path.
order_by: One of: 'filename', 'relative_path', 'annotation_count', 'added_at'.
order_dir: 'ASC' or 'DESC'.
limit: Optional max number of rows.
offset: Pagination offset.
Returns:
List of dicts: {id, relative_path, filename, added_at, annotation_count}
"""
allowed_order_by = {
"filename": "i.filename",
"relative_path": "i.relative_path",
"annotation_count": "annotation_count",
"added_at": "i.added_at",
}
order_expr = allowed_order_by.get(order_by, "i.filename")
dir_norm = str(order_dir).upper().strip()
if dir_norm not in {"ASC", "DESC"}:
dir_norm = "ASC"
conn = self.get_connection()
try:
params: List[Any] = []
where_sql = ""
if name_filter:
# Case-insensitive substring search.
token = f"%{name_filter}%"
where_sql = "WHERE (i.filename LIKE ? OR i.relative_path LIKE ?)"
params.extend([token, token])
limit_sql = ""
if limit is not None:
limit_sql = " LIMIT ? OFFSET ?"
params.extend([int(limit), int(offset)])
query = f"""
SELECT
i.id,
i.relative_path,
i.filename,
i.added_at,
COUNT(a.id) AS annotation_count
FROM images i
JOIN annotations a ON a.image_id = i.id
{where_sql}
GROUP BY i.id
HAVING annotation_count > 0
ORDER BY {order_expr} {dir_norm}
{limit_sql}
"""
cursor = conn.cursor()
cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()]
finally:
conn.close()
def add_annotation( def add_annotation(
self, self,
image_id: int, image_id: int,
@@ -741,6 +1048,27 @@ class DatabaseManager:
finally: finally:
conn.close() conn.close()
def delete_annotations_for_image(self, image_id: int) -> int:
"""Delete ALL annotations for a specific image.
This is primarily used for import/overwrite workflows.
Args:
image_id: ID of the image whose annotations should be deleted.
Returns:
Number of rows deleted.
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("DELETE FROM annotations WHERE image_id = ?", (int(image_id),))
conn.commit()
return int(cursor.rowcount or 0)
finally:
conn.close()
# ==================== Object Class Operations ==================== # ==================== Object Class Operations ====================
def get_object_classes(self) -> List[Dict]: def get_object_classes(self) -> List[Dict]:

View File

@@ -7,7 +7,7 @@ CREATE TABLE IF NOT EXISTS models (
model_name TEXT NOT NULL, model_name TEXT NOT NULL,
model_version TEXT NOT NULL, model_version TEXT NOT NULL,
model_path TEXT NOT NULL, model_path TEXT NOT NULL,
base_model TEXT NOT NULL DEFAULT 'yolov8s.pt', base_model TEXT NOT NULL DEFAULT 'yolo11s.pt',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
training_params TEXT, -- JSON string of training parameters training_params TEXT, -- JSON string of training parameters
metrics TEXT, -- JSON string of validation metrics metrics TEXT, -- JSON string of validation metrics
@@ -23,7 +23,8 @@ CREATE TABLE IF NOT EXISTS images (
height INTEGER, height INTEGER,
captured_at TIMESTAMP, captured_at TIMESTAMP,
added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
checksum TEXT checksum TEXT,
source TEXT
); );
-- Detections table: stores detection results -- Detections table: stores detection results
@@ -44,6 +45,19 @@ CREATE TABLE IF NOT EXISTS detections (
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE
); );
-- Detection runs table: stores per-image per-model run summaries (including 0 detections)
CREATE TABLE IF NOT EXISTS detection_runs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
image_id INTEGER NOT NULL,
model_id INTEGER NOT NULL,
detected_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
count INTEGER NOT NULL DEFAULT 0,
metadata TEXT,
UNIQUE(image_id, model_id),
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE,
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE
);
-- Object classes table: stores annotation class definitions with colors -- Object classes table: stores annotation class definitions with colors
CREATE TABLE IF NOT EXISTS object_classes ( CREATE TABLE IF NOT EXISTS object_classes (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -80,8 +94,12 @@ CREATE INDEX IF NOT EXISTS idx_detections_model_id ON detections(model_id);
CREATE INDEX IF NOT EXISTS idx_detections_class_name ON detections(class_name); CREATE INDEX IF NOT EXISTS idx_detections_class_name ON detections(class_name);
CREATE INDEX IF NOT EXISTS idx_detections_detected_at ON detections(detected_at); CREATE INDEX IF NOT EXISTS idx_detections_detected_at ON detections(detected_at);
CREATE INDEX IF NOT EXISTS idx_detections_confidence ON detections(confidence); CREATE INDEX IF NOT EXISTS idx_detections_confidence ON detections(confidence);
CREATE INDEX IF NOT EXISTS idx_detection_runs_image_id ON detection_runs(image_id);
CREATE INDEX IF NOT EXISTS idx_detection_runs_model_id ON detection_runs(model_id);
CREATE INDEX IF NOT EXISTS idx_detection_runs_detected_at ON detection_runs(detected_at);
CREATE INDEX IF NOT EXISTS idx_images_relative_path ON images(relative_path); CREATE INDEX IF NOT EXISTS idx_images_relative_path ON images(relative_path);
CREATE INDEX IF NOT EXISTS idx_images_added_at ON images(added_at); CREATE INDEX IF NOT EXISTS idx_images_added_at ON images(added_at);
CREATE INDEX IF NOT EXISTS idx_images_source ON images(source);
CREATE INDEX IF NOT EXISTS idx_annotations_image_id ON annotations(image_id); CREATE INDEX IF NOT EXISTS idx_annotations_image_id ON annotations(image_id);
CREATE INDEX IF NOT EXISTS idx_annotations_class_id ON annotations(class_id); CREATE INDEX IF NOT EXISTS idx_annotations_class_id ON annotations(class_id);
CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at); CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at);

View File

@@ -1,6 +1,7 @@
""" """Main window for the microscopy object detection application."""
Main window for the microscopy object detection application.
""" import shutil
from pathlib import Path
from PySide6.QtWidgets import ( from PySide6.QtWidgets import (
QMainWindow, QMainWindow,
@@ -20,6 +21,7 @@ from src.database.db_manager import DatabaseManager
from src.utils.config_manager import ConfigManager from src.utils.config_manager import ConfigManager
from src.utils.logger import get_logger from src.utils.logger import get_logger
from src.gui.dialogs.config_dialog import ConfigDialog from src.gui.dialogs.config_dialog import ConfigDialog
from src.gui.dialogs.delete_model_dialog import DeleteModelDialog
from src.gui.tabs.detection_tab import DetectionTab from src.gui.tabs.detection_tab import DetectionTab
from src.gui.tabs.training_tab import TrainingTab from src.gui.tabs.training_tab import TrainingTab
from src.gui.tabs.validation_tab import ValidationTab from src.gui.tabs.validation_tab import ValidationTab
@@ -91,6 +93,12 @@ class MainWindow(QMainWindow):
db_stats_action.triggered.connect(self._show_database_stats) db_stats_action.triggered.connect(self._show_database_stats)
tools_menu.addAction(db_stats_action) tools_menu.addAction(db_stats_action)
tools_menu.addSeparator()
delete_model_action = QAction("Delete &Model…", self)
delete_model_action.triggered.connect(self._show_delete_model_dialog)
tools_menu.addAction(delete_model_action)
# Help menu # Help menu
help_menu = menubar.addMenu("&Help") help_menu = menubar.addMenu("&Help")
@@ -117,10 +125,10 @@ class MainWindow(QMainWindow):
# Add tabs to widget # Add tabs to widget
self.tab_widget.addTab(self.detection_tab, "Detection") self.tab_widget.addTab(self.detection_tab, "Detection")
self.tab_widget.addTab(self.results_tab, "Results")
self.tab_widget.addTab(self.annotation_tab, "Annotation")
self.tab_widget.addTab(self.training_tab, "Training") self.tab_widget.addTab(self.training_tab, "Training")
self.tab_widget.addTab(self.validation_tab, "Validation") self.tab_widget.addTab(self.validation_tab, "Validation")
self.tab_widget.addTab(self.results_tab, "Results")
self.tab_widget.addTab(self.annotation_tab, "Annotation (Future)")
# Connect tab change signal # Connect tab change signal
self.tab_widget.currentChanged.connect(self._on_tab_changed) self.tab_widget.currentChanged.connect(self._on_tab_changed)
@@ -152,9 +160,7 @@ class MainWindow(QMainWindow):
"""Center window on screen.""" """Center window on screen."""
screen = self.screen().geometry() screen = self.screen().geometry()
size = self.geometry() size = self.geometry()
self.move( self.move((screen.width() - size.width()) // 2, (screen.height() - size.height()) // 2)
(screen.width() - size.width()) // 2, (screen.height() - size.height()) // 2
)
def _restore_window_state(self): def _restore_window_state(self):
"""Restore window geometry from settings or center window.""" """Restore window geometry from settings or center window."""
@@ -193,6 +199,10 @@ class MainWindow(QMainWindow):
self.training_tab.refresh() self.training_tab.refresh()
if hasattr(self, "results_tab"): if hasattr(self, "results_tab"):
self.results_tab.refresh() self.results_tab.refresh()
if hasattr(self, "annotation_tab"):
self.annotation_tab.refresh()
if hasattr(self, "validation_tab"):
self.validation_tab.refresh()
except Exception as e: except Exception as e:
logger.error(f"Error applying settings: {e}") logger.error(f"Error applying settings: {e}")
@@ -209,6 +219,14 @@ class MainWindow(QMainWindow):
logger.debug(f"Switched to tab: {tab_name}") logger.debug(f"Switched to tab: {tab_name}")
self._update_status(f"Viewing: {tab_name}") self._update_status(f"Viewing: {tab_name}")
# Ensure the Annotation tab always shows up-to-date DB-backed lists.
try:
current_widget = self.tab_widget.widget(index)
if hasattr(self, "annotation_tab") and current_widget is self.annotation_tab:
self.annotation_tab.refresh()
except Exception as exc:
logger.debug(f"Failed to refresh annotation tab on selection: {exc}")
def _show_database_stats(self): def _show_database_stats(self):
"""Show database statistics dialog.""" """Show database statistics dialog."""
try: try:
@@ -231,9 +249,229 @@ class MainWindow(QMainWindow):
except Exception as e: except Exception as e:
logger.error(f"Error getting database stats: {e}") logger.error(f"Error getting database stats: {e}")
QMessageBox.warning( QMessageBox.warning(self, "Error", f"Failed to get database statistics:\n{str(e)}")
self, "Error", f"Failed to get database statistics:\n{str(e)}"
) def _show_delete_model_dialog(self) -> None:
"""Open the model deletion dialog."""
dialog = DeleteModelDialog(self.db_manager, self)
if not dialog.exec():
return
model_ids = dialog.selected_model_ids
if not model_ids:
return
self._delete_models(model_ids)
def _delete_models(self, model_ids: list[int]) -> None:
"""Delete one or more models from the database and remove artifacts from disk."""
deleted_count = 0
removed_paths: list[str] = []
remove_errors: list[str] = []
for model_id in model_ids:
model = None
try:
model = self.db_manager.get_model_by_id(int(model_id))
except Exception as exc:
logger.error(f"Failed to load model {model_id} before deletion: {exc}")
if not model:
remove_errors.append(f"Model id {model_id} not found in database.")
continue
try:
deleted = self.db_manager.delete_model(int(model_id))
except Exception as exc:
logger.error(f"Failed to delete model {model_id}: {exc}")
remove_errors.append(f"Failed to delete model id {model_id} from DB: {exc}")
continue
if not deleted:
remove_errors.append(f"Model id {model_id} was not deleted (already removed?).")
continue
deleted_count += 1
removed, errors = self._delete_model_artifacts_from_disk(model)
removed_paths.extend(removed)
remove_errors.extend(errors)
# Refresh tabs to reflect the deletion(s).
try:
if hasattr(self, "detection_tab"):
self.detection_tab.refresh()
if hasattr(self, "results_tab"):
self.results_tab.refresh()
if hasattr(self, "validation_tab"):
self.validation_tab.refresh()
if hasattr(self, "training_tab"):
self.training_tab.refresh()
except Exception as exc:
logger.warning(f"Failed to refresh tabs after model deletion: {exc}")
details: list[str] = []
if removed_paths:
details.append("Removed from disk:\n" + "\n".join(removed_paths))
if remove_errors:
details.append("\nDisk cleanup warnings:\n" + "\n".join(remove_errors))
QMessageBox.information(
self,
"Delete Model",
f"Deleted {deleted_count} model(s) from database." + ("\n\n" + "\n".join(details) if details else ""),
)
def _delete_model(self, model_id: int) -> None:
"""Delete a model from the database and remove its artifacts from disk."""
model = None
try:
model = self.db_manager.get_model_by_id(model_id)
except Exception as exc:
logger.error(f"Failed to load model {model_id} before deletion: {exc}")
if not model:
QMessageBox.warning(self, "Delete Model", "Selected model was not found in the database.")
return
model_path = str(model.get("model_path") or "")
try:
deleted = self.db_manager.delete_model(model_id)
except Exception as exc:
logger.error(f"Failed to delete model {model_id}: {exc}")
QMessageBox.critical(self, "Delete Model", f"Failed to delete model from database:\n{exc}")
return
if not deleted:
QMessageBox.warning(self, "Delete Model", "No model was deleted (it may have already been removed).")
return
removed_paths, remove_errors = self._delete_model_artifacts_from_disk(model)
# Refresh tabs to reflect the deletion.
try:
if hasattr(self, "detection_tab"):
self.detection_tab.refresh()
if hasattr(self, "results_tab"):
self.results_tab.refresh()
if hasattr(self, "validation_tab"):
self.validation_tab.refresh()
if hasattr(self, "training_tab"):
self.training_tab.refresh()
except Exception as exc:
logger.warning(f"Failed to refresh tabs after model deletion: {exc}")
details = []
if model_path:
details.append(f"Deleted model record for: {model_path}")
if removed_paths:
details.append("\nRemoved from disk:\n" + "\n".join(removed_paths))
if remove_errors:
details.append("\nDisk cleanup warnings:\n" + "\n".join(remove_errors))
QMessageBox.information(
self,
"Delete Model",
"Model deleted from database." + ("\n\n" + "\n".join(details) if details else ""),
)
def _delete_model_artifacts_from_disk(self, model: dict) -> tuple[list[str], list[str]]:
"""Best-effort removal of model artifacts on disk.
Strategy:
- Remove run directories inferred from:
- model.model_path (…/<run>/weights/*.pt => <run>)
- training_params.stage_results[].results.save_dir
but only if they are under the configured models directory.
- If the weights file itself exists and is outside the models directory, delete only the file.
Returns:
(removed_paths, errors)
"""
removed: list[str] = []
errors: list[str] = []
models_root = Path(self.config_manager.get_models_directory() or "data/models").expanduser()
try:
models_root_resolved = models_root.resolve()
except Exception:
models_root_resolved = models_root
inferred_dirs: list[Path] = []
# 1) From model_path
model_path_value = model.get("model_path")
if model_path_value:
try:
p = Path(str(model_path_value)).expanduser()
p_resolved = p.resolve() if p.exists() else p
if p_resolved.is_file():
if p_resolved.parent.name == "weights" and p_resolved.parent.parent.exists():
inferred_dirs.append(p_resolved.parent.parent)
elif p_resolved.parent.exists():
inferred_dirs.append(p_resolved.parent)
except Exception:
pass
# 2) From training_params.stage_results[].results.save_dir
training_params = model.get("training_params") or {}
if isinstance(training_params, dict):
stage_results = training_params.get("stage_results")
if isinstance(stage_results, list):
for stage in stage_results:
results = (stage or {}).get("results")
save_dir = (results or {}).get("save_dir") if isinstance(results, dict) else None
if not save_dir:
continue
try:
d = Path(str(save_dir)).expanduser()
if d.exists() and d.is_dir():
inferred_dirs.append(d)
except Exception:
continue
# Deduplicate inferred_dirs
unique_dirs: list[Path] = []
seen: set[str] = set()
for d in inferred_dirs:
try:
key = str(d.resolve())
except Exception:
key = str(d)
if key in seen:
continue
seen.add(key)
unique_dirs.append(d)
# Delete directories under models_root
for d in unique_dirs:
try:
d_resolved = d.resolve()
except Exception:
d_resolved = d
try:
if d_resolved.exists() and d_resolved.is_dir() and d_resolved.is_relative_to(models_root_resolved):
shutil.rmtree(d_resolved)
removed.append(str(d_resolved))
except Exception as exc:
errors.append(f"Failed to remove directory {d_resolved}: {exc}")
# If nothing matched (e.g., model_path outside models_root), delete just the file.
if model_path_value:
try:
p = Path(str(model_path_value)).expanduser()
if p.exists() and p.is_file():
p_resolved = p.resolve()
if not p_resolved.is_relative_to(models_root_resolved):
p_resolved.unlink()
removed.append(str(p_resolved))
except Exception as exc:
errors.append(f"Failed to remove model file {model_path_value}: {exc}")
return removed, errors
def _show_about(self): def _show_about(self):
"""Show about dialog.""" """Show about dialog."""
@@ -301,6 +539,11 @@ class MainWindow(QMainWindow):
if hasattr(self, "training_tab"): if hasattr(self, "training_tab"):
self.training_tab.shutdown() self.training_tab.shutdown()
if hasattr(self, "annotation_tab"): if hasattr(self, "annotation_tab"):
# Best-effort refresh so DB-backed UI state is consistent at shutdown.
try:
self.annotation_tab.refresh()
except Exception:
pass
self.annotation_tab.save_state() self.annotation_tab.save_state()
logger.info("Application closing") logger.info("Application closing")

View File

@@ -13,6 +13,11 @@ from PySide6.QtWidgets import (
QFileDialog, QFileDialog,
QMessageBox, QMessageBox,
QSplitter, QSplitter,
QLineEdit,
QTableWidget,
QTableWidgetItem,
QHeaderView,
QAbstractItemView,
) )
from PySide6.QtCore import Qt, QSettings from PySide6.QtCore import Qt, QSettings
from pathlib import Path from pathlib import Path
@@ -29,9 +34,7 @@ logger = get_logger(__name__)
class AnnotationTab(QWidget): class AnnotationTab(QWidget):
"""Annotation tab for manual image annotation.""" """Annotation tab for manual image annotation."""
def __init__( def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
):
super().__init__(parent) super().__init__(parent)
self.db_manager = db_manager self.db_manager = db_manager
self.config_manager = config_manager self.config_manager = config_manager
@@ -52,6 +55,32 @@ class AnnotationTab(QWidget):
self.main_splitter = QSplitter(Qt.Horizontal) self.main_splitter = QSplitter(Qt.Horizontal)
self.main_splitter.setHandleWidth(10) self.main_splitter.setHandleWidth(10)
# { Left-most pane: annotated images list
annotated_group = QGroupBox("Annotated Images")
annotated_layout = QVBoxLayout()
filter_row = QHBoxLayout()
filter_row.addWidget(QLabel("Filter:"))
self.annotated_filter_edit = QLineEdit()
self.annotated_filter_edit.setPlaceholderText("Type to filter by image name…")
self.annotated_filter_edit.textChanged.connect(self._refresh_annotated_images_list)
filter_row.addWidget(self.annotated_filter_edit, 1)
annotated_layout.addLayout(filter_row)
self.annotated_images_table = QTableWidget(0, 2)
self.annotated_images_table.setHorizontalHeaderLabels(["Image", "Annotations"])
self.annotated_images_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Stretch)
self.annotated_images_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeToContents)
self.annotated_images_table.setSelectionBehavior(QAbstractItemView.SelectRows)
self.annotated_images_table.setSelectionMode(QAbstractItemView.SingleSelection)
self.annotated_images_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
self.annotated_images_table.setSortingEnabled(True)
self.annotated_images_table.itemSelectionChanged.connect(self._on_annotated_image_selected)
annotated_layout.addWidget(self.annotated_images_table, 1)
annotated_group.setLayout(annotated_layout)
# }
# { Left splitter for image display and zoom info # { Left splitter for image display and zoom info
self.left_splitter = QSplitter(Qt.Vertical) self.left_splitter = QSplitter(Qt.Vertical)
self.left_splitter.setHandleWidth(10) self.left_splitter.setHandleWidth(10)
@@ -62,6 +91,9 @@ class AnnotationTab(QWidget):
# Use the AnnotationCanvasWidget # Use the AnnotationCanvasWidget
self.annotation_canvas = AnnotationCanvasWidget() self.annotation_canvas = AnnotationCanvasWidget()
# Auto-zoom so newly loaded images fill the available canvas viewport.
# (Matches the behavior used in ResultsTab.)
self.annotation_canvas.set_auto_fit_to_view(True)
self.annotation_canvas.zoom_changed.connect(self._on_zoom_changed) self.annotation_canvas.zoom_changed.connect(self._on_zoom_changed)
self.annotation_canvas.annotation_drawn.connect(self._on_annotation_drawn) self.annotation_canvas.annotation_drawn.connect(self._on_annotation_drawn)
# Selection of existing polylines (when tool is not in drawing mode) # Selection of existing polylines (when tool is not in drawing mode)
@@ -72,9 +104,7 @@ class AnnotationTab(QWidget):
self.left_splitter.addWidget(canvas_group) self.left_splitter.addWidget(canvas_group)
# Controls info # Controls info
controls_info = QLabel( controls_info = QLabel("Zoom: Mouse wheel or +/- keys | Drawing: Enable pen and drag mouse")
"Zoom: Mouse wheel or +/- keys | Drawing: Enable pen and drag mouse"
)
controls_info.setStyleSheet("QLabel { color: #888; font-style: italic; }") controls_info.setStyleSheet("QLabel { color: #888; font-style: italic; }")
self.left_splitter.addWidget(controls_info) self.left_splitter.addWidget(controls_info)
# } # }
@@ -85,47 +115,47 @@ class AnnotationTab(QWidget):
# Annotation tools section # Annotation tools section
self.annotation_tools = AnnotationToolsWidget(self.db_manager) self.annotation_tools = AnnotationToolsWidget(self.db_manager)
self.annotation_tools.polyline_enabled_changed.connect( self.annotation_tools.polyline_enabled_changed.connect(self.annotation_canvas.set_polyline_enabled)
self.annotation_canvas.set_polyline_enabled self.annotation_tools.polyline_pen_color_changed.connect(self.annotation_canvas.set_polyline_pen_color)
) self.annotation_tools.polyline_pen_width_changed.connect(self.annotation_canvas.set_polyline_pen_width)
self.annotation_tools.polyline_pen_color_changed.connect(
self.annotation_canvas.set_polyline_pen_color
)
self.annotation_tools.polyline_pen_width_changed.connect(
self.annotation_canvas.set_polyline_pen_width
)
# Show / hide bounding boxes # Show / hide bounding boxes
self.annotation_tools.show_bboxes_changed.connect( self.annotation_tools.show_bboxes_changed.connect(self.annotation_canvas.set_show_bboxes)
self.annotation_canvas.set_show_bboxes
)
# RDP simplification controls # RDP simplification controls
self.annotation_tools.simplify_on_finish_changed.connect( self.annotation_tools.simplify_on_finish_changed.connect(self._on_simplify_on_finish_changed)
self._on_simplify_on_finish_changed self.annotation_tools.simplify_epsilon_changed.connect(self._on_simplify_epsilon_changed)
)
self.annotation_tools.simplify_epsilon_changed.connect(
self._on_simplify_epsilon_changed
)
# Class selection and class-color changes # Class selection and class-color changes
self.annotation_tools.class_selected.connect(self._on_class_selected) self.annotation_tools.class_selected.connect(self._on_class_selected)
self.annotation_tools.class_color_changed.connect(self._on_class_color_changed) self.annotation_tools.class_color_changed.connect(self._on_class_color_changed)
self.annotation_tools.clear_annotations_requested.connect( self.annotation_tools.clear_annotations_requested.connect(self._on_clear_annotations)
self._on_clear_annotations
)
# Delete selected annotation on canvas # Delete selected annotation on canvas
self.annotation_tools.delete_selected_annotation_requested.connect( self.annotation_tools.delete_selected_annotation_requested.connect(self._on_delete_selected_annotation)
self._on_delete_selected_annotation
)
self.right_splitter.addWidget(self.annotation_tools) self.right_splitter.addWidget(self.annotation_tools)
# Image loading section # Image loading section
load_group = QGroupBox("Image Loading") load_group = QGroupBox("Image Loading")
load_layout = QVBoxLayout() load_layout = QVBoxLayout()
# Load image button # Buttons row
button_layout = QHBoxLayout() button_layout = QHBoxLayout()
self.load_image_btn = QPushButton("Load Image") self.load_image_btn = QPushButton("Load Image")
self.load_image_btn.clicked.connect(self._load_image) self.load_image_btn.clicked.connect(self._load_image)
button_layout.addWidget(self.load_image_btn) button_layout.addWidget(self.load_image_btn)
self.import_images_btn = QPushButton("Import Images")
self.import_images_btn.setToolTip(
"Import one or more images into the database.\n" "Images already present in the DB are skipped."
)
self.import_images_btn.clicked.connect(self._import_images)
button_layout.addWidget(self.import_images_btn)
self.import_annotations_btn = QPushButton("Import Annotations")
self.import_annotations_btn.setToolTip(
"Import YOLO .txt annotation files and register them with their corresponding images.\n"
"Existing annotations for those images will be overwritten."
)
self.import_annotations_btn.clicked.connect(self._import_annotations)
button_layout.addWidget(self.import_annotations_btn)
button_layout.addStretch() button_layout.addStretch()
load_layout.addLayout(button_layout) load_layout.addLayout(button_layout)
@@ -137,12 +167,13 @@ class AnnotationTab(QWidget):
self.right_splitter.addWidget(load_group) self.right_splitter.addWidget(load_group)
# } # }
# Add both splitters to the main horizontal splitter # Add list + both splitters to the main horizontal splitter
self.main_splitter.addWidget(annotated_group)
self.main_splitter.addWidget(self.left_splitter) self.main_splitter.addWidget(self.left_splitter)
self.main_splitter.addWidget(self.right_splitter) self.main_splitter.addWidget(self.right_splitter)
# Set initial sizes: 75% for left (image), 25% for right (controls) # Set initial sizes: list (left), canvas (middle), controls (right)
self.main_splitter.setSizes([750, 250]) self.main_splitter.setSizes([320, 650, 280])
layout.addWidget(self.main_splitter) layout.addWidget(self.main_splitter)
self.setLayout(layout) self.setLayout(layout)
@@ -150,6 +181,375 @@ class AnnotationTab(QWidget):
# Restore splitter positions from settings # Restore splitter positions from settings
self._restore_state() self._restore_state()
# Populate list on startup.
self._refresh_annotated_images_list()
def _import_images(self) -> None:
"""Import one or more images into the database and refresh the list."""
settings = QSettings("microscopy_app", "object_detection")
last_dir = settings.value("annotation_tab/last_image_import_directory", None)
repo_root = (self.config_manager.get_image_repository_path() or "").strip()
if last_dir and Path(str(last_dir)).exists():
start_dir = str(last_dir)
elif repo_root and Path(repo_root).exists():
start_dir = repo_root
else:
start_dir = str(Path.home())
# Build filter string for supported extensions
patterns = " ".join(f"*{ext}" for ext in Image.SUPPORTED_EXTENSIONS)
file_paths, _ = QFileDialog.getOpenFileNames(
self,
"Select Image File(s)",
start_dir,
f"Images ({patterns})",
)
if not file_paths:
return
try:
settings.setValue("annotation_tab/last_image_import_directory", str(Path(file_paths[0]).parent))
# Keep compatibility with the existing image resolver fallback (it checks last_directory).
settings.setValue("annotation_tab/last_directory", str(Path(file_paths[0]).parent))
except Exception:
pass
imported = 0
tagged_existing = 0
skipped = 0
errors: list[str] = []
for fp in file_paths:
try:
img_path = Path(fp)
img = Image(str(img_path))
relative_path = self._compute_relative_path_for_repo(img_path)
# Skip if already present
existing = self.db_manager.get_image_by_path(relative_path)
if existing:
# If the image already exists (e.g. created earlier by other workflows),
# tag it as being managed by the Annotation tab so it becomes visible
# in the left list.
try:
self.db_manager.set_image_source(int(existing["id"]), "annotation_tab")
tagged_existing += 1
except Exception:
# If tagging fails, fall back to treating as skipped.
skipped += 1
continue
image_id = self.db_manager.add_image(
relative_path,
img_path.name,
img.width,
img.height,
source="annotation_tab",
)
try:
# In case the DB row was created by an older schema/migration path.
self.db_manager.set_image_source(image_id, "annotation_tab")
except Exception:
pass
imported += 1
except ImageLoadError as exc:
skipped += 1
errors.append(f"Failed to load image {fp}: {exc}")
except Exception as exc:
skipped += 1
errors.append(f"Failed to import image {fp}: {exc}")
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
msg = (
f"Imported: {imported}\n"
f"Already in DB (tagged for Annotation tab): {tagged_existing}\n"
f"Skipped (errors): {skipped}"
)
if errors:
details = "\n".join(errors[:25])
if len(errors) > 25:
details += f"\n... and {len(errors) - 25} more"
msg += "\n\nDetails:\n" + details
QMessageBox.information(self, "Import Images", msg)
# ==================== Import annotations (YOLO .txt) ====================
def _import_annotations(self) -> None:
"""Import YOLO segmentation/bbox annotations from one or more .txt files."""
settings = QSettings("microscopy_app", "object_detection")
last_dir = settings.value("annotation_tab/last_annotation_directory", None)
# Default start dir: repo root if set, otherwise last used, otherwise home.
repo_root = (self.config_manager.get_image_repository_path() or "").strip()
if last_dir and Path(str(last_dir)).exists():
start_dir = str(last_dir)
elif repo_root and Path(repo_root).exists():
start_dir = repo_root
else:
start_dir = str(Path.home())
file_paths, _ = QFileDialog.getOpenFileNames(
self,
"Select YOLO Annotation File(s)",
start_dir,
"YOLO annotations (*.txt)",
)
if not file_paths:
return
# Persist last annotation directory for the next import.
try:
settings.setValue("annotation_tab/last_annotation_directory", str(Path(file_paths[0]).parent))
except Exception:
pass
imported_images = 0
imported_annotations = 0
overwritten_images = 0
skipped = 0
errors: list[str] = []
for label_file in file_paths:
label_path = Path(label_file)
try:
image_path = self._infer_corresponding_image_path(label_path)
if not image_path:
skipped += 1
errors.append(f"Image not found for label file: {label_path}")
continue
# Load image to obtain width/height for DB entry.
img = Image(str(image_path))
# Store in DB using a repo-relative path if possible.
relative_path = self._compute_relative_path_for_repo(image_path)
image_id = self.db_manager.get_or_create_image(relative_path, image_path.name, img.width, img.height)
try:
self.db_manager.set_image_source(image_id, "annotation_tab")
except Exception:
pass
# Overwrite existing annotations for this image.
try:
deleted = self.db_manager.delete_annotations_for_image(image_id)
except AttributeError:
# Safety fallback if older DBManager is used.
deleted = 0
if deleted > 0:
overwritten_images += 1
# Parse YOLO lines and insert as annotations.
parsed = self._parse_yolo_annotation_file(label_path)
if not parsed:
# Empty/invalid label file: treat as "clear" operation (already deleted above)
imported_images += 1
continue
db_classes = self.db_manager.get_object_classes() or []
classes_by_index = {idx: row for idx, row in enumerate(db_classes)}
for class_index, bbox, poly in parsed:
class_row = classes_by_index.get(int(class_index))
if not class_row:
skipped += 1
errors.append(
f"Unknown class index {class_index} in {label_path.name}. "
"Create object classes in the UI first (class index is based on DB ordering)."
)
continue
ann_id = self.db_manager.add_annotation(
image_id=image_id,
class_id=int(class_row["id"]),
bbox=bbox,
annotator="import",
segmentation_mask=poly,
verified=False,
)
if ann_id:
imported_annotations += 1
imported_images += 1
# If we imported for the currently open image, reload.
if self.current_image_id and int(self.current_image_id) == int(image_id):
self._load_annotations_for_current_image()
except ImageLoadError as exc:
skipped += 1
errors.append(f"Failed to load image for {label_path.name}: {exc}")
except Exception as exc:
skipped += 1
errors.append(f"Import failed for {label_path.name}: {exc}")
# Refresh annotated images list.
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
summary = (
f"Imported files: {len(file_paths)}\n"
f"Images processed: {imported_images}\n"
f"Annotations inserted: {imported_annotations}\n"
f"Images overwritten (had existing annotations): {overwritten_images}\n"
f"Skipped: {skipped}"
)
if errors:
# Cap error details to avoid huge dialogs.
details = "\n".join(errors[:25])
if len(errors) > 25:
details += f"\n... and {len(errors) - 25} more"
summary += "\n\nDetails:\n" + details
QMessageBox.information(self, "Import Annotations", summary)
def _infer_corresponding_image_path(self, label_path: Path) -> Path | None:
"""Infer image path from YOLO label file path.
Requirement: image(s) live in an `images/` folder located in the label file's parent directory.
Example:
/dataset/train/labels/img123.txt -> /dataset/train/images/img123.(any supported ext)
"""
parent = label_path.parent
images_dir = parent.parent / "images"
stem = label_path.stem
# 1) Direct stem match in images dir (any supported extension)
for ext in Image.SUPPORTED_EXTENSIONS:
candidate = images_dir / f"{stem}{ext}"
if candidate.exists() and candidate.is_file():
return candidate
# 2) Fallback: repository-root search by filename
repo_root = (self.config_manager.get_image_repository_path() or "").strip()
if repo_root:
root = Path(repo_root).expanduser()
try:
if root.exists():
for ext in Image.SUPPORTED_EXTENSIONS:
filename = f"{stem}{ext}"
for match in root.rglob(filename):
if match.is_file():
return match.resolve()
except Exception:
pass
return None
def _compute_relative_path_for_repo(self, image_path: Path) -> str:
"""Compute a stable `relative_path` suitable for DB storage.
Policy:
- If an image repository root is configured and the image is under it, store a repo-relative path.
- Otherwise, store an absolute resolved path so the image can be reopened later.
"""
repo_root = (self.config_manager.get_image_repository_path() or "").strip()
try:
if repo_root:
repo_root_path = Path(repo_root).expanduser().resolve()
img_resolved = image_path.expanduser().resolve()
if img_resolved.is_relative_to(repo_root_path):
return img_resolved.relative_to(repo_root_path).as_posix()
except Exception:
pass
try:
return str(image_path.expanduser().resolve())
except Exception:
return str(image_path)
def _parse_yolo_annotation_file(
self, label_path: Path
) -> list[tuple[int, tuple[float, float, float, float], list[list[float]] | None]]:
"""Parse a YOLO .txt label file.
Supports:
- YOLO segmentation polygon format: "class x1 y1 x2 y2 ..." (normalized)
- YOLO bbox format: "class x_center y_center width height" (normalized)
Returns:
List of (class_index, bbox_xyxy_norm, segmentation_mask_db)
Where segmentation_mask_db is [[y_norm, x_norm], ...] or None.
"""
out: list[tuple[int, tuple[float, float, float, float], list[list[float]] | None]] = []
try:
raw = label_path.read_text(encoding="utf-8").splitlines()
except OSError as exc:
logger.error(f"Failed to read label file {label_path}: {exc}")
return out
for line in raw:
stripped = line.strip()
if not stripped:
continue
parts = stripped.split()
if len(parts) < 5:
# not enough for bbox
continue
try:
class_idx = int(float(parts[0]))
coords = [float(x) for x in parts[1:]]
except Exception:
continue
# Segmentation polygon format (>= 6 values)
if len(coords) >= 6:
# bbox is not explicitly present in this format in our importer; compute from polygon.
xs = coords[0::2]
ys = coords[1::2]
if not xs or not ys:
continue
x_min, x_max = min(xs), max(xs)
y_min, y_max = min(ys), max(ys)
bbox = (
self._clamp01(x_min),
self._clamp01(y_min),
self._clamp01(x_max),
self._clamp01(y_max),
)
# Convert to DB polyline convention: [[y_norm, x_norm], ...]
poly: list[list[float]] = []
for x, y in zip(xs, ys):
poly.append([self._clamp01(float(y)), self._clamp01(float(x))])
# Ensure closure for consistency (optional)
if poly and poly[0] != poly[-1]:
poly.append(list(poly[0]))
out.append((class_idx, bbox, poly))
continue
# bbox format: xc yc w h
if len(coords) >= 4:
xc, yc, w, h = coords[:4]
x_min = xc - w / 2.0
y_min = yc - h / 2.0
x_max = xc + w / 2.0
y_max = yc + h / 2.0
bbox = (
self._clamp01(float(x_min)),
self._clamp01(float(y_min)),
self._clamp01(float(x_max)),
self._clamp01(float(y_max)),
)
out.append((class_idx, bbox, None))
return out
@staticmethod
def _clamp01(value: float) -> float:
if value < 0.0:
return 0.0
if value > 1.0:
return 1.0
return float(value)
def _load_image(self): def _load_image(self):
"""Load and display an image file.""" """Load and display an image file."""
# Get last opened directory from QSettings # Get last opened directory from QSettings
@@ -180,18 +580,35 @@ class AnnotationTab(QWidget):
self.current_image_path = file_path self.current_image_path = file_path
# Store the directory for next time # Store the directory for next time
settings.setValue( settings.setValue("annotation_tab/last_directory", str(Path(file_path).parent))
"annotation_tab/last_directory", str(Path(file_path).parent)
)
# Get or create image in database # Get or create image in database
relative_path = str(Path(file_path).name) # Simplified for now repo_root = self.config_manager.get_image_repository_path()
relative_path: str
try:
if repo_root:
repo_root_path = Path(repo_root).expanduser().resolve()
file_resolved = Path(file_path).expanduser().resolve()
if file_resolved.is_relative_to(repo_root_path):
relative_path = file_resolved.relative_to(repo_root_path).as_posix()
else:
# Fallback: store filename only to avoid leaking absolute paths.
relative_path = file_resolved.name
else:
relative_path = str(Path(file_path).name)
except Exception:
relative_path = str(Path(file_path).name)
self.current_image_id = self.db_manager.get_or_create_image( self.current_image_id = self.db_manager.get_or_create_image(
relative_path, relative_path,
Path(file_path).name, Path(file_path).name,
self.current_image.width, self.current_image.width,
self.current_image.height, self.current_image.height,
) )
# Mark as managed by Annotation tab so it appears in the left list.
try:
self.db_manager.set_image_source(int(self.current_image_id), "annotation_tab")
except Exception:
pass
# Display image using the AnnotationCanvasWidget # Display image using the AnnotationCanvasWidget
self.annotation_canvas.load_image(self.current_image) self.annotation_canvas.load_image(self.current_image)
@@ -199,6 +616,9 @@ class AnnotationTab(QWidget):
# Load and display any existing annotations for this image # Load and display any existing annotations for this image
self._load_annotations_for_current_image() self._load_annotations_for_current_image()
# Update annotated images list (newly annotated image added/selected).
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
# Update info label # Update info label
self._update_image_info() self._update_image_info()
@@ -206,9 +626,7 @@ class AnnotationTab(QWidget):
except ImageLoadError as e: except ImageLoadError as e:
logger.error(f"Failed to load image: {e}") logger.error(f"Failed to load image: {e}")
QMessageBox.critical( QMessageBox.critical(self, "Error Loading Image", f"Failed to load image:\n{str(e)}")
self, "Error Loading Image", f"Failed to load image:\n{str(e)}"
)
except Exception as e: except Exception as e:
logger.error(f"Unexpected error loading image: {e}") logger.error(f"Unexpected error loading image: {e}")
QMessageBox.critical(self, "Error", f"Unexpected error:\n{str(e)}") QMessageBox.critical(self, "Error", f"Unexpected error:\n{str(e)}")
@@ -296,6 +714,9 @@ class AnnotationTab(QWidget):
# Reload annotations from DB and redraw (respecting current class filter) # Reload annotations from DB and redraw (respecting current class filter)
self._load_annotations_for_current_image() self._load_annotations_for_current_image()
# Update list counts.
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
except Exception as e: except Exception as e:
logger.error(f"Failed to save annotation: {e}") logger.error(f"Failed to save annotation: {e}")
QMessageBox.critical(self, "Error", f"Failed to save annotation:\n{str(e)}") QMessageBox.critical(self, "Error", f"Failed to save annotation:\n{str(e)}")
@@ -340,9 +761,7 @@ class AnnotationTab(QWidget):
if not self.current_image_id: if not self.current_image_id:
return return
logger.debug( logger.debug(f"Class color changed; reloading annotations for image ID {self.current_image_id}")
f"Class color changed; reloading annotations for image ID {self.current_image_id}"
)
self._load_annotations_for_current_image() self._load_annotations_for_current_image()
def _on_class_selected(self, class_data): def _on_class_selected(self, class_data):
@@ -355,9 +774,7 @@ class AnnotationTab(QWidget):
if class_data: if class_data:
logger.debug(f"Object class selected: {class_data['class_name']}") logger.debug(f"Object class selected: {class_data['class_name']}")
else: else:
logger.debug( logger.debug('No class selected ("-- Select Class --"), showing all annotations')
'No class selected ("-- Select Class --"), showing all annotations'
)
# Changing the class filter invalidates any previous selection # Changing the class filter invalidates any previous selection
self.selected_annotation_ids = [] self.selected_annotation_ids = []
@@ -390,9 +807,7 @@ class AnnotationTab(QWidget):
question = "Are you sure you want to delete the selected annotation?" question = "Are you sure you want to delete the selected annotation?"
title = "Delete Annotation" title = "Delete Annotation"
else: else:
question = ( question = f"Are you sure you want to delete the {count} selected annotations?"
f"Are you sure you want to delete the {count} selected annotations?"
)
title = "Delete Annotations" title = "Delete Annotations"
reply = QMessageBox.question( reply = QMessageBox.question(
@@ -420,13 +835,11 @@ class AnnotationTab(QWidget):
QMessageBox.warning( QMessageBox.warning(
self, self,
"Partial Failure", "Partial Failure",
"Some annotations could not be deleted:\n" "Some annotations could not be deleted:\n" + ", ".join(str(a) for a in failed_ids),
+ ", ".join(str(a) for a in failed_ids),
) )
else: else:
logger.info( logger.info(
f"Deleted {count} annotation(s): " f"Deleted {count} annotation(s): " + ", ".join(str(a) for a in self.selected_annotation_ids)
+ ", ".join(str(a) for a in self.selected_annotation_ids)
) )
# Clear selection and reload annotations for the current image from DB # Clear selection and reload annotations for the current image from DB
@@ -434,6 +847,9 @@ class AnnotationTab(QWidget):
self.annotation_tools.set_has_selected_annotation(False) self.annotation_tools.set_has_selected_annotation(False)
self._load_annotations_for_current_image() self._load_annotations_for_current_image()
# Update list counts.
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
except Exception as e: except Exception as e:
logger.error(f"Failed to delete annotations: {e}") logger.error(f"Failed to delete annotations: {e}")
QMessageBox.critical( QMessageBox.critical(
@@ -456,17 +872,13 @@ class AnnotationTab(QWidget):
return return
try: try:
self.current_annotations = self.db_manager.get_annotations_for_image( self.current_annotations = self.db_manager.get_annotations_for_image(self.current_image_id)
self.current_image_id
)
# New annotations loaded; reset any selection # New annotations loaded; reset any selection
self.selected_annotation_ids = [] self.selected_annotation_ids = []
self.annotation_tools.set_has_selected_annotation(False) self.annotation_tools.set_has_selected_annotation(False)
self._redraw_annotations_for_current_filter() self._redraw_annotations_for_current_filter()
except Exception as e: except Exception as e:
logger.error( logger.error(f"Failed to load annotations for image {self.current_image_id}: {e}")
f"Failed to load annotations for image {self.current_image_id}: {e}"
)
QMessageBox.critical( QMessageBox.critical(
self, self,
"Error", "Error",
@@ -490,10 +902,7 @@ class AnnotationTab(QWidget):
drawn_count = 0 drawn_count = 0
for ann in self.current_annotations: for ann in self.current_annotations:
# Filter by class if one is selected # Filter by class if one is selected
if ( if selected_class_id is not None and ann.get("class_id") != selected_class_id:
selected_class_id is not None
and ann.get("class_id") != selected_class_id
):
continue continue
if ann.get("segmentation_mask"): if ann.get("segmentation_mask"):
@@ -545,22 +954,185 @@ class AnnotationTab(QWidget):
settings = QSettings("microscopy_app", "object_detection") settings = QSettings("microscopy_app", "object_detection")
# Save main splitter state # Save main splitter state
settings.setValue( settings.setValue("annotation_tab/main_splitter_state", self.main_splitter.saveState())
"annotation_tab/main_splitter_state", self.main_splitter.saveState()
)
# Save left splitter state # Save left splitter state
settings.setValue( settings.setValue("annotation_tab/left_splitter_state", self.left_splitter.saveState())
"annotation_tab/left_splitter_state", self.left_splitter.saveState()
)
# Save right splitter state # Save right splitter state
settings.setValue( settings.setValue("annotation_tab/right_splitter_state", self.right_splitter.saveState())
"annotation_tab/right_splitter_state", self.right_splitter.saveState()
)
logger.debug("Saved annotation tab splitter states") logger.debug("Saved annotation tab splitter states")
def refresh(self): def refresh(self):
"""Refresh the tab.""" """Refresh the tab."""
pass self._refresh_annotated_images_list(select_image_id=self.current_image_id)
# ==================== Annotated images list ====================
def _refresh_annotated_images_list(self, select_image_id: int | None = None) -> None:
"""Reload annotated-images list from the database."""
if not hasattr(self, "annotated_images_table"):
return
# Preserve selection if possible
desired_id = select_image_id if select_image_id is not None else self.current_image_id
name_filter = ""
if hasattr(self, "annotated_filter_edit"):
name_filter = self.annotated_filter_edit.text().strip()
try:
rows = self.db_manager.get_images_summary(name_filter=name_filter, source_filter="annotation_tab")
except Exception as exc:
logger.error(f"Failed to load images summary: {exc}")
rows = []
sorting_enabled = self.annotated_images_table.isSortingEnabled()
self.annotated_images_table.setSortingEnabled(False)
self.annotated_images_table.blockSignals(True)
try:
self.annotated_images_table.setRowCount(len(rows))
for r, entry in enumerate(rows):
image_name = str(entry.get("filename") or "")
count = int(entry.get("annotation_count") or 0)
rel_path = str(entry.get("relative_path") or "")
name_item = QTableWidgetItem(image_name)
# Tooltip shows full path of the image (best-effort: repository_root + relative_path)
full_path = rel_path
repo_root = self.config_manager.get_image_repository_path()
if repo_root and rel_path and not Path(rel_path).is_absolute():
try:
full_path = str((Path(repo_root) / rel_path).resolve())
except Exception:
full_path = str(Path(repo_root) / rel_path)
name_item.setToolTip(full_path)
name_item.setData(Qt.UserRole, int(entry.get("id")))
name_item.setData(Qt.UserRole + 1, rel_path)
count_item = QTableWidgetItem()
# Use EditRole to ensure numeric sorting.
count_item.setData(Qt.EditRole, count)
count_item.setData(Qt.UserRole, int(entry.get("id")))
count_item.setData(Qt.UserRole + 1, rel_path)
self.annotated_images_table.setItem(r, 0, name_item)
self.annotated_images_table.setItem(r, 1, count_item)
# Re-select desired row
if desired_id is not None:
for r in range(self.annotated_images_table.rowCount()):
item = self.annotated_images_table.item(r, 0)
if item and item.data(Qt.UserRole) == desired_id:
self.annotated_images_table.selectRow(r)
break
finally:
self.annotated_images_table.blockSignals(False)
self.annotated_images_table.setSortingEnabled(sorting_enabled)
def _on_annotated_image_selected(self) -> None:
"""When user clicks an item in the list, load that image in the annotation canvas."""
selected = self.annotated_images_table.selectedItems()
if not selected:
return
# Row selection -> take the first column item
row = self.annotated_images_table.currentRow()
item = self.annotated_images_table.item(row, 0)
if not item:
return
image_id = item.data(Qt.UserRole)
rel_path = item.data(Qt.UserRole + 1) or ""
if not image_id:
return
image_path = self._resolve_image_path_for_relative_path(rel_path)
if not image_path:
QMessageBox.warning(
self,
"Image Not Found",
"Unable to locate image on disk for:\n"
f"{rel_path}\n\n"
"Tip: set Settings → Image repository path to the folder containing your images.",
)
return
try:
self.current_image = Image(image_path)
self.current_image_path = image_path
self.current_image_id = int(image_id)
self.annotation_canvas.load_image(self.current_image)
self._load_annotations_for_current_image()
self._update_image_info()
except ImageLoadError as exc:
logger.error(f"Failed to load image '{image_path}': {exc}")
QMessageBox.critical(self, "Error Loading Image", f"Failed to load image:\n{exc}")
except Exception as exc:
logger.error(f"Unexpected error loading image '{image_path}': {exc}")
QMessageBox.critical(self, "Error", f"Unexpected error:\n{exc}")
def _resolve_image_path_for_relative_path(self, relative_path: str) -> str | None:
"""Best-effort conversion from a DB relative_path to an on-disk file path."""
rel = (relative_path or "").strip()
if not rel:
return None
candidates: list[Path] = []
# 1) Repository root + relative
repo_root = (self.config_manager.get_image_repository_path() or "").strip()
if repo_root:
candidates.append(Path(repo_root) / rel)
# 2) If the DB path is absolute, try it directly.
candidates.append(Path(rel))
# 3) Try the directory of the currently loaded image (helps when DB stores only filenames)
if self.current_image_path:
try:
candidates.append(Path(self.current_image_path).expanduser().resolve().parent / Path(rel).name)
except Exception:
pass
# 4) Try the last directory used by the annotation file picker
try:
settings = QSettings("microscopy_app", "object_detection")
last_dir = settings.value("annotation_tab/last_directory", None)
if last_dir:
candidates.append(Path(str(last_dir)) / Path(rel).name)
except Exception:
pass
# 4b) Try the last directory used by the image import picker
try:
settings = QSettings("microscopy_app", "object_detection")
last_import_dir = settings.value("annotation_tab/last_image_import_directory", None)
if last_import_dir:
candidates.append(Path(str(last_import_dir)) / Path(rel).name)
except Exception:
pass
for p in candidates:
try:
expanded = p.expanduser()
if expanded.exists() and expanded.is_file():
return str(expanded.resolve())
except Exception:
continue
# 5) Fallback: search by filename within repository root.
filename = Path(rel).name
if repo_root and filename:
root = Path(repo_root).expanduser()
try:
if root.exists():
for match in root.rglob(filename):
if match.is_file():
return str(match.resolve())
except Exception as exc:
logger.debug(f"Search for {filename} under {root} failed: {exc}")
return None

View File

@@ -40,6 +40,12 @@ class ResultsTab(QWidget):
self.db_manager = db_manager self.db_manager = db_manager
self.config_manager = config_manager self.config_manager = config_manager
# Pagination
self.page_size = 200
self.current_page = 0 # 0-based
self.total_runs = 0
self.total_pages = 0
self.detection_summary: List[Dict] = [] self.detection_summary: List[Dict] = []
self.current_selection: Optional[Dict] = None self.current_selection: Optional[Dict] = None
self.current_image: Optional[Image] = None self.current_image: Optional[Image] = None
@@ -66,6 +72,20 @@ class ResultsTab(QWidget):
self.refresh_btn.clicked.connect(self.refresh) self.refresh_btn.clicked.connect(self.refresh)
controls_layout.addWidget(self.refresh_btn) controls_layout.addWidget(self.refresh_btn)
self.prev_page_btn = QPushButton("◀ Prev")
self.prev_page_btn.setToolTip("Previous page")
self.prev_page_btn.clicked.connect(self._prev_page)
controls_layout.addWidget(self.prev_page_btn)
self.next_page_btn = QPushButton("Next ▶")
self.next_page_btn.setToolTip("Next page")
self.next_page_btn.clicked.connect(self._next_page)
controls_layout.addWidget(self.next_page_btn)
self.page_label = QLabel("Page 1/1")
self.page_label.setMinimumWidth(100)
controls_layout.addWidget(self.page_label)
self.delete_all_btn = QPushButton("Delete All Detections") self.delete_all_btn = QPushButton("Delete All Detections")
self.delete_all_btn.setToolTip( self.delete_all_btn.setToolTip(
"Permanently delete ALL detections from the database.\n" "This cannot be undone." "Permanently delete ALL detections from the database.\n" "This cannot be undone."
@@ -183,6 +203,8 @@ class ResultsTab(QWidget):
def refresh(self): def refresh(self):
"""Refresh the detection list and preview.""" """Refresh the detection list and preview."""
# Reset to first page on refresh.
self.current_page = 0
self._load_detection_summary() self._load_detection_summary()
self._populate_results_table() self._populate_results_table()
self.current_selection = None self.current_selection = None
@@ -193,57 +215,135 @@ class ResultsTab(QWidget):
if hasattr(self, "export_labels_btn"): if hasattr(self, "export_labels_btn"):
self.export_labels_btn.setEnabled(False) self.export_labels_btn.setEnabled(False)
self._update_pagination_controls()
def _prev_page(self):
"""Go to previous results page."""
if self.current_page <= 0:
return
self.current_page -= 1
self._load_detection_summary()
self._populate_results_table()
self._update_pagination_controls()
def _next_page(self):
"""Go to next results page."""
if self.total_pages and self.current_page >= (self.total_pages - 1):
return
self.current_page += 1
self._load_detection_summary()
self._populate_results_table()
self._update_pagination_controls()
def _update_pagination_controls(self):
"""Update pagination label/button enabled state."""
# Default state (safe)
total_pages = max(int(self.total_pages or 0), 1)
current_page = min(max(int(self.current_page or 0), 0), total_pages - 1)
self.current_page = current_page
self.page_label.setText(f"Page {current_page + 1}/{total_pages}")
self.prev_page_btn.setEnabled(current_page > 0)
self.next_page_btn.setEnabled(current_page < (total_pages - 1))
def _load_detection_summary(self): def _load_detection_summary(self):
"""Load latest detection summaries grouped by image + model.""" """Load latest detection summaries grouped by image + model."""
try: try:
detections = self.db_manager.get_detections(limit=500) # Prefer run summaries (supports zero-detection runs). Fall back to legacy
summary_map: Dict[tuple, Dict] = {} # detection aggregation if the DB/table isn't available.
try:
self.total_runs = int(self.db_manager.get_detection_run_total())
except Exception:
self.total_runs = 0
for det in detections: self.total_pages = (self.total_runs + self.page_size - 1) // self.page_size if self.total_runs > 0 else 1
key = (det["image_id"], det["model_id"])
metadata = det.get("metadata") or {} offset = int(self.current_page) * int(self.page_size)
entry = summary_map.setdefault( runs = self.db_manager.get_detection_run_summaries(
key, limit=int(self.page_size),
offset=offset,
)
summary: List[Dict] = []
for run in runs:
meta = run.get("metadata") or {}
classes_raw = run.get("classes")
classes = set()
if isinstance(classes_raw, str) and classes_raw.strip():
classes = {c.strip() for c in classes_raw.split(",") if c.strip()}
summary.append(
{ {
"image_id": det["image_id"], "image_id": run.get("image_id"),
"model_id": det["model_id"], "model_id": run.get("model_id"),
"image_path": det.get("image_path"), "image_path": run.get("image_path"),
"image_filename": det.get("image_filename") or det.get("image_path"), "image_filename": run.get("image_filename") or run.get("image_path"),
"model_name": det.get("model_name", ""), "model_name": run.get("model_name", ""),
"model_version": det.get("model_version", ""), "model_version": run.get("model_version", ""),
"last_detected": det.get("detected_at"), "last_detected": run.get("detected_at"),
"count": 0, "count": int(run.get("count") or 0),
"classes": set(), "classes": classes,
"source_path": metadata.get("source_path"), "source_path": meta.get("source_path"),
"repository_root": metadata.get("repository_root"), "repository_root": meta.get("repository_root"),
}, }
) )
entry["count"] += 1 self.detection_summary = summary
if det.get("detected_at") and (
not entry.get("last_detected") or str(det.get("detected_at")) > str(entry.get("last_detected"))
):
entry["last_detected"] = det.get("detected_at")
if det.get("class_name"):
entry["classes"].add(det["class_name"])
if metadata.get("source_path") and not entry.get("source_path"):
entry["source_path"] = metadata.get("source_path")
if metadata.get("repository_root") and not entry.get("repository_root"):
entry["repository_root"] = metadata.get("repository_root")
self.detection_summary = sorted(
summary_map.values(),
key=lambda x: str(x.get("last_detected") or ""),
reverse=True,
)
except Exception as e: except Exception as e:
logger.error(f"Failed to load detection summary: {e}") logger.error(f"Failed to load detection run summaries, falling back: {e}")
QMessageBox.critical( # Disable pagination if we can't page via detection_runs.
self, self.total_runs = 0
"Error", self.total_pages = 1
f"Failed to load detection results:\n{str(e)}", self.current_page = 0
) try:
self.detection_summary = [] detections = self.db_manager.get_detections(limit=int(self.page_size))
summary_map: Dict[tuple, Dict] = {}
for det in detections:
key = (det["image_id"], det["model_id"])
metadata = det.get("metadata") or {}
entry = summary_map.setdefault(
key,
{
"image_id": det["image_id"],
"model_id": det["model_id"],
"image_path": det.get("image_path"),
"image_filename": det.get("image_filename") or det.get("image_path"),
"model_name": det.get("model_name", ""),
"model_version": det.get("model_version", ""),
"last_detected": det.get("detected_at"),
"count": 0,
"classes": set(),
"source_path": metadata.get("source_path"),
"repository_root": metadata.get("repository_root"),
},
)
entry["count"] += 1
if det.get("detected_at") and (
not entry.get("last_detected") or str(det.get("detected_at")) > str(entry.get("last_detected"))
):
entry["last_detected"] = det.get("detected_at")
if det.get("class_name"):
entry["classes"].add(det["class_name"])
if metadata.get("source_path") and not entry.get("source_path"):
entry["source_path"] = metadata.get("source_path")
if metadata.get("repository_root") and not entry.get("repository_root"):
entry["repository_root"] = metadata.get("repository_root")
self.detection_summary = sorted(
summary_map.values(),
key=lambda x: str(x.get("last_detected") or ""),
reverse=True,
)
except Exception as inner:
logger.error(f"Failed to load detection summary: {inner}")
QMessageBox.critical(
self,
"Error",
f"Failed to load detection results:\n{str(inner)}",
)
self.detection_summary = []
def _populate_results_table(self): def _populate_results_table(self):
"""Populate the table widget with detection summaries.""" """Populate the table widget with detection summaries."""

View File

@@ -905,9 +905,14 @@ class TrainingTab(QWidget):
if stats["registered_images"]: if stats["registered_images"]:
message += f" {stats['registered_images']} image(s) had database-backed annotations." message += f" {stats['registered_images']} image(s) had database-backed annotations."
if stats["missing_records"]: if stats["missing_records"]:
message += ( preserved = stats.get("preserved_existing_labels", 0)
f" {stats['missing_records']} image(s) had no database entry; empty label files were written." if preserved:
) message += (
f" {stats['missing_records']} image(s) had no database annotations; "
f"preserved {preserved} existing label file(s) (no overwrite)."
)
else:
message += f" {stats['missing_records']} image(s) had no database annotations; empty label files were written."
split_messages.append(message) split_messages.append(message)
for msg in split_messages: for msg in split_messages:
@@ -929,6 +934,7 @@ class TrainingTab(QWidget):
processed_images = 0 processed_images = 0
registered_images = 0 registered_images = 0
missing_records = 0 missing_records = 0
preserved_existing_labels = 0
total_annotations = 0 total_annotations = 0
for image_file in images_dir.rglob("*"): for image_file in images_dir.rglob("*"):
@@ -950,6 +956,12 @@ class TrainingTab(QWidget):
else: else:
missing_records += 1 missing_records += 1
# If the database has no entry for this image, do not overwrite an existing label file
# with an empty one (preserve any manually created labels on disk).
if not found and label_path.exists():
preserved_existing_labels += 1
continue
annotations_written = 0 annotations_written = 0
with open(label_path, "w", encoding="utf-8") as handle: with open(label_path, "w", encoding="utf-8") as handle:
for entry in annotation_entries: for entry in annotation_entries:
@@ -979,6 +991,7 @@ class TrainingTab(QWidget):
"processed_images": processed_images, "processed_images": processed_images,
"registered_images": registered_images, "registered_images": registered_images,
"missing_records": missing_records, "missing_records": missing_records,
"preserved_existing_labels": preserved_existing_labels,
"total_annotations": total_annotations, "total_annotations": total_annotations,
} }
@@ -1008,6 +1021,10 @@ class TrainingTab(QWidget):
resolved_image = image_path.resolve() resolved_image = image_path.resolve()
candidates: List[str] = [] candidates: List[str] = []
# Some DBs store absolute paths in `images.relative_path`.
# Include the absolute resolved path as a lookup candidate.
candidates.append(resolved_image.as_posix())
for base in (dataset_root, images_dir): for base in (dataset_root, images_dir):
try: try:
relative = resolved_image.relative_to(base.resolve()).as_posix() relative = resolved_image.relative_to(base.resolve()).as_posix()
@@ -1032,6 +1049,13 @@ class TrainingTab(QWidget):
return False, [] return False, []
annotations = self.db_manager.get_annotations_for_image(image_row["id"]) or [] annotations = self.db_manager.get_annotations_for_image(image_row["id"]) or []
# Treat "found" as "has database-backed annotations".
# If the image exists in DB but has no annotations yet, we don't want to overwrite
# an existing label file on disk with an empty one.
if not annotations:
return False, []
yolo_entries: List[Dict[str, Any]] = [] yolo_entries: List[Dict[str, Any]] = []
for ann in annotations: for ann in annotations:

View File

@@ -84,25 +84,19 @@ class InferenceEngine:
# Save detections to database, replacing any previous results for this image/model # Save detections to database, replacing any previous results for this image/model
if save_to_db: if save_to_db:
deleted_count = self.db_manager.delete_detections_for_image( deleted_count = self.db_manager.delete_detections_for_image(image_id, self.model_id)
image_id, self.model_id
)
if detections: if detections:
detection_records = [] detection_records = []
for det in detections: for det in detections:
# Use normalized bbox from detection # Use normalized bbox from detection
bbox_normalized = det[ bbox_normalized = det["bbox_normalized"] # [x_min, y_min, x_max, y_max]
"bbox_normalized"
] # [x_min, y_min, x_max, y_max]
metadata = { metadata = {
"class_id": det["class_id"], "class_id": det["class_id"],
"source_path": str(Path(image_path).resolve()), "source_path": str(Path(image_path).resolve()),
} }
if repository_root: if repository_root:
metadata["repository_root"] = str( metadata["repository_root"] = str(Path(repository_root).resolve())
Path(repository_root).resolve()
)
record = { record = {
"image_id": image_id, "image_id": image_id,
@@ -115,16 +109,27 @@ class InferenceEngine:
} }
detection_records.append(record) detection_records.append(record)
inserted_count = self.db_manager.add_detections_batch( inserted_count = self.db_manager.add_detections_batch(detection_records)
detection_records logger.info(f"Saved {inserted_count} detections to database (replaced {deleted_count})")
)
logger.info(
f"Saved {inserted_count} detections to database (replaced {deleted_count})"
)
else: else:
logger.info( logger.info(f"Detection run removed {deleted_count} stale entries but produced no new detections")
f"Detection run removed {deleted_count} stale entries but produced no new detections"
# Always store a run summary so the Results tab can show zero-detection runs.
try:
run_metadata = {
"source_path": str(Path(image_path).resolve()),
}
if repository_root:
run_metadata["repository_root"] = str(Path(repository_root).resolve())
self.db_manager.upsert_detection_run(
image_id=image_id,
model_id=self.model_id,
count=len(detections),
metadata=run_metadata,
) )
except Exception as exc:
# Non-fatal: detection records may still be present.
logger.warning(f"Failed to store detection run summary: {exc}")
return { return {
"success": True, "success": True,
@@ -232,9 +237,7 @@ class InferenceEngine:
for det in detections: for det in detections:
# Get color for this class # Get color for this class
class_name = det["class_name"] class_name = det["class_name"]
color_hex = bbox_colors.get( color_hex = bbox_colors.get(class_name, bbox_colors.get("default", "#00FF00"))
class_name, bbox_colors.get("default", "#00FF00")
)
color = self._hex_to_bgr(color_hex) color = self._hex_to_bgr(color_hex)
# Draw segmentation mask if available and requested # Draw segmentation mask if available and requested
@@ -243,10 +246,7 @@ class InferenceEngine:
if mask_normalized and len(mask_normalized) > 0: if mask_normalized and len(mask_normalized) > 0:
# Convert normalized coordinates to absolute pixels # Convert normalized coordinates to absolute pixels
mask_points = np.array( mask_points = np.array(
[ [[int(pt[0] * width), int(pt[1] * height)] for pt in mask_normalized],
[int(pt[0] * width), int(pt[1] * height)]
for pt in mask_normalized
],
dtype=np.int32, dtype=np.int32,
) )
@@ -270,9 +270,7 @@ class InferenceEngine:
label = f"{class_name} {det['confidence']:.2f}" label = f"{class_name} {det['confidence']:.2f}"
# Draw label background # Draw label background
(label_w, label_h), baseline = cv2.getTextSize( (label_w, label_h), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
)
cv2.rectangle( cv2.rectangle(
img, img,
(x1, y1 - label_h - baseline - 5), (x1, y1 - label_h - baseline - 5),

View File

@@ -186,7 +186,7 @@ class YOLOWrapper:
raise RuntimeError(f"Failed to load model from {self.model_path}") raise RuntimeError(f"Failed to load model from {self.model_path}")
prepared_source, cleanup_path = self._prepare_source(source) prepared_source, cleanup_path = self._prepare_source(source)
imgsz = 1088 imgsz = 640 # 1088
try: try:
logger.info(f"Running inference on {source} -> prepared_source {prepared_source}") logger.info(f"Running inference on {source} -> prepared_source {prepared_source}")
results = self.model.predict( results = self.model.predict(

View File

@@ -6,6 +6,9 @@ import cv2
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
from skimage.restoration import rolling_ball
from skimage.filters import threshold_otsu
from scipy.ndimage import median_filter, gaussian_filter
from src.utils.logger import get_logger from src.utils.logger import get_logger
from src.utils.file_utils import validate_file_path, is_image_file from src.utils.file_utils import validate_file_path, is_image_file
@@ -37,6 +40,8 @@ def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.5) -> np.ndarray:
a1[a1 > p999] = p999 a1[a1 > p999] = p999
a1 /= a1.max() a1 /= a1.max()
# print("Using get_pseudo_rgb")
if 1: if 1:
a2 = a1.copy() a2 = a1.copy()
a2 = a2**gamma a2 = a2**gamma
@@ -46,14 +51,75 @@ def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.5) -> np.ndarray:
p9999 = np.percentile(a3, 99.99) p9999 = np.percentile(a3, 99.99)
a3[a3 > p9999] = p9999 a3[a3 > p9999] = p9999
a3 /= a3.max() a3 /= a3.max()
out = np.stack([a1, a2, a3], axis=0)
else:
out = np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
return out
# return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0) # return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
# return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0) # return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
out = np.stack([a1, a2, a3], axis=0)
# print(any(np.isnan(out).flatten())) # print(any(np.isnan(out).flatten()))
def _get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.3) -> np.ndarray:
"""
Convert a grayscale image to a pseudo-RGB image using a gamma correction.
Args:
arr: Input grayscale image as numpy array
Returns:
Pseudo-RGB image as numpy array
"""
if arr.ndim != 2:
raise ValueError("Input array must be a grayscale image with shape (H, W)")
radius = 80
# bg = rolling_ball(arr, radius=radius)
a1 = arr.copy().astype(np.float32)
a1 = a1.astype(np.float32)
# a1 -= bg
# a1[a1 < 0] = 0
# a1 -= np.percentile(a1, 2)
# a1[a1 < 0] = 0
p999 = np.percentile(a1, 99.99)
a1[a1 > p999] = p999
a1 /= a1.max()
print("Using get_pseudo_rgb")
if 1:
a2 = a1.copy()
_a2 = a2**gamma
thr = threshold_otsu(_a2)
mask = gaussian_filter((_a2 > thr).astype(np.float32), sigma=5)
mask[mask > 0.0001] = 1
mask[mask <= 0.0001] = 0
a2 *= mask
# bg2 = rolling_ball(a2, radius=radius)
# a2 -= bg2
a2 -= np.percentile(a2, 2)
a2[a2 < 0] = 0
a2 /= a2.max()
a3 = a1.copy()
p9999 = np.percentile(a3, 99.99)
a3[a3 > p9999] = p9999
a3 /= a3.max()
out = np.stack([a1, a2, _a2], axis=0)
else:
out = np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
return out return out
# return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
# return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
# print(any(np.isnan(out).flatten()))
class ImageLoadError(Exception): class ImageLoadError(Exception):
"""Exception raised when an image cannot be loaded.""" """Exception raised when an image cannot be loaded."""
@@ -330,17 +396,20 @@ class Image:
return self._channels >= 3 return self._channels >= 3
def save(self, path: Union[str, Path], pseudo_rgb: bool = True) -> None: def save(self, path: Union[str, Path], pseudo_rgb: bool = True) -> None:
# print("Image.save", self.data.shape)
if self.channels == 1: if self.channels == 1:
# print("Image.save grayscale")
if pseudo_rgb: if pseudo_rgb:
img = get_pseudo_rgb(self.data) img = get_pseudo_rgb(self.data)
print("Image.save", img.shape) # print("Image.save", img.shape)
else: else:
img = np.repeat(self.data, 3, axis=2) img = np.repeat(self.data, 3, axis=2)
# print("Image.save no pseudo", img.shape)
else: else:
raise NotImplementedError("Only grayscale images are supported for now.") raise NotImplementedError("Only grayscale images are supported for now.")
# print("Image.save imwrite", img.shape)
imwrite(path, data=img) imwrite(path, data=img)
def __repr__(self) -> str: def __repr__(self) -> str:

View File

@@ -97,7 +97,7 @@ class UT:
class_index: int = 0, class_index: int = 0,
): ):
"""Export rois to a file""" """Export rois to a file"""
with open(path / subfolder / f"{self.stem}.txt", "w") as f: with open(path / subfolder / f"{PREFIX}-{self.stem}.txt", "w") as f:
for i, roi in enumerate(self.rois): for i, roi in enumerate(self.rois):
rc = roi.subpixel_coordinates rc = roi.subpixel_coordinates
if rc is None: if rc is None:
@@ -129,8 +129,8 @@ class UT:
self.image = np.max(self.image[channel], axis=0) self.image = np.max(self.image[channel], axis=0)
print(self.image.shape) print(self.image.shape)
print(path / subfolder / f"{self.stem}.tif") print(path / subfolder / f"{PREFIX}_{self.stem}.tif")
with TiffWriter(path / subfolder / f"{self.stem}.tif") as tif: with TiffWriter(path / subfolder / f"{PREFIX}-{self.stem}.tif") as tif:
tif.write(self.image) tif.write(self.image)
@@ -145,11 +145,19 @@ if __name__ == "__main__":
action="store_false", action="store_false",
help="Source does not have labels, export only images", help="Source does not have labels, export only images",
) )
parser.add_argument("--prefix", help="Prefix for output files")
args = parser.parse_args() args = parser.parse_args()
PREFIX = args.prefix
# print(args) # print(args)
# aa # aa
# for path in args.input:
# print(path)
# ut = UT(path, args.no_labels)
# ut.export_image(args.output, plane_mode="max projection", channel=0)
# ut.export_rois(args.output, class_index=0)
for path in args.input: for path in args.input:
print("Path:", path) print("Path:", path)
if not args.no_labels: if not args.no_labels:

View File

@@ -5,7 +5,7 @@ from tifffile import imread, imwrite
from shapely.geometry import LineString from shapely.geometry import LineString
from copy import deepcopy from copy import deepcopy
from scipy.ndimage import zoom from scipy.ndimage import zoom
from skimage.restoration import rolling_ball
# debug # debug
from src.utils.image import Image from src.utils.image import Image
@@ -160,8 +160,11 @@ class YoloLabelReader:
class ImageSplitter: class ImageSplitter:
def __init__(self, image_path: Path, label_path: Path): def __init__(self, image_path: Path, label_path: Path, subtract_bg: bool = False):
self.image = imread(image_path) self.image = imread(image_path)
if subtract_bg:
self.image = self.image - rolling_ball(self.image, radius=12)
self.image[self.image < 0] = 0
self.image_path = image_path self.image_path = image_path
self.label_path = label_path self.label_path = label_path
if not label_path.exists(): if not label_path.exists():
@@ -273,13 +276,14 @@ def main(args):
if args.output: if args.output:
args.output.mkdir(exist_ok=True, parents=True) args.output.mkdir(exist_ok=True, parents=True)
(args.output / "images").mkdir(exist_ok=True) (args.output / "images").mkdir(exist_ok=True)
(args.output / "images-zoomed").mkdir(exist_ok=True) # (args.output / "images-zoomed").mkdir(exist_ok=True)
(args.output / "labels").mkdir(exist_ok=True) (args.output / "labels").mkdir(exist_ok=True)
for image_path in (args.input / "images").glob("*.tif"): for image_path in (args.input / "images").glob("*.tif"):
data = ImageSplitter( data = ImageSplitter(
image_path=image_path, image_path=image_path,
label_path=(args.input / "labels" / image_path.stem).with_suffix(".txt"), label_path=(args.input / "labels" / image_path.stem).with_suffix(".txt"),
subtract_bg=args.subtract_bg,
) )
if args.split_around_label: if args.split_around_label:
@@ -332,10 +336,20 @@ def main(args):
if labels is not None: if labels is not None:
with open(args.output / "labels" / f"{image_path.stem}_{tile_reference}.txt", "w") as f: with open(args.output / "labels" / f"{image_path.stem}_{tile_reference}.txt", "w") as f:
print(
f"Writing {len(labels)} labels to {args.output / 'labels' / f'{image_path.stem}_{tile_reference}.txt'}"
)
for label in labels: for label in labels:
# label.offset_label(tile.shape[1], tile.shape[0]) # label.offset_label(tile.shape[1], tile.shape[0])
f.write(label.to_string() + "\n") f.write(label.to_string() + "\n")
# { debug
if debug:
print(label.to_string())
# } debug
# break
# break
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
@@ -363,6 +377,7 @@ if __name__ == "__main__":
default=67, default=67,
help="Padding around the label when splitting around the label.", help="Padding around the label when splitting around the label.",
) )
parser.add_argument("-bg", "--subtract-bg", action="store_true", help="Subtract background from the image.")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)