3 Commits

11 changed files with 1266 additions and 97 deletions

View File

@@ -40,8 +40,15 @@ class DatabaseManager:
conn = self.get_connection()
try:
# Check if annotations table needs migration
# Pre-schema migrations.
# These must run BEFORE executing schema.sql because schema.sql may
# contain CREATE INDEX statements referencing newly added columns.
#
# 1) Check if annotations table needs migration (may drop an old table)
self._migrate_annotations_table(conn)
# 2) Ensure images table has the required columns (e.g. 'source')
self._migrate_images_table(conn)
conn.commit()
# Read schema file and execute
schema_path = Path(__file__).parent / "schema.sql"
@@ -53,6 +60,19 @@ class DatabaseManager:
finally:
conn.close()
def _migrate_images_table(self, conn: sqlite3.Connection) -> None:
"""Migrate images table to include the 'source' column if missing."""
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='images'")
if not cursor.fetchone():
return
cursor.execute("PRAGMA table_info(images)")
columns = {row[1] for row in cursor.fetchall()}
if "source" not in columns:
cursor.execute("ALTER TABLE images ADD COLUMN source TEXT")
def _migrate_annotations_table(self, conn: sqlite3.Connection) -> None:
"""
Migrate annotations table from old schema (class_name) to new schema (class_id).
@@ -85,6 +105,103 @@ class DatabaseManager:
conn.execute("PRAGMA foreign_keys = ON") # Enable foreign keys
return conn
# ==================== Detection Run Operations ====================
def upsert_detection_run(
self,
image_id: int,
model_id: int,
count: int,
metadata: Optional[Dict] = None,
) -> bool:
"""Insert/update a per-image per-model detection run summary.
This enables the UI to show runs even when zero detections were produced.
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO detection_runs (image_id, model_id, detected_at, count, metadata)
VALUES (?, ?, CURRENT_TIMESTAMP, ?, ?)
ON CONFLICT(image_id, model_id) DO UPDATE SET
detected_at = CURRENT_TIMESTAMP,
count = excluded.count,
metadata = excluded.metadata
""",
(
int(image_id),
int(model_id),
int(count),
json.dumps(metadata) if metadata else None,
),
)
conn.commit()
return True
finally:
conn.close()
def get_detection_run_summaries(self, limit: int = 500, offset: int = 0) -> List[Dict]:
"""Return latest detection run summaries grouped by image+model.
Includes runs with 0 detections.
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"""
SELECT
dr.image_id,
dr.model_id,
dr.detected_at,
dr.count,
dr.metadata,
i.relative_path AS image_path,
i.filename AS image_filename,
m.model_name,
m.model_version,
GROUP_CONCAT(DISTINCT d.class_name) AS classes
FROM detection_runs dr
JOIN images i ON dr.image_id = i.id
JOIN models m ON dr.model_id = m.id
LEFT JOIN detections d
ON d.image_id = dr.image_id AND d.model_id = dr.model_id
GROUP BY dr.image_id, dr.model_id
ORDER BY dr.detected_at DESC
LIMIT ? OFFSET ?
""",
(int(limit), int(offset)),
)
rows: List[Dict] = []
for row in cursor.fetchall():
item = dict(row)
if item.get("metadata"):
try:
item["metadata"] = json.loads(item["metadata"])
except Exception:
item["metadata"] = None
rows.append(item)
return rows
finally:
conn.close()
def get_detection_run_total(self) -> int:
"""Return total number of detection_runs rows."""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) AS cnt FROM detection_runs")
row = cursor.fetchone()
return int(row["cnt"] if row and row["cnt"] is not None else 0)
finally:
conn.close()
# ==================== Model Operations ====================
def add_model(
@@ -233,6 +350,7 @@ class DatabaseManager:
height: int,
captured_at: Optional[datetime] = None,
checksum: Optional[str] = None,
source: Optional[str] = None,
) -> int:
"""
Add a new image to the database.
@@ -253,10 +371,10 @@ class DatabaseManager:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO images (relative_path, filename, width, height, captured_at, checksum)
VALUES (?, ?, ?, ?, ?, ?)
INSERT INTO images (relative_path, filename, width, height, captured_at, checksum, source)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(relative_path, filename, width, height, captured_at, checksum),
(relative_path, filename, width, height, captured_at, checksum, source),
)
conn.commit()
return cursor.lastrowid
@@ -286,6 +404,18 @@ class DatabaseManager:
return existing["id"]
return self.add_image(relative_path, filename, width, height)
def set_image_source(self, image_id: int, source: Optional[str]) -> bool:
"""Set/update the source marker for an image row."""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("UPDATE images SET source = ? WHERE id = ?", (source, int(image_id)))
conn.commit()
return cursor.rowcount > 0
finally:
conn.close()
# ==================== Detection Operations ====================
def add_detection(
@@ -494,6 +624,14 @@ class DatabaseManager:
conn = self.get_connection()
try:
cursor = conn.cursor()
# Also clear detection run summaries so the Results tab does not continue
# to show historical runs after detections have been wiped.
try:
cursor.execute("DELETE FROM detection_runs")
except sqlite3.OperationalError:
# Backwards-compatible: table may not exist on older DB files.
pass
cursor.execute("DELETE FROM detections")
conn.commit()
return cursor.rowcount
@@ -658,6 +796,153 @@ class DatabaseManager:
# ==================== Annotation Operations ====================
def get_images_summary(
self,
name_filter: Optional[str] = None,
source_filter: Optional[str] = None,
order_by: str = "filename",
order_dir: str = "ASC",
limit: Optional[int] = None,
offset: int = 0,
) -> List[Dict]:
"""Return all images with annotation counts (including zero).
This is used by the Annotation tab to populate the image list even when
no annotations exist yet.
Args:
name_filter: Optional substring filter applied to filename/relative_path.
order_by: One of: 'filename', 'relative_path', 'annotation_count', 'added_at'.
order_dir: 'ASC' or 'DESC'.
limit: Optional max number of rows.
offset: Pagination offset.
Returns:
List of dicts: {id, relative_path, filename, added_at, annotation_count}
"""
allowed_order_by = {
"filename": "i.filename",
"relative_path": "i.relative_path",
"annotation_count": "annotation_count",
"added_at": "i.added_at",
}
order_expr = allowed_order_by.get(order_by, "i.filename")
dir_norm = str(order_dir).upper().strip()
if dir_norm not in {"ASC", "DESC"}:
dir_norm = "ASC"
conn = self.get_connection()
try:
params: List[Any] = []
where_clauses: List[str] = []
if name_filter:
token = f"%{name_filter}%"
where_clauses.append("(i.filename LIKE ? OR i.relative_path LIKE ?)")
params.extend([token, token])
if source_filter:
where_clauses.append("i.source = ?")
params.append(source_filter)
where_sql = ""
if where_clauses:
where_sql = "WHERE " + " AND ".join(where_clauses)
limit_sql = ""
if limit is not None:
limit_sql = " LIMIT ? OFFSET ?"
params.extend([int(limit), int(offset)])
query = f"""
SELECT
i.id,
i.relative_path,
i.filename,
i.added_at,
COUNT(a.id) AS annotation_count
FROM images i
LEFT JOIN annotations a ON a.image_id = i.id
{where_sql}
GROUP BY i.id
ORDER BY {order_expr} {dir_norm}
{limit_sql}
"""
cursor = conn.cursor()
cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()]
finally:
conn.close()
def get_annotated_images_summary(
self,
name_filter: Optional[str] = None,
order_by: str = "filename",
order_dir: str = "ASC",
limit: Optional[int] = None,
offset: int = 0,
) -> List[Dict]:
"""Return images that have at least one manual annotation.
Args:
name_filter: Optional substring filter applied to filename/relative_path.
order_by: One of: 'filename', 'relative_path', 'annotation_count', 'added_at'.
order_dir: 'ASC' or 'DESC'.
limit: Optional max number of rows.
offset: Pagination offset.
Returns:
List of dicts: {id, relative_path, filename, added_at, annotation_count}
"""
allowed_order_by = {
"filename": "i.filename",
"relative_path": "i.relative_path",
"annotation_count": "annotation_count",
"added_at": "i.added_at",
}
order_expr = allowed_order_by.get(order_by, "i.filename")
dir_norm = str(order_dir).upper().strip()
if dir_norm not in {"ASC", "DESC"}:
dir_norm = "ASC"
conn = self.get_connection()
try:
params: List[Any] = []
where_sql = ""
if name_filter:
# Case-insensitive substring search.
token = f"%{name_filter}%"
where_sql = "WHERE (i.filename LIKE ? OR i.relative_path LIKE ?)"
params.extend([token, token])
limit_sql = ""
if limit is not None:
limit_sql = " LIMIT ? OFFSET ?"
params.extend([int(limit), int(offset)])
query = f"""
SELECT
i.id,
i.relative_path,
i.filename,
i.added_at,
COUNT(a.id) AS annotation_count
FROM images i
JOIN annotations a ON a.image_id = i.id
{where_sql}
GROUP BY i.id
HAVING annotation_count > 0
ORDER BY {order_expr} {dir_norm}
{limit_sql}
"""
cursor = conn.cursor()
cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()]
finally:
conn.close()
def add_annotation(
self,
image_id: int,
@@ -763,6 +1048,27 @@ class DatabaseManager:
finally:
conn.close()
def delete_annotations_for_image(self, image_id: int) -> int:
"""Delete ALL annotations for a specific image.
This is primarily used for import/overwrite workflows.
Args:
image_id: ID of the image whose annotations should be deleted.
Returns:
Number of rows deleted.
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("DELETE FROM annotations WHERE image_id = ?", (int(image_id),))
conn.commit()
return int(cursor.rowcount or 0)
finally:
conn.close()
# ==================== Object Class Operations ====================
def get_object_classes(self) -> List[Dict]:

View File

@@ -7,7 +7,7 @@ CREATE TABLE IF NOT EXISTS models (
model_name TEXT NOT NULL,
model_version TEXT NOT NULL,
model_path TEXT NOT NULL,
base_model TEXT NOT NULL DEFAULT 'yolov8s.pt',
base_model TEXT NOT NULL DEFAULT 'yolo11s.pt',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
training_params TEXT, -- JSON string of training parameters
metrics TEXT, -- JSON string of validation metrics
@@ -23,7 +23,8 @@ CREATE TABLE IF NOT EXISTS images (
height INTEGER,
captured_at TIMESTAMP,
added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
checksum TEXT
checksum TEXT,
source TEXT
);
-- Detections table: stores detection results
@@ -44,6 +45,19 @@ CREATE TABLE IF NOT EXISTS detections (
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE
);
-- Detection runs table: stores per-image per-model run summaries (including 0 detections)
CREATE TABLE IF NOT EXISTS detection_runs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
image_id INTEGER NOT NULL,
model_id INTEGER NOT NULL,
detected_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
count INTEGER NOT NULL DEFAULT 0,
metadata TEXT,
UNIQUE(image_id, model_id),
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE,
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE
);
-- Object classes table: stores annotation class definitions with colors
CREATE TABLE IF NOT EXISTS object_classes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -80,9 +94,13 @@ CREATE INDEX IF NOT EXISTS idx_detections_model_id ON detections(model_id);
CREATE INDEX IF NOT EXISTS idx_detections_class_name ON detections(class_name);
CREATE INDEX IF NOT EXISTS idx_detections_detected_at ON detections(detected_at);
CREATE INDEX IF NOT EXISTS idx_detections_confidence ON detections(confidence);
CREATE INDEX IF NOT EXISTS idx_detection_runs_image_id ON detection_runs(image_id);
CREATE INDEX IF NOT EXISTS idx_detection_runs_model_id ON detection_runs(model_id);
CREATE INDEX IF NOT EXISTS idx_detection_runs_detected_at ON detection_runs(detected_at);
CREATE INDEX IF NOT EXISTS idx_images_relative_path ON images(relative_path);
CREATE INDEX IF NOT EXISTS idx_images_added_at ON images(added_at);
CREATE INDEX IF NOT EXISTS idx_images_source ON images(source);
CREATE INDEX IF NOT EXISTS idx_annotations_image_id ON annotations(image_id);
CREATE INDEX IF NOT EXISTS idx_annotations_class_id ON annotations(class_id);
CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at);
CREATE INDEX IF NOT EXISTS idx_object_classes_class_name ON object_classes(class_name);
CREATE INDEX IF NOT EXISTS idx_object_classes_class_name ON object_classes(class_name);

View File

@@ -199,6 +199,10 @@ class MainWindow(QMainWindow):
self.training_tab.refresh()
if hasattr(self, "results_tab"):
self.results_tab.refresh()
if hasattr(self, "annotation_tab"):
self.annotation_tab.refresh()
if hasattr(self, "validation_tab"):
self.validation_tab.refresh()
except Exception as e:
logger.error(f"Error applying settings: {e}")
@@ -215,6 +219,14 @@ class MainWindow(QMainWindow):
logger.debug(f"Switched to tab: {tab_name}")
self._update_status(f"Viewing: {tab_name}")
# Ensure the Annotation tab always shows up-to-date DB-backed lists.
try:
current_widget = self.tab_widget.widget(index)
if hasattr(self, "annotation_tab") and current_widget is self.annotation_tab:
self.annotation_tab.refresh()
except Exception as exc:
logger.debug(f"Failed to refresh annotation tab on selection: {exc}")
def _show_database_stats(self):
"""Show database statistics dialog."""
try:
@@ -527,6 +539,11 @@ class MainWindow(QMainWindow):
if hasattr(self, "training_tab"):
self.training_tab.shutdown()
if hasattr(self, "annotation_tab"):
# Best-effort refresh so DB-backed UI state is consistent at shutdown.
try:
self.annotation_tab.refresh()
except Exception:
pass
self.annotation_tab.save_state()
logger.info("Application closing")

View File

@@ -13,6 +13,11 @@ from PySide6.QtWidgets import (
QFileDialog,
QMessageBox,
QSplitter,
QLineEdit,
QTableWidget,
QTableWidgetItem,
QHeaderView,
QAbstractItemView,
)
from PySide6.QtCore import Qt, QSettings
from pathlib import Path
@@ -50,6 +55,32 @@ class AnnotationTab(QWidget):
self.main_splitter = QSplitter(Qt.Horizontal)
self.main_splitter.setHandleWidth(10)
# { Left-most pane: annotated images list
annotated_group = QGroupBox("Annotated Images")
annotated_layout = QVBoxLayout()
filter_row = QHBoxLayout()
filter_row.addWidget(QLabel("Filter:"))
self.annotated_filter_edit = QLineEdit()
self.annotated_filter_edit.setPlaceholderText("Type to filter by image name…")
self.annotated_filter_edit.textChanged.connect(self._refresh_annotated_images_list)
filter_row.addWidget(self.annotated_filter_edit, 1)
annotated_layout.addLayout(filter_row)
self.annotated_images_table = QTableWidget(0, 2)
self.annotated_images_table.setHorizontalHeaderLabels(["Image", "Annotations"])
self.annotated_images_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Stretch)
self.annotated_images_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeToContents)
self.annotated_images_table.setSelectionBehavior(QAbstractItemView.SelectRows)
self.annotated_images_table.setSelectionMode(QAbstractItemView.SingleSelection)
self.annotated_images_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
self.annotated_images_table.setSortingEnabled(True)
self.annotated_images_table.itemSelectionChanged.connect(self._on_annotated_image_selected)
annotated_layout.addWidget(self.annotated_images_table, 1)
annotated_group.setLayout(annotated_layout)
# }
# { Left splitter for image display and zoom info
self.left_splitter = QSplitter(Qt.Vertical)
self.left_splitter.setHandleWidth(10)
@@ -104,11 +135,27 @@ class AnnotationTab(QWidget):
load_group = QGroupBox("Image Loading")
load_layout = QVBoxLayout()
# Load image button
# Buttons row
button_layout = QHBoxLayout()
self.load_image_btn = QPushButton("Load Image")
self.load_image_btn.clicked.connect(self._load_image)
button_layout.addWidget(self.load_image_btn)
self.import_images_btn = QPushButton("Import Images")
self.import_images_btn.setToolTip(
"Import one or more images into the database.\n" "Images already present in the DB are skipped."
)
self.import_images_btn.clicked.connect(self._import_images)
button_layout.addWidget(self.import_images_btn)
self.import_annotations_btn = QPushButton("Import Annotations")
self.import_annotations_btn.setToolTip(
"Import YOLO .txt annotation files and register them with their corresponding images.\n"
"Existing annotations for those images will be overwritten."
)
self.import_annotations_btn.clicked.connect(self._import_annotations)
button_layout.addWidget(self.import_annotations_btn)
button_layout.addStretch()
load_layout.addLayout(button_layout)
@@ -120,12 +167,13 @@ class AnnotationTab(QWidget):
self.right_splitter.addWidget(load_group)
# }
# Add both splitters to the main horizontal splitter
# Add list + both splitters to the main horizontal splitter
self.main_splitter.addWidget(annotated_group)
self.main_splitter.addWidget(self.left_splitter)
self.main_splitter.addWidget(self.right_splitter)
# Set initial sizes: 75% for left (image), 25% for right (controls)
self.main_splitter.setSizes([750, 250])
# Set initial sizes: list (left), canvas (middle), controls (right)
self.main_splitter.setSizes([320, 650, 280])
layout.addWidget(self.main_splitter)
self.setLayout(layout)
@@ -133,6 +181,375 @@ class AnnotationTab(QWidget):
# Restore splitter positions from settings
self._restore_state()
# Populate list on startup.
self._refresh_annotated_images_list()
def _import_images(self) -> None:
"""Import one or more images into the database and refresh the list."""
settings = QSettings("microscopy_app", "object_detection")
last_dir = settings.value("annotation_tab/last_image_import_directory", None)
repo_root = (self.config_manager.get_image_repository_path() or "").strip()
if last_dir and Path(str(last_dir)).exists():
start_dir = str(last_dir)
elif repo_root and Path(repo_root).exists():
start_dir = repo_root
else:
start_dir = str(Path.home())
# Build filter string for supported extensions
patterns = " ".join(f"*{ext}" for ext in Image.SUPPORTED_EXTENSIONS)
file_paths, _ = QFileDialog.getOpenFileNames(
self,
"Select Image File(s)",
start_dir,
f"Images ({patterns})",
)
if not file_paths:
return
try:
settings.setValue("annotation_tab/last_image_import_directory", str(Path(file_paths[0]).parent))
# Keep compatibility with the existing image resolver fallback (it checks last_directory).
settings.setValue("annotation_tab/last_directory", str(Path(file_paths[0]).parent))
except Exception:
pass
imported = 0
tagged_existing = 0
skipped = 0
errors: list[str] = []
for fp in file_paths:
try:
img_path = Path(fp)
img = Image(str(img_path))
relative_path = self._compute_relative_path_for_repo(img_path)
# Skip if already present
existing = self.db_manager.get_image_by_path(relative_path)
if existing:
# If the image already exists (e.g. created earlier by other workflows),
# tag it as being managed by the Annotation tab so it becomes visible
# in the left list.
try:
self.db_manager.set_image_source(int(existing["id"]), "annotation_tab")
tagged_existing += 1
except Exception:
# If tagging fails, fall back to treating as skipped.
skipped += 1
continue
image_id = self.db_manager.add_image(
relative_path,
img_path.name,
img.width,
img.height,
source="annotation_tab",
)
try:
# In case the DB row was created by an older schema/migration path.
self.db_manager.set_image_source(image_id, "annotation_tab")
except Exception:
pass
imported += 1
except ImageLoadError as exc:
skipped += 1
errors.append(f"Failed to load image {fp}: {exc}")
except Exception as exc:
skipped += 1
errors.append(f"Failed to import image {fp}: {exc}")
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
msg = (
f"Imported: {imported}\n"
f"Already in DB (tagged for Annotation tab): {tagged_existing}\n"
f"Skipped (errors): {skipped}"
)
if errors:
details = "\n".join(errors[:25])
if len(errors) > 25:
details += f"\n... and {len(errors) - 25} more"
msg += "\n\nDetails:\n" + details
QMessageBox.information(self, "Import Images", msg)
# ==================== Import annotations (YOLO .txt) ====================
def _import_annotations(self) -> None:
"""Import YOLO segmentation/bbox annotations from one or more .txt files."""
settings = QSettings("microscopy_app", "object_detection")
last_dir = settings.value("annotation_tab/last_annotation_directory", None)
# Default start dir: repo root if set, otherwise last used, otherwise home.
repo_root = (self.config_manager.get_image_repository_path() or "").strip()
if last_dir and Path(str(last_dir)).exists():
start_dir = str(last_dir)
elif repo_root and Path(repo_root).exists():
start_dir = repo_root
else:
start_dir = str(Path.home())
file_paths, _ = QFileDialog.getOpenFileNames(
self,
"Select YOLO Annotation File(s)",
start_dir,
"YOLO annotations (*.txt)",
)
if not file_paths:
return
# Persist last annotation directory for the next import.
try:
settings.setValue("annotation_tab/last_annotation_directory", str(Path(file_paths[0]).parent))
except Exception:
pass
imported_images = 0
imported_annotations = 0
overwritten_images = 0
skipped = 0
errors: list[str] = []
for label_file in file_paths:
label_path = Path(label_file)
try:
image_path = self._infer_corresponding_image_path(label_path)
if not image_path:
skipped += 1
errors.append(f"Image not found for label file: {label_path}")
continue
# Load image to obtain width/height for DB entry.
img = Image(str(image_path))
# Store in DB using a repo-relative path if possible.
relative_path = self._compute_relative_path_for_repo(image_path)
image_id = self.db_manager.get_or_create_image(relative_path, image_path.name, img.width, img.height)
try:
self.db_manager.set_image_source(image_id, "annotation_tab")
except Exception:
pass
# Overwrite existing annotations for this image.
try:
deleted = self.db_manager.delete_annotations_for_image(image_id)
except AttributeError:
# Safety fallback if older DBManager is used.
deleted = 0
if deleted > 0:
overwritten_images += 1
# Parse YOLO lines and insert as annotations.
parsed = self._parse_yolo_annotation_file(label_path)
if not parsed:
# Empty/invalid label file: treat as "clear" operation (already deleted above)
imported_images += 1
continue
db_classes = self.db_manager.get_object_classes() or []
classes_by_index = {idx: row for idx, row in enumerate(db_classes)}
for class_index, bbox, poly in parsed:
class_row = classes_by_index.get(int(class_index))
if not class_row:
skipped += 1
errors.append(
f"Unknown class index {class_index} in {label_path.name}. "
"Create object classes in the UI first (class index is based on DB ordering)."
)
continue
ann_id = self.db_manager.add_annotation(
image_id=image_id,
class_id=int(class_row["id"]),
bbox=bbox,
annotator="import",
segmentation_mask=poly,
verified=False,
)
if ann_id:
imported_annotations += 1
imported_images += 1
# If we imported for the currently open image, reload.
if self.current_image_id and int(self.current_image_id) == int(image_id):
self._load_annotations_for_current_image()
except ImageLoadError as exc:
skipped += 1
errors.append(f"Failed to load image for {label_path.name}: {exc}")
except Exception as exc:
skipped += 1
errors.append(f"Import failed for {label_path.name}: {exc}")
# Refresh annotated images list.
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
summary = (
f"Imported files: {len(file_paths)}\n"
f"Images processed: {imported_images}\n"
f"Annotations inserted: {imported_annotations}\n"
f"Images overwritten (had existing annotations): {overwritten_images}\n"
f"Skipped: {skipped}"
)
if errors:
# Cap error details to avoid huge dialogs.
details = "\n".join(errors[:25])
if len(errors) > 25:
details += f"\n... and {len(errors) - 25} more"
summary += "\n\nDetails:\n" + details
QMessageBox.information(self, "Import Annotations", summary)
def _infer_corresponding_image_path(self, label_path: Path) -> Path | None:
"""Infer image path from YOLO label file path.
Requirement: image(s) live in an `images/` folder located in the label file's parent directory.
Example:
/dataset/train/labels/img123.txt -> /dataset/train/images/img123.(any supported ext)
"""
parent = label_path.parent
images_dir = parent.parent / "images"
stem = label_path.stem
# 1) Direct stem match in images dir (any supported extension)
for ext in Image.SUPPORTED_EXTENSIONS:
candidate = images_dir / f"{stem}{ext}"
if candidate.exists() and candidate.is_file():
return candidate
# 2) Fallback: repository-root search by filename
repo_root = (self.config_manager.get_image_repository_path() or "").strip()
if repo_root:
root = Path(repo_root).expanduser()
try:
if root.exists():
for ext in Image.SUPPORTED_EXTENSIONS:
filename = f"{stem}{ext}"
for match in root.rglob(filename):
if match.is_file():
return match.resolve()
except Exception:
pass
return None
def _compute_relative_path_for_repo(self, image_path: Path) -> str:
"""Compute a stable `relative_path` suitable for DB storage.
Policy:
- If an image repository root is configured and the image is under it, store a repo-relative path.
- Otherwise, store an absolute resolved path so the image can be reopened later.
"""
repo_root = (self.config_manager.get_image_repository_path() or "").strip()
try:
if repo_root:
repo_root_path = Path(repo_root).expanduser().resolve()
img_resolved = image_path.expanduser().resolve()
if img_resolved.is_relative_to(repo_root_path):
return img_resolved.relative_to(repo_root_path).as_posix()
except Exception:
pass
try:
return str(image_path.expanduser().resolve())
except Exception:
return str(image_path)
def _parse_yolo_annotation_file(
self, label_path: Path
) -> list[tuple[int, tuple[float, float, float, float], list[list[float]] | None]]:
"""Parse a YOLO .txt label file.
Supports:
- YOLO segmentation polygon format: "class x1 y1 x2 y2 ..." (normalized)
- YOLO bbox format: "class x_center y_center width height" (normalized)
Returns:
List of (class_index, bbox_xyxy_norm, segmentation_mask_db)
Where segmentation_mask_db is [[y_norm, x_norm], ...] or None.
"""
out: list[tuple[int, tuple[float, float, float, float], list[list[float]] | None]] = []
try:
raw = label_path.read_text(encoding="utf-8").splitlines()
except OSError as exc:
logger.error(f"Failed to read label file {label_path}: {exc}")
return out
for line in raw:
stripped = line.strip()
if not stripped:
continue
parts = stripped.split()
if len(parts) < 5:
# not enough for bbox
continue
try:
class_idx = int(float(parts[0]))
coords = [float(x) for x in parts[1:]]
except Exception:
continue
# Segmentation polygon format (>= 6 values)
if len(coords) >= 6:
# bbox is not explicitly present in this format in our importer; compute from polygon.
xs = coords[0::2]
ys = coords[1::2]
if not xs or not ys:
continue
x_min, x_max = min(xs), max(xs)
y_min, y_max = min(ys), max(ys)
bbox = (
self._clamp01(x_min),
self._clamp01(y_min),
self._clamp01(x_max),
self._clamp01(y_max),
)
# Convert to DB polyline convention: [[y_norm, x_norm], ...]
poly: list[list[float]] = []
for x, y in zip(xs, ys):
poly.append([self._clamp01(float(y)), self._clamp01(float(x))])
# Ensure closure for consistency (optional)
if poly and poly[0] != poly[-1]:
poly.append(list(poly[0]))
out.append((class_idx, bbox, poly))
continue
# bbox format: xc yc w h
if len(coords) >= 4:
xc, yc, w, h = coords[:4]
x_min = xc - w / 2.0
y_min = yc - h / 2.0
x_max = xc + w / 2.0
y_max = yc + h / 2.0
bbox = (
self._clamp01(float(x_min)),
self._clamp01(float(y_min)),
self._clamp01(float(x_max)),
self._clamp01(float(y_max)),
)
out.append((class_idx, bbox, None))
return out
@staticmethod
def _clamp01(value: float) -> float:
if value < 0.0:
return 0.0
if value > 1.0:
return 1.0
return float(value)
def _load_image(self):
"""Load and display an image file."""
# Get last opened directory from QSettings
@@ -166,13 +583,32 @@ class AnnotationTab(QWidget):
settings.setValue("annotation_tab/last_directory", str(Path(file_path).parent))
# Get or create image in database
relative_path = str(Path(file_path).name) # Simplified for now
repo_root = self.config_manager.get_image_repository_path()
relative_path: str
try:
if repo_root:
repo_root_path = Path(repo_root).expanduser().resolve()
file_resolved = Path(file_path).expanduser().resolve()
if file_resolved.is_relative_to(repo_root_path):
relative_path = file_resolved.relative_to(repo_root_path).as_posix()
else:
# Fallback: store filename only to avoid leaking absolute paths.
relative_path = file_resolved.name
else:
relative_path = str(Path(file_path).name)
except Exception:
relative_path = str(Path(file_path).name)
self.current_image_id = self.db_manager.get_or_create_image(
relative_path,
Path(file_path).name,
self.current_image.width,
self.current_image.height,
)
# Mark as managed by Annotation tab so it appears in the left list.
try:
self.db_manager.set_image_source(int(self.current_image_id), "annotation_tab")
except Exception:
pass
# Display image using the AnnotationCanvasWidget
self.annotation_canvas.load_image(self.current_image)
@@ -180,6 +616,9 @@ class AnnotationTab(QWidget):
# Load and display any existing annotations for this image
self._load_annotations_for_current_image()
# Update annotated images list (newly annotated image added/selected).
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
# Update info label
self._update_image_info()
@@ -275,6 +714,9 @@ class AnnotationTab(QWidget):
# Reload annotations from DB and redraw (respecting current class filter)
self._load_annotations_for_current_image()
# Update list counts.
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
except Exception as e:
logger.error(f"Failed to save annotation: {e}")
QMessageBox.critical(self, "Error", f"Failed to save annotation:\n{str(e)}")
@@ -405,6 +847,9 @@ class AnnotationTab(QWidget):
self.annotation_tools.set_has_selected_annotation(False)
self._load_annotations_for_current_image()
# Update list counts.
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
except Exception as e:
logger.error(f"Failed to delete annotations: {e}")
QMessageBox.critical(
@@ -521,4 +966,173 @@ class AnnotationTab(QWidget):
def refresh(self):
"""Refresh the tab."""
pass
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
# ==================== Annotated images list ====================
def _refresh_annotated_images_list(self, select_image_id: int | None = None) -> None:
"""Reload annotated-images list from the database."""
if not hasattr(self, "annotated_images_table"):
return
# Preserve selection if possible
desired_id = select_image_id if select_image_id is not None else self.current_image_id
name_filter = ""
if hasattr(self, "annotated_filter_edit"):
name_filter = self.annotated_filter_edit.text().strip()
try:
rows = self.db_manager.get_images_summary(name_filter=name_filter, source_filter="annotation_tab")
except Exception as exc:
logger.error(f"Failed to load images summary: {exc}")
rows = []
sorting_enabled = self.annotated_images_table.isSortingEnabled()
self.annotated_images_table.setSortingEnabled(False)
self.annotated_images_table.blockSignals(True)
try:
self.annotated_images_table.setRowCount(len(rows))
for r, entry in enumerate(rows):
image_name = str(entry.get("filename") or "")
count = int(entry.get("annotation_count") or 0)
rel_path = str(entry.get("relative_path") or "")
name_item = QTableWidgetItem(image_name)
# Tooltip shows full path of the image (best-effort: repository_root + relative_path)
full_path = rel_path
repo_root = self.config_manager.get_image_repository_path()
if repo_root and rel_path and not Path(rel_path).is_absolute():
try:
full_path = str((Path(repo_root) / rel_path).resolve())
except Exception:
full_path = str(Path(repo_root) / rel_path)
name_item.setToolTip(full_path)
name_item.setData(Qt.UserRole, int(entry.get("id")))
name_item.setData(Qt.UserRole + 1, rel_path)
count_item = QTableWidgetItem()
# Use EditRole to ensure numeric sorting.
count_item.setData(Qt.EditRole, count)
count_item.setData(Qt.UserRole, int(entry.get("id")))
count_item.setData(Qt.UserRole + 1, rel_path)
self.annotated_images_table.setItem(r, 0, name_item)
self.annotated_images_table.setItem(r, 1, count_item)
# Re-select desired row
if desired_id is not None:
for r in range(self.annotated_images_table.rowCount()):
item = self.annotated_images_table.item(r, 0)
if item and item.data(Qt.UserRole) == desired_id:
self.annotated_images_table.selectRow(r)
break
finally:
self.annotated_images_table.blockSignals(False)
self.annotated_images_table.setSortingEnabled(sorting_enabled)
def _on_annotated_image_selected(self) -> None:
"""When user clicks an item in the list, load that image in the annotation canvas."""
selected = self.annotated_images_table.selectedItems()
if not selected:
return
# Row selection -> take the first column item
row = self.annotated_images_table.currentRow()
item = self.annotated_images_table.item(row, 0)
if not item:
return
image_id = item.data(Qt.UserRole)
rel_path = item.data(Qt.UserRole + 1) or ""
if not image_id:
return
image_path = self._resolve_image_path_for_relative_path(rel_path)
if not image_path:
QMessageBox.warning(
self,
"Image Not Found",
"Unable to locate image on disk for:\n"
f"{rel_path}\n\n"
"Tip: set Settings → Image repository path to the folder containing your images.",
)
return
try:
self.current_image = Image(image_path)
self.current_image_path = image_path
self.current_image_id = int(image_id)
self.annotation_canvas.load_image(self.current_image)
self._load_annotations_for_current_image()
self._update_image_info()
except ImageLoadError as exc:
logger.error(f"Failed to load image '{image_path}': {exc}")
QMessageBox.critical(self, "Error Loading Image", f"Failed to load image:\n{exc}")
except Exception as exc:
logger.error(f"Unexpected error loading image '{image_path}': {exc}")
QMessageBox.critical(self, "Error", f"Unexpected error:\n{exc}")
def _resolve_image_path_for_relative_path(self, relative_path: str) -> str | None:
"""Best-effort conversion from a DB relative_path to an on-disk file path."""
rel = (relative_path or "").strip()
if not rel:
return None
candidates: list[Path] = []
# 1) Repository root + relative
repo_root = (self.config_manager.get_image_repository_path() or "").strip()
if repo_root:
candidates.append(Path(repo_root) / rel)
# 2) If the DB path is absolute, try it directly.
candidates.append(Path(rel))
# 3) Try the directory of the currently loaded image (helps when DB stores only filenames)
if self.current_image_path:
try:
candidates.append(Path(self.current_image_path).expanduser().resolve().parent / Path(rel).name)
except Exception:
pass
# 4) Try the last directory used by the annotation file picker
try:
settings = QSettings("microscopy_app", "object_detection")
last_dir = settings.value("annotation_tab/last_directory", None)
if last_dir:
candidates.append(Path(str(last_dir)) / Path(rel).name)
except Exception:
pass
# 4b) Try the last directory used by the image import picker
try:
settings = QSettings("microscopy_app", "object_detection")
last_import_dir = settings.value("annotation_tab/last_image_import_directory", None)
if last_import_dir:
candidates.append(Path(str(last_import_dir)) / Path(rel).name)
except Exception:
pass
for p in candidates:
try:
expanded = p.expanduser()
if expanded.exists() and expanded.is_file():
return str(expanded.resolve())
except Exception:
continue
# 5) Fallback: search by filename within repository root.
filename = Path(rel).name
if repo_root and filename:
root = Path(repo_root).expanduser()
try:
if root.exists():
for match in root.rglob(filename):
if match.is_file():
return str(match.resolve())
except Exception as exc:
logger.debug(f"Search for {filename} under {root} failed: {exc}")
return None

View File

@@ -40,6 +40,12 @@ class ResultsTab(QWidget):
self.db_manager = db_manager
self.config_manager = config_manager
# Pagination
self.page_size = 200
self.current_page = 0 # 0-based
self.total_runs = 0
self.total_pages = 0
self.detection_summary: List[Dict] = []
self.current_selection: Optional[Dict] = None
self.current_image: Optional[Image] = None
@@ -66,6 +72,20 @@ class ResultsTab(QWidget):
self.refresh_btn.clicked.connect(self.refresh)
controls_layout.addWidget(self.refresh_btn)
self.prev_page_btn = QPushButton("◀ Prev")
self.prev_page_btn.setToolTip("Previous page")
self.prev_page_btn.clicked.connect(self._prev_page)
controls_layout.addWidget(self.prev_page_btn)
self.next_page_btn = QPushButton("Next ▶")
self.next_page_btn.setToolTip("Next page")
self.next_page_btn.clicked.connect(self._next_page)
controls_layout.addWidget(self.next_page_btn)
self.page_label = QLabel("Page 1/1")
self.page_label.setMinimumWidth(100)
controls_layout.addWidget(self.page_label)
self.delete_all_btn = QPushButton("Delete All Detections")
self.delete_all_btn.setToolTip(
"Permanently delete ALL detections from the database.\n" "This cannot be undone."
@@ -183,6 +203,8 @@ class ResultsTab(QWidget):
def refresh(self):
"""Refresh the detection list and preview."""
# Reset to first page on refresh.
self.current_page = 0
self._load_detection_summary()
self._populate_results_table()
self.current_selection = None
@@ -193,57 +215,135 @@ class ResultsTab(QWidget):
if hasattr(self, "export_labels_btn"):
self.export_labels_btn.setEnabled(False)
self._update_pagination_controls()
def _prev_page(self):
"""Go to previous results page."""
if self.current_page <= 0:
return
self.current_page -= 1
self._load_detection_summary()
self._populate_results_table()
self._update_pagination_controls()
def _next_page(self):
"""Go to next results page."""
if self.total_pages and self.current_page >= (self.total_pages - 1):
return
self.current_page += 1
self._load_detection_summary()
self._populate_results_table()
self._update_pagination_controls()
def _update_pagination_controls(self):
"""Update pagination label/button enabled state."""
# Default state (safe)
total_pages = max(int(self.total_pages or 0), 1)
current_page = min(max(int(self.current_page or 0), 0), total_pages - 1)
self.current_page = current_page
self.page_label.setText(f"Page {current_page + 1}/{total_pages}")
self.prev_page_btn.setEnabled(current_page > 0)
self.next_page_btn.setEnabled(current_page < (total_pages - 1))
def _load_detection_summary(self):
"""Load latest detection summaries grouped by image + model."""
try:
detections = self.db_manager.get_detections(limit=500)
summary_map: Dict[tuple, Dict] = {}
# Prefer run summaries (supports zero-detection runs). Fall back to legacy
# detection aggregation if the DB/table isn't available.
try:
self.total_runs = int(self.db_manager.get_detection_run_total())
except Exception:
self.total_runs = 0
for det in detections:
key = (det["image_id"], det["model_id"])
metadata = det.get("metadata") or {}
entry = summary_map.setdefault(
key,
self.total_pages = (self.total_runs + self.page_size - 1) // self.page_size if self.total_runs > 0 else 1
offset = int(self.current_page) * int(self.page_size)
runs = self.db_manager.get_detection_run_summaries(
limit=int(self.page_size),
offset=offset,
)
summary: List[Dict] = []
for run in runs:
meta = run.get("metadata") or {}
classes_raw = run.get("classes")
classes = set()
if isinstance(classes_raw, str) and classes_raw.strip():
classes = {c.strip() for c in classes_raw.split(",") if c.strip()}
summary.append(
{
"image_id": det["image_id"],
"model_id": det["model_id"],
"image_path": det.get("image_path"),
"image_filename": det.get("image_filename") or det.get("image_path"),
"model_name": det.get("model_name", ""),
"model_version": det.get("model_version", ""),
"last_detected": det.get("detected_at"),
"count": 0,
"classes": set(),
"source_path": metadata.get("source_path"),
"repository_root": metadata.get("repository_root"),
},
"image_id": run.get("image_id"),
"model_id": run.get("model_id"),
"image_path": run.get("image_path"),
"image_filename": run.get("image_filename") or run.get("image_path"),
"model_name": run.get("model_name", ""),
"model_version": run.get("model_version", ""),
"last_detected": run.get("detected_at"),
"count": int(run.get("count") or 0),
"classes": classes,
"source_path": meta.get("source_path"),
"repository_root": meta.get("repository_root"),
}
)
entry["count"] += 1
if det.get("detected_at") and (
not entry.get("last_detected") or str(det.get("detected_at")) > str(entry.get("last_detected"))
):
entry["last_detected"] = det.get("detected_at")
if det.get("class_name"):
entry["classes"].add(det["class_name"])
if metadata.get("source_path") and not entry.get("source_path"):
entry["source_path"] = metadata.get("source_path")
if metadata.get("repository_root") and not entry.get("repository_root"):
entry["repository_root"] = metadata.get("repository_root")
self.detection_summary = sorted(
summary_map.values(),
key=lambda x: str(x.get("last_detected") or ""),
reverse=True,
)
self.detection_summary = summary
except Exception as e:
logger.error(f"Failed to load detection summary: {e}")
QMessageBox.critical(
self,
"Error",
f"Failed to load detection results:\n{str(e)}",
)
self.detection_summary = []
logger.error(f"Failed to load detection run summaries, falling back: {e}")
# Disable pagination if we can't page via detection_runs.
self.total_runs = 0
self.total_pages = 1
self.current_page = 0
try:
detections = self.db_manager.get_detections(limit=int(self.page_size))
summary_map: Dict[tuple, Dict] = {}
for det in detections:
key = (det["image_id"], det["model_id"])
metadata = det.get("metadata") or {}
entry = summary_map.setdefault(
key,
{
"image_id": det["image_id"],
"model_id": det["model_id"],
"image_path": det.get("image_path"),
"image_filename": det.get("image_filename") or det.get("image_path"),
"model_name": det.get("model_name", ""),
"model_version": det.get("model_version", ""),
"last_detected": det.get("detected_at"),
"count": 0,
"classes": set(),
"source_path": metadata.get("source_path"),
"repository_root": metadata.get("repository_root"),
},
)
entry["count"] += 1
if det.get("detected_at") and (
not entry.get("last_detected") or str(det.get("detected_at")) > str(entry.get("last_detected"))
):
entry["last_detected"] = det.get("detected_at")
if det.get("class_name"):
entry["classes"].add(det["class_name"])
if metadata.get("source_path") and not entry.get("source_path"):
entry["source_path"] = metadata.get("source_path")
if metadata.get("repository_root") and not entry.get("repository_root"):
entry["repository_root"] = metadata.get("repository_root")
self.detection_summary = sorted(
summary_map.values(),
key=lambda x: str(x.get("last_detected") or ""),
reverse=True,
)
except Exception as inner:
logger.error(f"Failed to load detection summary: {inner}")
QMessageBox.critical(
self,
"Error",
f"Failed to load detection results:\n{str(inner)}",
)
self.detection_summary = []
def _populate_results_table(self):
"""Populate the table widget with detection summaries."""

View File

@@ -905,9 +905,14 @@ class TrainingTab(QWidget):
if stats["registered_images"]:
message += f" {stats['registered_images']} image(s) had database-backed annotations."
if stats["missing_records"]:
message += (
f" {stats['missing_records']} image(s) had no database entry; empty label files were written."
)
preserved = stats.get("preserved_existing_labels", 0)
if preserved:
message += (
f" {stats['missing_records']} image(s) had no database annotations; "
f"preserved {preserved} existing label file(s) (no overwrite)."
)
else:
message += f" {stats['missing_records']} image(s) had no database annotations; empty label files were written."
split_messages.append(message)
for msg in split_messages:
@@ -929,6 +934,7 @@ class TrainingTab(QWidget):
processed_images = 0
registered_images = 0
missing_records = 0
preserved_existing_labels = 0
total_annotations = 0
for image_file in images_dir.rglob("*"):
@@ -950,6 +956,12 @@ class TrainingTab(QWidget):
else:
missing_records += 1
# If the database has no entry for this image, do not overwrite an existing label file
# with an empty one (preserve any manually created labels on disk).
if not found and label_path.exists():
preserved_existing_labels += 1
continue
annotations_written = 0
with open(label_path, "w", encoding="utf-8") as handle:
for entry in annotation_entries:
@@ -979,6 +991,7 @@ class TrainingTab(QWidget):
"processed_images": processed_images,
"registered_images": registered_images,
"missing_records": missing_records,
"preserved_existing_labels": preserved_existing_labels,
"total_annotations": total_annotations,
}
@@ -1008,6 +1021,10 @@ class TrainingTab(QWidget):
resolved_image = image_path.resolve()
candidates: List[str] = []
# Some DBs store absolute paths in `images.relative_path`.
# Include the absolute resolved path as a lookup candidate.
candidates.append(resolved_image.as_posix())
for base in (dataset_root, images_dir):
try:
relative = resolved_image.relative_to(base.resolve()).as_posix()
@@ -1032,6 +1049,13 @@ class TrainingTab(QWidget):
return False, []
annotations = self.db_manager.get_annotations_for_image(image_row["id"]) or []
# Treat "found" as "has database-backed annotations".
# If the image exists in DB but has no annotations yet, we don't want to overwrite
# an existing label file on disk with an empty one.
if not annotations:
return False, []
yolo_entries: List[Dict[str, Any]] = []
for ann in annotations:

View File

@@ -84,25 +84,19 @@ class InferenceEngine:
# Save detections to database, replacing any previous results for this image/model
if save_to_db:
deleted_count = self.db_manager.delete_detections_for_image(
image_id, self.model_id
)
deleted_count = self.db_manager.delete_detections_for_image(image_id, self.model_id)
if detections:
detection_records = []
for det in detections:
# Use normalized bbox from detection
bbox_normalized = det[
"bbox_normalized"
] # [x_min, y_min, x_max, y_max]
bbox_normalized = det["bbox_normalized"] # [x_min, y_min, x_max, y_max]
metadata = {
"class_id": det["class_id"],
"source_path": str(Path(image_path).resolve()),
}
if repository_root:
metadata["repository_root"] = str(
Path(repository_root).resolve()
)
metadata["repository_root"] = str(Path(repository_root).resolve())
record = {
"image_id": image_id,
@@ -115,16 +109,27 @@ class InferenceEngine:
}
detection_records.append(record)
inserted_count = self.db_manager.add_detections_batch(
detection_records
)
logger.info(
f"Saved {inserted_count} detections to database (replaced {deleted_count})"
)
inserted_count = self.db_manager.add_detections_batch(detection_records)
logger.info(f"Saved {inserted_count} detections to database (replaced {deleted_count})")
else:
logger.info(
f"Detection run removed {deleted_count} stale entries but produced no new detections"
logger.info(f"Detection run removed {deleted_count} stale entries but produced no new detections")
# Always store a run summary so the Results tab can show zero-detection runs.
try:
run_metadata = {
"source_path": str(Path(image_path).resolve()),
}
if repository_root:
run_metadata["repository_root"] = str(Path(repository_root).resolve())
self.db_manager.upsert_detection_run(
image_id=image_id,
model_id=self.model_id,
count=len(detections),
metadata=run_metadata,
)
except Exception as exc:
# Non-fatal: detection records may still be present.
logger.warning(f"Failed to store detection run summary: {exc}")
return {
"success": True,
@@ -232,9 +237,7 @@ class InferenceEngine:
for det in detections:
# Get color for this class
class_name = det["class_name"]
color_hex = bbox_colors.get(
class_name, bbox_colors.get("default", "#00FF00")
)
color_hex = bbox_colors.get(class_name, bbox_colors.get("default", "#00FF00"))
color = self._hex_to_bgr(color_hex)
# Draw segmentation mask if available and requested
@@ -243,10 +246,7 @@ class InferenceEngine:
if mask_normalized and len(mask_normalized) > 0:
# Convert normalized coordinates to absolute pixels
mask_points = np.array(
[
[int(pt[0] * width), int(pt[1] * height)]
for pt in mask_normalized
],
[[int(pt[0] * width), int(pt[1] * height)] for pt in mask_normalized],
dtype=np.int32,
)
@@ -270,9 +270,7 @@ class InferenceEngine:
label = f"{class_name} {det['confidence']:.2f}"
# Draw label background
(label_w, label_h), baseline = cv2.getTextSize(
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
)
(label_w, label_h), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
cv2.rectangle(
img,
(x1, y1 - label_h - baseline - 5),

View File

@@ -186,7 +186,7 @@ class YOLOWrapper:
raise RuntimeError(f"Failed to load model from {self.model_path}")
prepared_source, cleanup_path = self._prepare_source(source)
imgsz = 1088
imgsz = 640 # 1088
try:
logger.info(f"Running inference on {source} -> prepared_source {prepared_source}")
results = self.model.predict(

View File

@@ -6,6 +6,9 @@ import cv2
import numpy as np
from pathlib import Path
from typing import Optional, Tuple, Union
from skimage.restoration import rolling_ball
from skimage.filters import threshold_otsu
from scipy.ndimage import median_filter, gaussian_filter
from src.utils.logger import get_logger
from src.utils.file_utils import validate_file_path, is_image_file
@@ -37,6 +40,8 @@ def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.5) -> np.ndarray:
a1[a1 > p999] = p999
a1 /= a1.max()
# print("Using get_pseudo_rgb")
if 1:
a2 = a1.copy()
a2 = a2**gamma
@@ -46,14 +51,75 @@ def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.5) -> np.ndarray:
p9999 = np.percentile(a3, 99.99)
a3[a3 > p9999] = p9999
a3 /= a3.max()
out = np.stack([a1, a2, a3], axis=0)
else:
out = np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
return out
# return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
# return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
out = np.stack([a1, a2, a3], axis=0)
# print(any(np.isnan(out).flatten()))
def _get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.3) -> np.ndarray:
"""
Convert a grayscale image to a pseudo-RGB image using a gamma correction.
Args:
arr: Input grayscale image as numpy array
Returns:
Pseudo-RGB image as numpy array
"""
if arr.ndim != 2:
raise ValueError("Input array must be a grayscale image with shape (H, W)")
radius = 80
# bg = rolling_ball(arr, radius=radius)
a1 = arr.copy().astype(np.float32)
a1 = a1.astype(np.float32)
# a1 -= bg
# a1[a1 < 0] = 0
# a1 -= np.percentile(a1, 2)
# a1[a1 < 0] = 0
p999 = np.percentile(a1, 99.99)
a1[a1 > p999] = p999
a1 /= a1.max()
print("Using get_pseudo_rgb")
if 1:
a2 = a1.copy()
_a2 = a2**gamma
thr = threshold_otsu(_a2)
mask = gaussian_filter((_a2 > thr).astype(np.float32), sigma=5)
mask[mask > 0.0001] = 1
mask[mask <= 0.0001] = 0
a2 *= mask
# bg2 = rolling_ball(a2, radius=radius)
# a2 -= bg2
a2 -= np.percentile(a2, 2)
a2[a2 < 0] = 0
a2 /= a2.max()
a3 = a1.copy()
p9999 = np.percentile(a3, 99.99)
a3[a3 > p9999] = p9999
a3 /= a3.max()
out = np.stack([a1, a2, _a2], axis=0)
else:
out = np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
return out
# return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
# return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
# print(any(np.isnan(out).flatten()))
class ImageLoadError(Exception):
"""Exception raised when an image cannot be loaded."""
@@ -330,17 +396,20 @@ class Image:
return self._channels >= 3
def save(self, path: Union[str, Path], pseudo_rgb: bool = True) -> None:
# print("Image.save", self.data.shape)
if self.channels == 1:
# print("Image.save grayscale")
if pseudo_rgb:
img = get_pseudo_rgb(self.data)
print("Image.save", img.shape)
# print("Image.save", img.shape)
else:
img = np.repeat(self.data, 3, axis=2)
# print("Image.save no pseudo", img.shape)
else:
raise NotImplementedError("Only grayscale images are supported for now.")
# print("Image.save imwrite", img.shape)
imwrite(path, data=img)
def __repr__(self) -> str:

View File

@@ -97,7 +97,7 @@ class UT:
class_index: int = 0,
):
"""Export rois to a file"""
with open(path / subfolder / f"{self.stem}.txt", "w") as f:
with open(path / subfolder / f"{PREFIX}-{self.stem}.txt", "w") as f:
for i, roi in enumerate(self.rois):
rc = roi.subpixel_coordinates
if rc is None:
@@ -129,8 +129,8 @@ class UT:
self.image = np.max(self.image[channel], axis=0)
print(self.image.shape)
print(path / subfolder / f"{self.stem}.tif")
with TiffWriter(path / subfolder / f"{self.stem}.tif") as tif:
print(path / subfolder / f"{PREFIX}_{self.stem}.tif")
with TiffWriter(path / subfolder / f"{PREFIX}-{self.stem}.tif") as tif:
tif.write(self.image)
@@ -145,11 +145,19 @@ if __name__ == "__main__":
action="store_false",
help="Source does not have labels, export only images",
)
parser.add_argument("--prefix", help="Prefix for output files")
args = parser.parse_args()
PREFIX = args.prefix
# print(args)
# aa
# for path in args.input:
# print(path)
# ut = UT(path, args.no_labels)
# ut.export_image(args.output, plane_mode="max projection", channel=0)
# ut.export_rois(args.output, class_index=0)
for path in args.input:
print("Path:", path)
if not args.no_labels:

View File

@@ -5,7 +5,7 @@ from tifffile import imread, imwrite
from shapely.geometry import LineString
from copy import deepcopy
from scipy.ndimage import zoom
from skimage.restoration import rolling_ball
# debug
from src.utils.image import Image
@@ -160,8 +160,11 @@ class YoloLabelReader:
class ImageSplitter:
def __init__(self, image_path: Path, label_path: Path):
def __init__(self, image_path: Path, label_path: Path, subtract_bg: bool = False):
self.image = imread(image_path)
if subtract_bg:
self.image = self.image - rolling_ball(self.image, radius=12)
self.image[self.image < 0] = 0
self.image_path = image_path
self.label_path = label_path
if not label_path.exists():
@@ -273,13 +276,14 @@ def main(args):
if args.output:
args.output.mkdir(exist_ok=True, parents=True)
(args.output / "images").mkdir(exist_ok=True)
(args.output / "images-zoomed").mkdir(exist_ok=True)
# (args.output / "images-zoomed").mkdir(exist_ok=True)
(args.output / "labels").mkdir(exist_ok=True)
for image_path in (args.input / "images").glob("*.tif"):
data = ImageSplitter(
image_path=image_path,
label_path=(args.input / "labels" / image_path.stem).with_suffix(".txt"),
subtract_bg=args.subtract_bg,
)
if args.split_around_label:
@@ -332,10 +336,20 @@ def main(args):
if labels is not None:
with open(args.output / "labels" / f"{image_path.stem}_{tile_reference}.txt", "w") as f:
print(
f"Writing {len(labels)} labels to {args.output / 'labels' / f'{image_path.stem}_{tile_reference}.txt'}"
)
for label in labels:
# label.offset_label(tile.shape[1], tile.shape[0])
f.write(label.to_string() + "\n")
# { debug
if debug:
print(label.to_string())
# } debug
# break
# break
if __name__ == "__main__":
import argparse
@@ -363,6 +377,7 @@ if __name__ == "__main__":
default=67,
help="Padding around the label when splitting around the label.",
)
parser.add_argument("-bg", "--subtract-bg", action="store_true", help="Subtract background from the image.")
args = parser.parse_args()
main(args)