Compare commits
8 Commits
506c74e53a
...
monkey-pat
| Author | SHA1 | Date | |
|---|---|---|---|
| 98bc89691b | |||
| 3c8247b3bc | |||
| d03ffdc4d0 | |||
| 8d30e6bb7a | |||
| f810fec4d8 | |||
| 9c8931e6f3 | |||
| 20578c1fdf | |||
| 2c494dac49 |
@@ -40,8 +40,15 @@ class DatabaseManager:
|
|||||||
|
|
||||||
conn = self.get_connection()
|
conn = self.get_connection()
|
||||||
try:
|
try:
|
||||||
# Check if annotations table needs migration
|
# Pre-schema migrations.
|
||||||
|
# These must run BEFORE executing schema.sql because schema.sql may
|
||||||
|
# contain CREATE INDEX statements referencing newly added columns.
|
||||||
|
#
|
||||||
|
# 1) Check if annotations table needs migration (may drop an old table)
|
||||||
self._migrate_annotations_table(conn)
|
self._migrate_annotations_table(conn)
|
||||||
|
# 2) Ensure images table has the required columns (e.g. 'source')
|
||||||
|
self._migrate_images_table(conn)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
# Read schema file and execute
|
# Read schema file and execute
|
||||||
schema_path = Path(__file__).parent / "schema.sql"
|
schema_path = Path(__file__).parent / "schema.sql"
|
||||||
@@ -53,6 +60,19 @@ class DatabaseManager:
|
|||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
def _migrate_images_table(self, conn: sqlite3.Connection) -> None:
|
||||||
|
"""Migrate images table to include the 'source' column if missing."""
|
||||||
|
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='images'")
|
||||||
|
if not cursor.fetchone():
|
||||||
|
return
|
||||||
|
|
||||||
|
cursor.execute("PRAGMA table_info(images)")
|
||||||
|
columns = {row[1] for row in cursor.fetchall()}
|
||||||
|
if "source" not in columns:
|
||||||
|
cursor.execute("ALTER TABLE images ADD COLUMN source TEXT")
|
||||||
|
|
||||||
def _migrate_annotations_table(self, conn: sqlite3.Connection) -> None:
|
def _migrate_annotations_table(self, conn: sqlite3.Connection) -> None:
|
||||||
"""
|
"""
|
||||||
Migrate annotations table from old schema (class_name) to new schema (class_id).
|
Migrate annotations table from old schema (class_name) to new schema (class_id).
|
||||||
@@ -85,6 +105,103 @@ class DatabaseManager:
|
|||||||
conn.execute("PRAGMA foreign_keys = ON") # Enable foreign keys
|
conn.execute("PRAGMA foreign_keys = ON") # Enable foreign keys
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
|
# ==================== Detection Run Operations ====================
|
||||||
|
|
||||||
|
def upsert_detection_run(
|
||||||
|
self,
|
||||||
|
image_id: int,
|
||||||
|
model_id: int,
|
||||||
|
count: int,
|
||||||
|
metadata: Optional[Dict] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Insert/update a per-image per-model detection run summary.
|
||||||
|
|
||||||
|
This enables the UI to show runs even when zero detections were produced.
|
||||||
|
"""
|
||||||
|
|
||||||
|
conn = self.get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO detection_runs (image_id, model_id, detected_at, count, metadata)
|
||||||
|
VALUES (?, ?, CURRENT_TIMESTAMP, ?, ?)
|
||||||
|
ON CONFLICT(image_id, model_id) DO UPDATE SET
|
||||||
|
detected_at = CURRENT_TIMESTAMP,
|
||||||
|
count = excluded.count,
|
||||||
|
metadata = excluded.metadata
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
int(image_id),
|
||||||
|
int(model_id),
|
||||||
|
int(count),
|
||||||
|
json.dumps(metadata) if metadata else None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
return True
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def get_detection_run_summaries(self, limit: int = 500, offset: int = 0) -> List[Dict]:
|
||||||
|
"""Return latest detection run summaries grouped by image+model.
|
||||||
|
|
||||||
|
Includes runs with 0 detections.
|
||||||
|
"""
|
||||||
|
|
||||||
|
conn = self.get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
SELECT
|
||||||
|
dr.image_id,
|
||||||
|
dr.model_id,
|
||||||
|
dr.detected_at,
|
||||||
|
dr.count,
|
||||||
|
dr.metadata,
|
||||||
|
i.relative_path AS image_path,
|
||||||
|
i.filename AS image_filename,
|
||||||
|
m.model_name,
|
||||||
|
m.model_version,
|
||||||
|
GROUP_CONCAT(DISTINCT d.class_name) AS classes
|
||||||
|
FROM detection_runs dr
|
||||||
|
JOIN images i ON dr.image_id = i.id
|
||||||
|
JOIN models m ON dr.model_id = m.id
|
||||||
|
LEFT JOIN detections d
|
||||||
|
ON d.image_id = dr.image_id AND d.model_id = dr.model_id
|
||||||
|
GROUP BY dr.image_id, dr.model_id
|
||||||
|
ORDER BY dr.detected_at DESC
|
||||||
|
LIMIT ? OFFSET ?
|
||||||
|
""",
|
||||||
|
(int(limit), int(offset)),
|
||||||
|
)
|
||||||
|
|
||||||
|
rows: List[Dict] = []
|
||||||
|
for row in cursor.fetchall():
|
||||||
|
item = dict(row)
|
||||||
|
if item.get("metadata"):
|
||||||
|
try:
|
||||||
|
item["metadata"] = json.loads(item["metadata"])
|
||||||
|
except Exception:
|
||||||
|
item["metadata"] = None
|
||||||
|
rows.append(item)
|
||||||
|
return rows
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def get_detection_run_total(self) -> int:
|
||||||
|
"""Return total number of detection_runs rows."""
|
||||||
|
|
||||||
|
conn = self.get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT COUNT(*) AS cnt FROM detection_runs")
|
||||||
|
row = cursor.fetchone()
|
||||||
|
return int(row["cnt"] if row and row["cnt"] is not None else 0)
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
# ==================== Model Operations ====================
|
# ==================== Model Operations ====================
|
||||||
|
|
||||||
def add_model(
|
def add_model(
|
||||||
@@ -201,6 +318,28 @@ class DatabaseManager:
|
|||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
def delete_model(self, model_id: int) -> bool:
|
||||||
|
"""Delete a model from the database.
|
||||||
|
|
||||||
|
Note: detections referencing this model are deleted automatically via
|
||||||
|
the `detections.model_id` foreign key (ON DELETE CASCADE).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: ID of the model to delete.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if a model row was deleted, False otherwise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
conn = self.get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("DELETE FROM models WHERE id = ?", (model_id,))
|
||||||
|
conn.commit()
|
||||||
|
return cursor.rowcount > 0
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
# ==================== Image Operations ====================
|
# ==================== Image Operations ====================
|
||||||
|
|
||||||
def add_image(
|
def add_image(
|
||||||
@@ -211,6 +350,7 @@ class DatabaseManager:
|
|||||||
height: int,
|
height: int,
|
||||||
captured_at: Optional[datetime] = None,
|
captured_at: Optional[datetime] = None,
|
||||||
checksum: Optional[str] = None,
|
checksum: Optional[str] = None,
|
||||||
|
source: Optional[str] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Add a new image to the database.
|
Add a new image to the database.
|
||||||
@@ -231,10 +371,10 @@ class DatabaseManager:
|
|||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO images (relative_path, filename, width, height, captured_at, checksum)
|
INSERT INTO images (relative_path, filename, width, height, captured_at, checksum, source)
|
||||||
VALUES (?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||||
""",
|
""",
|
||||||
(relative_path, filename, width, height, captured_at, checksum),
|
(relative_path, filename, width, height, captured_at, checksum, source),
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
return cursor.lastrowid
|
return cursor.lastrowid
|
||||||
@@ -264,6 +404,18 @@ class DatabaseManager:
|
|||||||
return existing["id"]
|
return existing["id"]
|
||||||
return self.add_image(relative_path, filename, width, height)
|
return self.add_image(relative_path, filename, width, height)
|
||||||
|
|
||||||
|
def set_image_source(self, image_id: int, source: Optional[str]) -> bool:
|
||||||
|
"""Set/update the source marker for an image row."""
|
||||||
|
|
||||||
|
conn = self.get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("UPDATE images SET source = ? WHERE id = ?", (source, int(image_id)))
|
||||||
|
conn.commit()
|
||||||
|
return cursor.rowcount > 0
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
# ==================== Detection Operations ====================
|
# ==================== Detection Operations ====================
|
||||||
|
|
||||||
def add_detection(
|
def add_detection(
|
||||||
@@ -462,6 +614,30 @@ class DatabaseManager:
|
|||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
def delete_all_detections(self) -> int:
|
||||||
|
"""Delete all detections from the database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of rows deleted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
conn = self.get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
# Also clear detection run summaries so the Results tab does not continue
|
||||||
|
# to show historical runs after detections have been wiped.
|
||||||
|
try:
|
||||||
|
cursor.execute("DELETE FROM detection_runs")
|
||||||
|
except sqlite3.OperationalError:
|
||||||
|
# Backwards-compatible: table may not exist on older DB files.
|
||||||
|
pass
|
||||||
|
|
||||||
|
cursor.execute("DELETE FROM detections")
|
||||||
|
conn.commit()
|
||||||
|
return cursor.rowcount
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
# ==================== Statistics Operations ====================
|
# ==================== Statistics Operations ====================
|
||||||
|
|
||||||
def get_detection_statistics(
|
def get_detection_statistics(
|
||||||
@@ -620,6 +796,153 @@ class DatabaseManager:
|
|||||||
|
|
||||||
# ==================== Annotation Operations ====================
|
# ==================== Annotation Operations ====================
|
||||||
|
|
||||||
|
def get_images_summary(
|
||||||
|
self,
|
||||||
|
name_filter: Optional[str] = None,
|
||||||
|
source_filter: Optional[str] = None,
|
||||||
|
order_by: str = "filename",
|
||||||
|
order_dir: str = "ASC",
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""Return all images with annotation counts (including zero).
|
||||||
|
|
||||||
|
This is used by the Annotation tab to populate the image list even when
|
||||||
|
no annotations exist yet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name_filter: Optional substring filter applied to filename/relative_path.
|
||||||
|
order_by: One of: 'filename', 'relative_path', 'annotation_count', 'added_at'.
|
||||||
|
order_dir: 'ASC' or 'DESC'.
|
||||||
|
limit: Optional max number of rows.
|
||||||
|
offset: Pagination offset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts: {id, relative_path, filename, added_at, annotation_count}
|
||||||
|
"""
|
||||||
|
|
||||||
|
allowed_order_by = {
|
||||||
|
"filename": "i.filename",
|
||||||
|
"relative_path": "i.relative_path",
|
||||||
|
"annotation_count": "annotation_count",
|
||||||
|
"added_at": "i.added_at",
|
||||||
|
}
|
||||||
|
order_expr = allowed_order_by.get(order_by, "i.filename")
|
||||||
|
dir_norm = str(order_dir).upper().strip()
|
||||||
|
if dir_norm not in {"ASC", "DESC"}:
|
||||||
|
dir_norm = "ASC"
|
||||||
|
|
||||||
|
conn = self.get_connection()
|
||||||
|
try:
|
||||||
|
params: List[Any] = []
|
||||||
|
where_clauses: List[str] = []
|
||||||
|
if name_filter:
|
||||||
|
token = f"%{name_filter}%"
|
||||||
|
where_clauses.append("(i.filename LIKE ? OR i.relative_path LIKE ?)")
|
||||||
|
params.extend([token, token])
|
||||||
|
if source_filter:
|
||||||
|
where_clauses.append("i.source = ?")
|
||||||
|
params.append(source_filter)
|
||||||
|
|
||||||
|
where_sql = ""
|
||||||
|
if where_clauses:
|
||||||
|
where_sql = "WHERE " + " AND ".join(where_clauses)
|
||||||
|
|
||||||
|
limit_sql = ""
|
||||||
|
if limit is not None:
|
||||||
|
limit_sql = " LIMIT ? OFFSET ?"
|
||||||
|
params.extend([int(limit), int(offset)])
|
||||||
|
|
||||||
|
query = f"""
|
||||||
|
SELECT
|
||||||
|
i.id,
|
||||||
|
i.relative_path,
|
||||||
|
i.filename,
|
||||||
|
i.added_at,
|
||||||
|
COUNT(a.id) AS annotation_count
|
||||||
|
FROM images i
|
||||||
|
LEFT JOIN annotations a ON a.image_id = i.id
|
||||||
|
{where_sql}
|
||||||
|
GROUP BY i.id
|
||||||
|
ORDER BY {order_expr} {dir_norm}
|
||||||
|
{limit_sql}
|
||||||
|
"""
|
||||||
|
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(query, params)
|
||||||
|
return [dict(row) for row in cursor.fetchall()]
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def get_annotated_images_summary(
|
||||||
|
self,
|
||||||
|
name_filter: Optional[str] = None,
|
||||||
|
order_by: str = "filename",
|
||||||
|
order_dir: str = "ASC",
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""Return images that have at least one manual annotation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name_filter: Optional substring filter applied to filename/relative_path.
|
||||||
|
order_by: One of: 'filename', 'relative_path', 'annotation_count', 'added_at'.
|
||||||
|
order_dir: 'ASC' or 'DESC'.
|
||||||
|
limit: Optional max number of rows.
|
||||||
|
offset: Pagination offset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts: {id, relative_path, filename, added_at, annotation_count}
|
||||||
|
"""
|
||||||
|
|
||||||
|
allowed_order_by = {
|
||||||
|
"filename": "i.filename",
|
||||||
|
"relative_path": "i.relative_path",
|
||||||
|
"annotation_count": "annotation_count",
|
||||||
|
"added_at": "i.added_at",
|
||||||
|
}
|
||||||
|
order_expr = allowed_order_by.get(order_by, "i.filename")
|
||||||
|
dir_norm = str(order_dir).upper().strip()
|
||||||
|
if dir_norm not in {"ASC", "DESC"}:
|
||||||
|
dir_norm = "ASC"
|
||||||
|
|
||||||
|
conn = self.get_connection()
|
||||||
|
try:
|
||||||
|
params: List[Any] = []
|
||||||
|
where_sql = ""
|
||||||
|
if name_filter:
|
||||||
|
# Case-insensitive substring search.
|
||||||
|
token = f"%{name_filter}%"
|
||||||
|
where_sql = "WHERE (i.filename LIKE ? OR i.relative_path LIKE ?)"
|
||||||
|
params.extend([token, token])
|
||||||
|
|
||||||
|
limit_sql = ""
|
||||||
|
if limit is not None:
|
||||||
|
limit_sql = " LIMIT ? OFFSET ?"
|
||||||
|
params.extend([int(limit), int(offset)])
|
||||||
|
|
||||||
|
query = f"""
|
||||||
|
SELECT
|
||||||
|
i.id,
|
||||||
|
i.relative_path,
|
||||||
|
i.filename,
|
||||||
|
i.added_at,
|
||||||
|
COUNT(a.id) AS annotation_count
|
||||||
|
FROM images i
|
||||||
|
JOIN annotations a ON a.image_id = i.id
|
||||||
|
{where_sql}
|
||||||
|
GROUP BY i.id
|
||||||
|
HAVING annotation_count > 0
|
||||||
|
ORDER BY {order_expr} {dir_norm}
|
||||||
|
{limit_sql}
|
||||||
|
"""
|
||||||
|
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(query, params)
|
||||||
|
return [dict(row) for row in cursor.fetchall()]
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
def add_annotation(
|
def add_annotation(
|
||||||
self,
|
self,
|
||||||
image_id: int,
|
image_id: int,
|
||||||
@@ -725,6 +1048,27 @@ class DatabaseManager:
|
|||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
def delete_annotations_for_image(self, image_id: int) -> int:
|
||||||
|
"""Delete ALL annotations for a specific image.
|
||||||
|
|
||||||
|
This is primarily used for import/overwrite workflows.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_id: ID of the image whose annotations should be deleted.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of rows deleted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
conn = self.get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("DELETE FROM annotations WHERE image_id = ?", (int(image_id),))
|
||||||
|
conn.commit()
|
||||||
|
return int(cursor.rowcount or 0)
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
# ==================== Object Class Operations ====================
|
# ==================== Object Class Operations ====================
|
||||||
|
|
||||||
def get_object_classes(self) -> List[Dict]:
|
def get_object_classes(self) -> List[Dict]:
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ CREATE TABLE IF NOT EXISTS models (
|
|||||||
model_name TEXT NOT NULL,
|
model_name TEXT NOT NULL,
|
||||||
model_version TEXT NOT NULL,
|
model_version TEXT NOT NULL,
|
||||||
model_path TEXT NOT NULL,
|
model_path TEXT NOT NULL,
|
||||||
base_model TEXT NOT NULL DEFAULT 'yolov8s.pt',
|
base_model TEXT NOT NULL DEFAULT 'yolo11s.pt',
|
||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
training_params TEXT, -- JSON string of training parameters
|
training_params TEXT, -- JSON string of training parameters
|
||||||
metrics TEXT, -- JSON string of validation metrics
|
metrics TEXT, -- JSON string of validation metrics
|
||||||
@@ -23,7 +23,8 @@ CREATE TABLE IF NOT EXISTS images (
|
|||||||
height INTEGER,
|
height INTEGER,
|
||||||
captured_at TIMESTAMP,
|
captured_at TIMESTAMP,
|
||||||
added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
checksum TEXT
|
checksum TEXT,
|
||||||
|
source TEXT
|
||||||
);
|
);
|
||||||
|
|
||||||
-- Detections table: stores detection results
|
-- Detections table: stores detection results
|
||||||
@@ -44,6 +45,19 @@ CREATE TABLE IF NOT EXISTS detections (
|
|||||||
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE
|
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE
|
||||||
);
|
);
|
||||||
|
|
||||||
|
-- Detection runs table: stores per-image per-model run summaries (including 0 detections)
|
||||||
|
CREATE TABLE IF NOT EXISTS detection_runs (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
image_id INTEGER NOT NULL,
|
||||||
|
model_id INTEGER NOT NULL,
|
||||||
|
detected_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
count INTEGER NOT NULL DEFAULT 0,
|
||||||
|
metadata TEXT,
|
||||||
|
UNIQUE(image_id, model_id),
|
||||||
|
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE,
|
||||||
|
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
|
||||||
-- Object classes table: stores annotation class definitions with colors
|
-- Object classes table: stores annotation class definitions with colors
|
||||||
CREATE TABLE IF NOT EXISTS object_classes (
|
CREATE TABLE IF NOT EXISTS object_classes (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
@@ -55,10 +69,7 @@ CREATE TABLE IF NOT EXISTS object_classes (
|
|||||||
|
|
||||||
-- Insert default object classes
|
-- Insert default object classes
|
||||||
INSERT OR IGNORE INTO object_classes (class_name, color, description) VALUES
|
INSERT OR IGNORE INTO object_classes (class_name, color, description) VALUES
|
||||||
('cell', '#FF0000', 'Cell object'),
|
('terminal', '#FFFF00', 'Axion terminal');
|
||||||
('nucleus', '#00FF00', 'Cell nucleus'),
|
|
||||||
('mitochondria', '#0000FF', 'Mitochondria'),
|
|
||||||
('vesicle', '#FFFF00', 'Vesicle');
|
|
||||||
|
|
||||||
-- Annotations table: stores manual annotations
|
-- Annotations table: stores manual annotations
|
||||||
CREATE TABLE IF NOT EXISTS annotations (
|
CREATE TABLE IF NOT EXISTS annotations (
|
||||||
@@ -83,9 +94,13 @@ CREATE INDEX IF NOT EXISTS idx_detections_model_id ON detections(model_id);
|
|||||||
CREATE INDEX IF NOT EXISTS idx_detections_class_name ON detections(class_name);
|
CREATE INDEX IF NOT EXISTS idx_detections_class_name ON detections(class_name);
|
||||||
CREATE INDEX IF NOT EXISTS idx_detections_detected_at ON detections(detected_at);
|
CREATE INDEX IF NOT EXISTS idx_detections_detected_at ON detections(detected_at);
|
||||||
CREATE INDEX IF NOT EXISTS idx_detections_confidence ON detections(confidence);
|
CREATE INDEX IF NOT EXISTS idx_detections_confidence ON detections(confidence);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_detection_runs_image_id ON detection_runs(image_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_detection_runs_model_id ON detection_runs(model_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_detection_runs_detected_at ON detection_runs(detected_at);
|
||||||
CREATE INDEX IF NOT EXISTS idx_images_relative_path ON images(relative_path);
|
CREATE INDEX IF NOT EXISTS idx_images_relative_path ON images(relative_path);
|
||||||
CREATE INDEX IF NOT EXISTS idx_images_added_at ON images(added_at);
|
CREATE INDEX IF NOT EXISTS idx_images_added_at ON images(added_at);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_images_source ON images(source);
|
||||||
CREATE INDEX IF NOT EXISTS idx_annotations_image_id ON annotations(image_id);
|
CREATE INDEX IF NOT EXISTS idx_annotations_image_id ON annotations(image_id);
|
||||||
CREATE INDEX IF NOT EXISTS idx_annotations_class_id ON annotations(class_id);
|
CREATE INDEX IF NOT EXISTS idx_annotations_class_id ON annotations(class_id);
|
||||||
CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at);
|
CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at);
|
||||||
CREATE INDEX IF NOT EXISTS idx_object_classes_class_name ON object_classes(class_name);
|
CREATE INDEX IF NOT EXISTS idx_object_classes_class_name ON object_classes(class_name);
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""Main window for the microscopy object detection application."""
|
||||||
Main window for the microscopy object detection application.
|
|
||||||
"""
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from PySide6.QtWidgets import (
|
from PySide6.QtWidgets import (
|
||||||
QMainWindow,
|
QMainWindow,
|
||||||
@@ -20,6 +21,7 @@ from src.database.db_manager import DatabaseManager
|
|||||||
from src.utils.config_manager import ConfigManager
|
from src.utils.config_manager import ConfigManager
|
||||||
from src.utils.logger import get_logger
|
from src.utils.logger import get_logger
|
||||||
from src.gui.dialogs.config_dialog import ConfigDialog
|
from src.gui.dialogs.config_dialog import ConfigDialog
|
||||||
|
from src.gui.dialogs.delete_model_dialog import DeleteModelDialog
|
||||||
from src.gui.tabs.detection_tab import DetectionTab
|
from src.gui.tabs.detection_tab import DetectionTab
|
||||||
from src.gui.tabs.training_tab import TrainingTab
|
from src.gui.tabs.training_tab import TrainingTab
|
||||||
from src.gui.tabs.validation_tab import ValidationTab
|
from src.gui.tabs.validation_tab import ValidationTab
|
||||||
@@ -91,6 +93,12 @@ class MainWindow(QMainWindow):
|
|||||||
db_stats_action.triggered.connect(self._show_database_stats)
|
db_stats_action.triggered.connect(self._show_database_stats)
|
||||||
tools_menu.addAction(db_stats_action)
|
tools_menu.addAction(db_stats_action)
|
||||||
|
|
||||||
|
tools_menu.addSeparator()
|
||||||
|
|
||||||
|
delete_model_action = QAction("Delete &Model…", self)
|
||||||
|
delete_model_action.triggered.connect(self._show_delete_model_dialog)
|
||||||
|
tools_menu.addAction(delete_model_action)
|
||||||
|
|
||||||
# Help menu
|
# Help menu
|
||||||
help_menu = menubar.addMenu("&Help")
|
help_menu = menubar.addMenu("&Help")
|
||||||
|
|
||||||
@@ -117,10 +125,10 @@ class MainWindow(QMainWindow):
|
|||||||
|
|
||||||
# Add tabs to widget
|
# Add tabs to widget
|
||||||
self.tab_widget.addTab(self.detection_tab, "Detection")
|
self.tab_widget.addTab(self.detection_tab, "Detection")
|
||||||
|
self.tab_widget.addTab(self.results_tab, "Results")
|
||||||
|
self.tab_widget.addTab(self.annotation_tab, "Annotation")
|
||||||
self.tab_widget.addTab(self.training_tab, "Training")
|
self.tab_widget.addTab(self.training_tab, "Training")
|
||||||
self.tab_widget.addTab(self.validation_tab, "Validation")
|
self.tab_widget.addTab(self.validation_tab, "Validation")
|
||||||
self.tab_widget.addTab(self.results_tab, "Results")
|
|
||||||
self.tab_widget.addTab(self.annotation_tab, "Annotation (Future)")
|
|
||||||
|
|
||||||
# Connect tab change signal
|
# Connect tab change signal
|
||||||
self.tab_widget.currentChanged.connect(self._on_tab_changed)
|
self.tab_widget.currentChanged.connect(self._on_tab_changed)
|
||||||
@@ -152,9 +160,7 @@ class MainWindow(QMainWindow):
|
|||||||
"""Center window on screen."""
|
"""Center window on screen."""
|
||||||
screen = self.screen().geometry()
|
screen = self.screen().geometry()
|
||||||
size = self.geometry()
|
size = self.geometry()
|
||||||
self.move(
|
self.move((screen.width() - size.width()) // 2, (screen.height() - size.height()) // 2)
|
||||||
(screen.width() - size.width()) // 2, (screen.height() - size.height()) // 2
|
|
||||||
)
|
|
||||||
|
|
||||||
def _restore_window_state(self):
|
def _restore_window_state(self):
|
||||||
"""Restore window geometry from settings or center window."""
|
"""Restore window geometry from settings or center window."""
|
||||||
@@ -193,6 +199,10 @@ class MainWindow(QMainWindow):
|
|||||||
self.training_tab.refresh()
|
self.training_tab.refresh()
|
||||||
if hasattr(self, "results_tab"):
|
if hasattr(self, "results_tab"):
|
||||||
self.results_tab.refresh()
|
self.results_tab.refresh()
|
||||||
|
if hasattr(self, "annotation_tab"):
|
||||||
|
self.annotation_tab.refresh()
|
||||||
|
if hasattr(self, "validation_tab"):
|
||||||
|
self.validation_tab.refresh()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error applying settings: {e}")
|
logger.error(f"Error applying settings: {e}")
|
||||||
|
|
||||||
@@ -209,6 +219,14 @@ class MainWindow(QMainWindow):
|
|||||||
logger.debug(f"Switched to tab: {tab_name}")
|
logger.debug(f"Switched to tab: {tab_name}")
|
||||||
self._update_status(f"Viewing: {tab_name}")
|
self._update_status(f"Viewing: {tab_name}")
|
||||||
|
|
||||||
|
# Ensure the Annotation tab always shows up-to-date DB-backed lists.
|
||||||
|
try:
|
||||||
|
current_widget = self.tab_widget.widget(index)
|
||||||
|
if hasattr(self, "annotation_tab") and current_widget is self.annotation_tab:
|
||||||
|
self.annotation_tab.refresh()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug(f"Failed to refresh annotation tab on selection: {exc}")
|
||||||
|
|
||||||
def _show_database_stats(self):
|
def _show_database_stats(self):
|
||||||
"""Show database statistics dialog."""
|
"""Show database statistics dialog."""
|
||||||
try:
|
try:
|
||||||
@@ -231,9 +249,229 @@ class MainWindow(QMainWindow):
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting database stats: {e}")
|
logger.error(f"Error getting database stats: {e}")
|
||||||
QMessageBox.warning(
|
QMessageBox.warning(self, "Error", f"Failed to get database statistics:\n{str(e)}")
|
||||||
self, "Error", f"Failed to get database statistics:\n{str(e)}"
|
|
||||||
)
|
def _show_delete_model_dialog(self) -> None:
|
||||||
|
"""Open the model deletion dialog."""
|
||||||
|
dialog = DeleteModelDialog(self.db_manager, self)
|
||||||
|
if not dialog.exec():
|
||||||
|
return
|
||||||
|
|
||||||
|
model_ids = dialog.selected_model_ids
|
||||||
|
if not model_ids:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._delete_models(model_ids)
|
||||||
|
|
||||||
|
def _delete_models(self, model_ids: list[int]) -> None:
|
||||||
|
"""Delete one or more models from the database and remove artifacts from disk."""
|
||||||
|
|
||||||
|
deleted_count = 0
|
||||||
|
removed_paths: list[str] = []
|
||||||
|
remove_errors: list[str] = []
|
||||||
|
|
||||||
|
for model_id in model_ids:
|
||||||
|
model = None
|
||||||
|
try:
|
||||||
|
model = self.db_manager.get_model_by_id(int(model_id))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Failed to load model {model_id} before deletion: {exc}")
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
remove_errors.append(f"Model id {model_id} not found in database.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
deleted = self.db_manager.delete_model(int(model_id))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Failed to delete model {model_id}: {exc}")
|
||||||
|
remove_errors.append(f"Failed to delete model id {model_id} from DB: {exc}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not deleted:
|
||||||
|
remove_errors.append(f"Model id {model_id} was not deleted (already removed?).")
|
||||||
|
continue
|
||||||
|
|
||||||
|
deleted_count += 1
|
||||||
|
removed, errors = self._delete_model_artifacts_from_disk(model)
|
||||||
|
removed_paths.extend(removed)
|
||||||
|
remove_errors.extend(errors)
|
||||||
|
|
||||||
|
# Refresh tabs to reflect the deletion(s).
|
||||||
|
try:
|
||||||
|
if hasattr(self, "detection_tab"):
|
||||||
|
self.detection_tab.refresh()
|
||||||
|
if hasattr(self, "results_tab"):
|
||||||
|
self.results_tab.refresh()
|
||||||
|
if hasattr(self, "validation_tab"):
|
||||||
|
self.validation_tab.refresh()
|
||||||
|
if hasattr(self, "training_tab"):
|
||||||
|
self.training_tab.refresh()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"Failed to refresh tabs after model deletion: {exc}")
|
||||||
|
|
||||||
|
details: list[str] = []
|
||||||
|
if removed_paths:
|
||||||
|
details.append("Removed from disk:\n" + "\n".join(removed_paths))
|
||||||
|
if remove_errors:
|
||||||
|
details.append("\nDisk cleanup warnings:\n" + "\n".join(remove_errors))
|
||||||
|
|
||||||
|
QMessageBox.information(
|
||||||
|
self,
|
||||||
|
"Delete Model",
|
||||||
|
f"Deleted {deleted_count} model(s) from database." + ("\n\n" + "\n".join(details) if details else ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _delete_model(self, model_id: int) -> None:
|
||||||
|
"""Delete a model from the database and remove its artifacts from disk."""
|
||||||
|
|
||||||
|
model = None
|
||||||
|
try:
|
||||||
|
model = self.db_manager.get_model_by_id(model_id)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Failed to load model {model_id} before deletion: {exc}")
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
QMessageBox.warning(self, "Delete Model", "Selected model was not found in the database.")
|
||||||
|
return
|
||||||
|
|
||||||
|
model_path = str(model.get("model_path") or "")
|
||||||
|
|
||||||
|
try:
|
||||||
|
deleted = self.db_manager.delete_model(model_id)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Failed to delete model {model_id}: {exc}")
|
||||||
|
QMessageBox.critical(self, "Delete Model", f"Failed to delete model from database:\n{exc}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not deleted:
|
||||||
|
QMessageBox.warning(self, "Delete Model", "No model was deleted (it may have already been removed).")
|
||||||
|
return
|
||||||
|
|
||||||
|
removed_paths, remove_errors = self._delete_model_artifacts_from_disk(model)
|
||||||
|
|
||||||
|
# Refresh tabs to reflect the deletion.
|
||||||
|
try:
|
||||||
|
if hasattr(self, "detection_tab"):
|
||||||
|
self.detection_tab.refresh()
|
||||||
|
if hasattr(self, "results_tab"):
|
||||||
|
self.results_tab.refresh()
|
||||||
|
if hasattr(self, "validation_tab"):
|
||||||
|
self.validation_tab.refresh()
|
||||||
|
if hasattr(self, "training_tab"):
|
||||||
|
self.training_tab.refresh()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"Failed to refresh tabs after model deletion: {exc}")
|
||||||
|
|
||||||
|
details = []
|
||||||
|
if model_path:
|
||||||
|
details.append(f"Deleted model record for: {model_path}")
|
||||||
|
if removed_paths:
|
||||||
|
details.append("\nRemoved from disk:\n" + "\n".join(removed_paths))
|
||||||
|
if remove_errors:
|
||||||
|
details.append("\nDisk cleanup warnings:\n" + "\n".join(remove_errors))
|
||||||
|
|
||||||
|
QMessageBox.information(
|
||||||
|
self,
|
||||||
|
"Delete Model",
|
||||||
|
"Model deleted from database." + ("\n\n" + "\n".join(details) if details else ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _delete_model_artifacts_from_disk(self, model: dict) -> tuple[list[str], list[str]]:
|
||||||
|
"""Best-effort removal of model artifacts on disk.
|
||||||
|
|
||||||
|
Strategy:
|
||||||
|
- Remove run directories inferred from:
|
||||||
|
- model.model_path (…/<run>/weights/*.pt => <run>)
|
||||||
|
- training_params.stage_results[].results.save_dir
|
||||||
|
but only if they are under the configured models directory.
|
||||||
|
- If the weights file itself exists and is outside the models directory, delete only the file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(removed_paths, errors)
|
||||||
|
"""
|
||||||
|
|
||||||
|
removed: list[str] = []
|
||||||
|
errors: list[str] = []
|
||||||
|
|
||||||
|
models_root = Path(self.config_manager.get_models_directory() or "data/models").expanduser()
|
||||||
|
try:
|
||||||
|
models_root_resolved = models_root.resolve()
|
||||||
|
except Exception:
|
||||||
|
models_root_resolved = models_root
|
||||||
|
|
||||||
|
inferred_dirs: list[Path] = []
|
||||||
|
|
||||||
|
# 1) From model_path
|
||||||
|
model_path_value = model.get("model_path")
|
||||||
|
if model_path_value:
|
||||||
|
try:
|
||||||
|
p = Path(str(model_path_value)).expanduser()
|
||||||
|
p_resolved = p.resolve() if p.exists() else p
|
||||||
|
if p_resolved.is_file():
|
||||||
|
if p_resolved.parent.name == "weights" and p_resolved.parent.parent.exists():
|
||||||
|
inferred_dirs.append(p_resolved.parent.parent)
|
||||||
|
elif p_resolved.parent.exists():
|
||||||
|
inferred_dirs.append(p_resolved.parent)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 2) From training_params.stage_results[].results.save_dir
|
||||||
|
training_params = model.get("training_params") or {}
|
||||||
|
if isinstance(training_params, dict):
|
||||||
|
stage_results = training_params.get("stage_results")
|
||||||
|
if isinstance(stage_results, list):
|
||||||
|
for stage in stage_results:
|
||||||
|
results = (stage or {}).get("results")
|
||||||
|
save_dir = (results or {}).get("save_dir") if isinstance(results, dict) else None
|
||||||
|
if not save_dir:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
d = Path(str(save_dir)).expanduser()
|
||||||
|
if d.exists() and d.is_dir():
|
||||||
|
inferred_dirs.append(d)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Deduplicate inferred_dirs
|
||||||
|
unique_dirs: list[Path] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for d in inferred_dirs:
|
||||||
|
try:
|
||||||
|
key = str(d.resolve())
|
||||||
|
except Exception:
|
||||||
|
key = str(d)
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
unique_dirs.append(d)
|
||||||
|
|
||||||
|
# Delete directories under models_root
|
||||||
|
for d in unique_dirs:
|
||||||
|
try:
|
||||||
|
d_resolved = d.resolve()
|
||||||
|
except Exception:
|
||||||
|
d_resolved = d
|
||||||
|
try:
|
||||||
|
if d_resolved.exists() and d_resolved.is_dir() and d_resolved.is_relative_to(models_root_resolved):
|
||||||
|
shutil.rmtree(d_resolved)
|
||||||
|
removed.append(str(d_resolved))
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"Failed to remove directory {d_resolved}: {exc}")
|
||||||
|
|
||||||
|
# If nothing matched (e.g., model_path outside models_root), delete just the file.
|
||||||
|
if model_path_value:
|
||||||
|
try:
|
||||||
|
p = Path(str(model_path_value)).expanduser()
|
||||||
|
if p.exists() and p.is_file():
|
||||||
|
p_resolved = p.resolve()
|
||||||
|
if not p_resolved.is_relative_to(models_root_resolved):
|
||||||
|
p_resolved.unlink()
|
||||||
|
removed.append(str(p_resolved))
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"Failed to remove model file {model_path_value}: {exc}")
|
||||||
|
|
||||||
|
return removed, errors
|
||||||
|
|
||||||
def _show_about(self):
|
def _show_about(self):
|
||||||
"""Show about dialog."""
|
"""Show about dialog."""
|
||||||
@@ -301,6 +539,11 @@ class MainWindow(QMainWindow):
|
|||||||
if hasattr(self, "training_tab"):
|
if hasattr(self, "training_tab"):
|
||||||
self.training_tab.shutdown()
|
self.training_tab.shutdown()
|
||||||
if hasattr(self, "annotation_tab"):
|
if hasattr(self, "annotation_tab"):
|
||||||
|
# Best-effort refresh so DB-backed UI state is consistent at shutdown.
|
||||||
|
try:
|
||||||
|
self.annotation_tab.refresh()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
self.annotation_tab.save_state()
|
self.annotation_tab.save_state()
|
||||||
|
|
||||||
logger.info("Application closing")
|
logger.info("Application closing")
|
||||||
|
|||||||
@@ -13,6 +13,11 @@ from PySide6.QtWidgets import (
|
|||||||
QFileDialog,
|
QFileDialog,
|
||||||
QMessageBox,
|
QMessageBox,
|
||||||
QSplitter,
|
QSplitter,
|
||||||
|
QLineEdit,
|
||||||
|
QTableWidget,
|
||||||
|
QTableWidgetItem,
|
||||||
|
QHeaderView,
|
||||||
|
QAbstractItemView,
|
||||||
)
|
)
|
||||||
from PySide6.QtCore import Qt, QSettings
|
from PySide6.QtCore import Qt, QSettings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -29,9 +34,7 @@ logger = get_logger(__name__)
|
|||||||
class AnnotationTab(QWidget):
|
class AnnotationTab(QWidget):
|
||||||
"""Annotation tab for manual image annotation."""
|
"""Annotation tab for manual image annotation."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
|
||||||
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
|
||||||
):
|
|
||||||
super().__init__(parent)
|
super().__init__(parent)
|
||||||
self.db_manager = db_manager
|
self.db_manager = db_manager
|
||||||
self.config_manager = config_manager
|
self.config_manager = config_manager
|
||||||
@@ -52,6 +55,32 @@ class AnnotationTab(QWidget):
|
|||||||
self.main_splitter = QSplitter(Qt.Horizontal)
|
self.main_splitter = QSplitter(Qt.Horizontal)
|
||||||
self.main_splitter.setHandleWidth(10)
|
self.main_splitter.setHandleWidth(10)
|
||||||
|
|
||||||
|
# { Left-most pane: annotated images list
|
||||||
|
annotated_group = QGroupBox("Annotated Images")
|
||||||
|
annotated_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
filter_row = QHBoxLayout()
|
||||||
|
filter_row.addWidget(QLabel("Filter:"))
|
||||||
|
self.annotated_filter_edit = QLineEdit()
|
||||||
|
self.annotated_filter_edit.setPlaceholderText("Type to filter by image name…")
|
||||||
|
self.annotated_filter_edit.textChanged.connect(self._refresh_annotated_images_list)
|
||||||
|
filter_row.addWidget(self.annotated_filter_edit, 1)
|
||||||
|
annotated_layout.addLayout(filter_row)
|
||||||
|
|
||||||
|
self.annotated_images_table = QTableWidget(0, 2)
|
||||||
|
self.annotated_images_table.setHorizontalHeaderLabels(["Image", "Annotations"])
|
||||||
|
self.annotated_images_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Stretch)
|
||||||
|
self.annotated_images_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeToContents)
|
||||||
|
self.annotated_images_table.setSelectionBehavior(QAbstractItemView.SelectRows)
|
||||||
|
self.annotated_images_table.setSelectionMode(QAbstractItemView.SingleSelection)
|
||||||
|
self.annotated_images_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
|
||||||
|
self.annotated_images_table.setSortingEnabled(True)
|
||||||
|
self.annotated_images_table.itemSelectionChanged.connect(self._on_annotated_image_selected)
|
||||||
|
annotated_layout.addWidget(self.annotated_images_table, 1)
|
||||||
|
|
||||||
|
annotated_group.setLayout(annotated_layout)
|
||||||
|
# }
|
||||||
|
|
||||||
# { Left splitter for image display and zoom info
|
# { Left splitter for image display and zoom info
|
||||||
self.left_splitter = QSplitter(Qt.Vertical)
|
self.left_splitter = QSplitter(Qt.Vertical)
|
||||||
self.left_splitter.setHandleWidth(10)
|
self.left_splitter.setHandleWidth(10)
|
||||||
@@ -62,6 +91,9 @@ class AnnotationTab(QWidget):
|
|||||||
|
|
||||||
# Use the AnnotationCanvasWidget
|
# Use the AnnotationCanvasWidget
|
||||||
self.annotation_canvas = AnnotationCanvasWidget()
|
self.annotation_canvas = AnnotationCanvasWidget()
|
||||||
|
# Auto-zoom so newly loaded images fill the available canvas viewport.
|
||||||
|
# (Matches the behavior used in ResultsTab.)
|
||||||
|
self.annotation_canvas.set_auto_fit_to_view(True)
|
||||||
self.annotation_canvas.zoom_changed.connect(self._on_zoom_changed)
|
self.annotation_canvas.zoom_changed.connect(self._on_zoom_changed)
|
||||||
self.annotation_canvas.annotation_drawn.connect(self._on_annotation_drawn)
|
self.annotation_canvas.annotation_drawn.connect(self._on_annotation_drawn)
|
||||||
# Selection of existing polylines (when tool is not in drawing mode)
|
# Selection of existing polylines (when tool is not in drawing mode)
|
||||||
@@ -72,9 +104,7 @@ class AnnotationTab(QWidget):
|
|||||||
self.left_splitter.addWidget(canvas_group)
|
self.left_splitter.addWidget(canvas_group)
|
||||||
|
|
||||||
# Controls info
|
# Controls info
|
||||||
controls_info = QLabel(
|
controls_info = QLabel("Zoom: Mouse wheel or +/- keys | Drawing: Enable pen and drag mouse")
|
||||||
"Zoom: Mouse wheel or +/- keys | Drawing: Enable pen and drag mouse"
|
|
||||||
)
|
|
||||||
controls_info.setStyleSheet("QLabel { color: #888; font-style: italic; }")
|
controls_info.setStyleSheet("QLabel { color: #888; font-style: italic; }")
|
||||||
self.left_splitter.addWidget(controls_info)
|
self.left_splitter.addWidget(controls_info)
|
||||||
# }
|
# }
|
||||||
@@ -85,47 +115,47 @@ class AnnotationTab(QWidget):
|
|||||||
|
|
||||||
# Annotation tools section
|
# Annotation tools section
|
||||||
self.annotation_tools = AnnotationToolsWidget(self.db_manager)
|
self.annotation_tools = AnnotationToolsWidget(self.db_manager)
|
||||||
self.annotation_tools.polyline_enabled_changed.connect(
|
self.annotation_tools.polyline_enabled_changed.connect(self.annotation_canvas.set_polyline_enabled)
|
||||||
self.annotation_canvas.set_polyline_enabled
|
self.annotation_tools.polyline_pen_color_changed.connect(self.annotation_canvas.set_polyline_pen_color)
|
||||||
)
|
self.annotation_tools.polyline_pen_width_changed.connect(self.annotation_canvas.set_polyline_pen_width)
|
||||||
self.annotation_tools.polyline_pen_color_changed.connect(
|
|
||||||
self.annotation_canvas.set_polyline_pen_color
|
|
||||||
)
|
|
||||||
self.annotation_tools.polyline_pen_width_changed.connect(
|
|
||||||
self.annotation_canvas.set_polyline_pen_width
|
|
||||||
)
|
|
||||||
# Show / hide bounding boxes
|
# Show / hide bounding boxes
|
||||||
self.annotation_tools.show_bboxes_changed.connect(
|
self.annotation_tools.show_bboxes_changed.connect(self.annotation_canvas.set_show_bboxes)
|
||||||
self.annotation_canvas.set_show_bboxes
|
|
||||||
)
|
|
||||||
# RDP simplification controls
|
# RDP simplification controls
|
||||||
self.annotation_tools.simplify_on_finish_changed.connect(
|
self.annotation_tools.simplify_on_finish_changed.connect(self._on_simplify_on_finish_changed)
|
||||||
self._on_simplify_on_finish_changed
|
self.annotation_tools.simplify_epsilon_changed.connect(self._on_simplify_epsilon_changed)
|
||||||
)
|
|
||||||
self.annotation_tools.simplify_epsilon_changed.connect(
|
|
||||||
self._on_simplify_epsilon_changed
|
|
||||||
)
|
|
||||||
# Class selection and class-color changes
|
# Class selection and class-color changes
|
||||||
self.annotation_tools.class_selected.connect(self._on_class_selected)
|
self.annotation_tools.class_selected.connect(self._on_class_selected)
|
||||||
self.annotation_tools.class_color_changed.connect(self._on_class_color_changed)
|
self.annotation_tools.class_color_changed.connect(self._on_class_color_changed)
|
||||||
self.annotation_tools.clear_annotations_requested.connect(
|
self.annotation_tools.clear_annotations_requested.connect(self._on_clear_annotations)
|
||||||
self._on_clear_annotations
|
|
||||||
)
|
|
||||||
# Delete selected annotation on canvas
|
# Delete selected annotation on canvas
|
||||||
self.annotation_tools.delete_selected_annotation_requested.connect(
|
self.annotation_tools.delete_selected_annotation_requested.connect(self._on_delete_selected_annotation)
|
||||||
self._on_delete_selected_annotation
|
|
||||||
)
|
|
||||||
self.right_splitter.addWidget(self.annotation_tools)
|
self.right_splitter.addWidget(self.annotation_tools)
|
||||||
|
|
||||||
# Image loading section
|
# Image loading section
|
||||||
load_group = QGroupBox("Image Loading")
|
load_group = QGroupBox("Image Loading")
|
||||||
load_layout = QVBoxLayout()
|
load_layout = QVBoxLayout()
|
||||||
|
|
||||||
# Load image button
|
# Buttons row
|
||||||
button_layout = QHBoxLayout()
|
button_layout = QHBoxLayout()
|
||||||
self.load_image_btn = QPushButton("Load Image")
|
self.load_image_btn = QPushButton("Load Image")
|
||||||
self.load_image_btn.clicked.connect(self._load_image)
|
self.load_image_btn.clicked.connect(self._load_image)
|
||||||
button_layout.addWidget(self.load_image_btn)
|
button_layout.addWidget(self.load_image_btn)
|
||||||
|
|
||||||
|
self.import_images_btn = QPushButton("Import Images")
|
||||||
|
self.import_images_btn.setToolTip(
|
||||||
|
"Import one or more images into the database.\n" "Images already present in the DB are skipped."
|
||||||
|
)
|
||||||
|
self.import_images_btn.clicked.connect(self._import_images)
|
||||||
|
button_layout.addWidget(self.import_images_btn)
|
||||||
|
|
||||||
|
self.import_annotations_btn = QPushButton("Import Annotations")
|
||||||
|
self.import_annotations_btn.setToolTip(
|
||||||
|
"Import YOLO .txt annotation files and register them with their corresponding images.\n"
|
||||||
|
"Existing annotations for those images will be overwritten."
|
||||||
|
)
|
||||||
|
self.import_annotations_btn.clicked.connect(self._import_annotations)
|
||||||
|
button_layout.addWidget(self.import_annotations_btn)
|
||||||
|
|
||||||
button_layout.addStretch()
|
button_layout.addStretch()
|
||||||
load_layout.addLayout(button_layout)
|
load_layout.addLayout(button_layout)
|
||||||
|
|
||||||
@@ -137,12 +167,13 @@ class AnnotationTab(QWidget):
|
|||||||
self.right_splitter.addWidget(load_group)
|
self.right_splitter.addWidget(load_group)
|
||||||
# }
|
# }
|
||||||
|
|
||||||
# Add both splitters to the main horizontal splitter
|
# Add list + both splitters to the main horizontal splitter
|
||||||
|
self.main_splitter.addWidget(annotated_group)
|
||||||
self.main_splitter.addWidget(self.left_splitter)
|
self.main_splitter.addWidget(self.left_splitter)
|
||||||
self.main_splitter.addWidget(self.right_splitter)
|
self.main_splitter.addWidget(self.right_splitter)
|
||||||
|
|
||||||
# Set initial sizes: 75% for left (image), 25% for right (controls)
|
# Set initial sizes: list (left), canvas (middle), controls (right)
|
||||||
self.main_splitter.setSizes([750, 250])
|
self.main_splitter.setSizes([320, 650, 280])
|
||||||
|
|
||||||
layout.addWidget(self.main_splitter)
|
layout.addWidget(self.main_splitter)
|
||||||
self.setLayout(layout)
|
self.setLayout(layout)
|
||||||
@@ -150,6 +181,375 @@ class AnnotationTab(QWidget):
|
|||||||
# Restore splitter positions from settings
|
# Restore splitter positions from settings
|
||||||
self._restore_state()
|
self._restore_state()
|
||||||
|
|
||||||
|
# Populate list on startup.
|
||||||
|
self._refresh_annotated_images_list()
|
||||||
|
|
||||||
|
def _import_images(self) -> None:
|
||||||
|
"""Import one or more images into the database and refresh the list."""
|
||||||
|
|
||||||
|
settings = QSettings("microscopy_app", "object_detection")
|
||||||
|
last_dir = settings.value("annotation_tab/last_image_import_directory", None)
|
||||||
|
|
||||||
|
repo_root = (self.config_manager.get_image_repository_path() or "").strip()
|
||||||
|
if last_dir and Path(str(last_dir)).exists():
|
||||||
|
start_dir = str(last_dir)
|
||||||
|
elif repo_root and Path(repo_root).exists():
|
||||||
|
start_dir = repo_root
|
||||||
|
else:
|
||||||
|
start_dir = str(Path.home())
|
||||||
|
|
||||||
|
# Build filter string for supported extensions
|
||||||
|
patterns = " ".join(f"*{ext}" for ext in Image.SUPPORTED_EXTENSIONS)
|
||||||
|
file_paths, _ = QFileDialog.getOpenFileNames(
|
||||||
|
self,
|
||||||
|
"Select Image File(s)",
|
||||||
|
start_dir,
|
||||||
|
f"Images ({patterns})",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not file_paths:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
settings.setValue("annotation_tab/last_image_import_directory", str(Path(file_paths[0]).parent))
|
||||||
|
# Keep compatibility with the existing image resolver fallback (it checks last_directory).
|
||||||
|
settings.setValue("annotation_tab/last_directory", str(Path(file_paths[0]).parent))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
imported = 0
|
||||||
|
tagged_existing = 0
|
||||||
|
skipped = 0
|
||||||
|
errors: list[str] = []
|
||||||
|
|
||||||
|
for fp in file_paths:
|
||||||
|
try:
|
||||||
|
img_path = Path(fp)
|
||||||
|
img = Image(str(img_path))
|
||||||
|
relative_path = self._compute_relative_path_for_repo(img_path)
|
||||||
|
|
||||||
|
# Skip if already present
|
||||||
|
existing = self.db_manager.get_image_by_path(relative_path)
|
||||||
|
if existing:
|
||||||
|
# If the image already exists (e.g. created earlier by other workflows),
|
||||||
|
# tag it as being managed by the Annotation tab so it becomes visible
|
||||||
|
# in the left list.
|
||||||
|
try:
|
||||||
|
self.db_manager.set_image_source(int(existing["id"]), "annotation_tab")
|
||||||
|
tagged_existing += 1
|
||||||
|
except Exception:
|
||||||
|
# If tagging fails, fall back to treating as skipped.
|
||||||
|
skipped += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
image_id = self.db_manager.add_image(
|
||||||
|
relative_path,
|
||||||
|
img_path.name,
|
||||||
|
img.width,
|
||||||
|
img.height,
|
||||||
|
source="annotation_tab",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
# In case the DB row was created by an older schema/migration path.
|
||||||
|
self.db_manager.set_image_source(image_id, "annotation_tab")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
imported += 1
|
||||||
|
except ImageLoadError as exc:
|
||||||
|
skipped += 1
|
||||||
|
errors.append(f"Failed to load image {fp}: {exc}")
|
||||||
|
except Exception as exc:
|
||||||
|
skipped += 1
|
||||||
|
errors.append(f"Failed to import image {fp}: {exc}")
|
||||||
|
|
||||||
|
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
|
||||||
|
|
||||||
|
msg = (
|
||||||
|
f"Imported: {imported}\n"
|
||||||
|
f"Already in DB (tagged for Annotation tab): {tagged_existing}\n"
|
||||||
|
f"Skipped (errors): {skipped}"
|
||||||
|
)
|
||||||
|
if errors:
|
||||||
|
details = "\n".join(errors[:25])
|
||||||
|
if len(errors) > 25:
|
||||||
|
details += f"\n... and {len(errors) - 25} more"
|
||||||
|
msg += "\n\nDetails:\n" + details
|
||||||
|
QMessageBox.information(self, "Import Images", msg)
|
||||||
|
|
||||||
|
# ==================== Import annotations (YOLO .txt) ====================
|
||||||
|
|
||||||
|
def _import_annotations(self) -> None:
|
||||||
|
"""Import YOLO segmentation/bbox annotations from one or more .txt files."""
|
||||||
|
|
||||||
|
settings = QSettings("microscopy_app", "object_detection")
|
||||||
|
last_dir = settings.value("annotation_tab/last_annotation_directory", None)
|
||||||
|
|
||||||
|
# Default start dir: repo root if set, otherwise last used, otherwise home.
|
||||||
|
repo_root = (self.config_manager.get_image_repository_path() or "").strip()
|
||||||
|
if last_dir and Path(str(last_dir)).exists():
|
||||||
|
start_dir = str(last_dir)
|
||||||
|
elif repo_root and Path(repo_root).exists():
|
||||||
|
start_dir = repo_root
|
||||||
|
else:
|
||||||
|
start_dir = str(Path.home())
|
||||||
|
|
||||||
|
file_paths, _ = QFileDialog.getOpenFileNames(
|
||||||
|
self,
|
||||||
|
"Select YOLO Annotation File(s)",
|
||||||
|
start_dir,
|
||||||
|
"YOLO annotations (*.txt)",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not file_paths:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Persist last annotation directory for the next import.
|
||||||
|
try:
|
||||||
|
settings.setValue("annotation_tab/last_annotation_directory", str(Path(file_paths[0]).parent))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
imported_images = 0
|
||||||
|
imported_annotations = 0
|
||||||
|
overwritten_images = 0
|
||||||
|
skipped = 0
|
||||||
|
errors: list[str] = []
|
||||||
|
|
||||||
|
for label_file in file_paths:
|
||||||
|
label_path = Path(label_file)
|
||||||
|
try:
|
||||||
|
image_path = self._infer_corresponding_image_path(label_path)
|
||||||
|
if not image_path:
|
||||||
|
skipped += 1
|
||||||
|
errors.append(f"Image not found for label file: {label_path}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Load image to obtain width/height for DB entry.
|
||||||
|
img = Image(str(image_path))
|
||||||
|
|
||||||
|
# Store in DB using a repo-relative path if possible.
|
||||||
|
relative_path = self._compute_relative_path_for_repo(image_path)
|
||||||
|
image_id = self.db_manager.get_or_create_image(relative_path, image_path.name, img.width, img.height)
|
||||||
|
try:
|
||||||
|
self.db_manager.set_image_source(image_id, "annotation_tab")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Overwrite existing annotations for this image.
|
||||||
|
try:
|
||||||
|
deleted = self.db_manager.delete_annotations_for_image(image_id)
|
||||||
|
except AttributeError:
|
||||||
|
# Safety fallback if older DBManager is used.
|
||||||
|
deleted = 0
|
||||||
|
if deleted > 0:
|
||||||
|
overwritten_images += 1
|
||||||
|
|
||||||
|
# Parse YOLO lines and insert as annotations.
|
||||||
|
parsed = self._parse_yolo_annotation_file(label_path)
|
||||||
|
if not parsed:
|
||||||
|
# Empty/invalid label file: treat as "clear" operation (already deleted above)
|
||||||
|
imported_images += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
db_classes = self.db_manager.get_object_classes() or []
|
||||||
|
classes_by_index = {idx: row for idx, row in enumerate(db_classes)}
|
||||||
|
|
||||||
|
for class_index, bbox, poly in parsed:
|
||||||
|
class_row = classes_by_index.get(int(class_index))
|
||||||
|
if not class_row:
|
||||||
|
skipped += 1
|
||||||
|
errors.append(
|
||||||
|
f"Unknown class index {class_index} in {label_path.name}. "
|
||||||
|
"Create object classes in the UI first (class index is based on DB ordering)."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
ann_id = self.db_manager.add_annotation(
|
||||||
|
image_id=image_id,
|
||||||
|
class_id=int(class_row["id"]),
|
||||||
|
bbox=bbox,
|
||||||
|
annotator="import",
|
||||||
|
segmentation_mask=poly,
|
||||||
|
verified=False,
|
||||||
|
)
|
||||||
|
if ann_id:
|
||||||
|
imported_annotations += 1
|
||||||
|
|
||||||
|
imported_images += 1
|
||||||
|
|
||||||
|
# If we imported for the currently open image, reload.
|
||||||
|
if self.current_image_id and int(self.current_image_id) == int(image_id):
|
||||||
|
self._load_annotations_for_current_image()
|
||||||
|
|
||||||
|
except ImageLoadError as exc:
|
||||||
|
skipped += 1
|
||||||
|
errors.append(f"Failed to load image for {label_path.name}: {exc}")
|
||||||
|
except Exception as exc:
|
||||||
|
skipped += 1
|
||||||
|
errors.append(f"Import failed for {label_path.name}: {exc}")
|
||||||
|
|
||||||
|
# Refresh annotated images list.
|
||||||
|
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
|
||||||
|
|
||||||
|
summary = (
|
||||||
|
f"Imported files: {len(file_paths)}\n"
|
||||||
|
f"Images processed: {imported_images}\n"
|
||||||
|
f"Annotations inserted: {imported_annotations}\n"
|
||||||
|
f"Images overwritten (had existing annotations): {overwritten_images}\n"
|
||||||
|
f"Skipped: {skipped}"
|
||||||
|
)
|
||||||
|
if errors:
|
||||||
|
# Cap error details to avoid huge dialogs.
|
||||||
|
details = "\n".join(errors[:25])
|
||||||
|
if len(errors) > 25:
|
||||||
|
details += f"\n... and {len(errors) - 25} more"
|
||||||
|
summary += "\n\nDetails:\n" + details
|
||||||
|
|
||||||
|
QMessageBox.information(self, "Import Annotations", summary)
|
||||||
|
|
||||||
|
def _infer_corresponding_image_path(self, label_path: Path) -> Path | None:
|
||||||
|
"""Infer image path from YOLO label file path.
|
||||||
|
|
||||||
|
Requirement: image(s) live in an `images/` folder located in the label file's parent directory.
|
||||||
|
Example:
|
||||||
|
/dataset/train/labels/img123.txt -> /dataset/train/images/img123.(any supported ext)
|
||||||
|
"""
|
||||||
|
|
||||||
|
parent = label_path.parent
|
||||||
|
images_dir = parent.parent / "images"
|
||||||
|
stem = label_path.stem
|
||||||
|
|
||||||
|
# 1) Direct stem match in images dir (any supported extension)
|
||||||
|
for ext in Image.SUPPORTED_EXTENSIONS:
|
||||||
|
candidate = images_dir / f"{stem}{ext}"
|
||||||
|
if candidate.exists() and candidate.is_file():
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
# 2) Fallback: repository-root search by filename
|
||||||
|
repo_root = (self.config_manager.get_image_repository_path() or "").strip()
|
||||||
|
if repo_root:
|
||||||
|
root = Path(repo_root).expanduser()
|
||||||
|
try:
|
||||||
|
if root.exists():
|
||||||
|
for ext in Image.SUPPORTED_EXTENSIONS:
|
||||||
|
filename = f"{stem}{ext}"
|
||||||
|
for match in root.rglob(filename):
|
||||||
|
if match.is_file():
|
||||||
|
return match.resolve()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _compute_relative_path_for_repo(self, image_path: Path) -> str:
|
||||||
|
"""Compute a stable `relative_path` suitable for DB storage.
|
||||||
|
|
||||||
|
Policy:
|
||||||
|
- If an image repository root is configured and the image is under it, store a repo-relative path.
|
||||||
|
- Otherwise, store an absolute resolved path so the image can be reopened later.
|
||||||
|
"""
|
||||||
|
|
||||||
|
repo_root = (self.config_manager.get_image_repository_path() or "").strip()
|
||||||
|
try:
|
||||||
|
if repo_root:
|
||||||
|
repo_root_path = Path(repo_root).expanduser().resolve()
|
||||||
|
img_resolved = image_path.expanduser().resolve()
|
||||||
|
if img_resolved.is_relative_to(repo_root_path):
|
||||||
|
return img_resolved.relative_to(repo_root_path).as_posix()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
return str(image_path.expanduser().resolve())
|
||||||
|
except Exception:
|
||||||
|
return str(image_path)
|
||||||
|
|
||||||
|
def _parse_yolo_annotation_file(
|
||||||
|
self, label_path: Path
|
||||||
|
) -> list[tuple[int, tuple[float, float, float, float], list[list[float]] | None]]:
|
||||||
|
"""Parse a YOLO .txt label file.
|
||||||
|
|
||||||
|
Supports:
|
||||||
|
- YOLO segmentation polygon format: "class x1 y1 x2 y2 ..." (normalized)
|
||||||
|
- YOLO bbox format: "class x_center y_center width height" (normalized)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (class_index, bbox_xyxy_norm, segmentation_mask_db)
|
||||||
|
Where segmentation_mask_db is [[y_norm, x_norm], ...] or None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
out: list[tuple[int, tuple[float, float, float, float], list[list[float]] | None]] = []
|
||||||
|
try:
|
||||||
|
raw = label_path.read_text(encoding="utf-8").splitlines()
|
||||||
|
except OSError as exc:
|
||||||
|
logger.error(f"Failed to read label file {label_path}: {exc}")
|
||||||
|
return out
|
||||||
|
|
||||||
|
for line in raw:
|
||||||
|
stripped = line.strip()
|
||||||
|
if not stripped:
|
||||||
|
continue
|
||||||
|
parts = stripped.split()
|
||||||
|
if len(parts) < 5:
|
||||||
|
# not enough for bbox
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
class_idx = int(float(parts[0]))
|
||||||
|
coords = [float(x) for x in parts[1:]]
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Segmentation polygon format (>= 6 values)
|
||||||
|
if len(coords) >= 6:
|
||||||
|
# bbox is not explicitly present in this format in our importer; compute from polygon.
|
||||||
|
xs = coords[0::2]
|
||||||
|
ys = coords[1::2]
|
||||||
|
if not xs or not ys:
|
||||||
|
continue
|
||||||
|
x_min, x_max = min(xs), max(xs)
|
||||||
|
y_min, y_max = min(ys), max(ys)
|
||||||
|
bbox = (
|
||||||
|
self._clamp01(x_min),
|
||||||
|
self._clamp01(y_min),
|
||||||
|
self._clamp01(x_max),
|
||||||
|
self._clamp01(y_max),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to DB polyline convention: [[y_norm, x_norm], ...]
|
||||||
|
poly: list[list[float]] = []
|
||||||
|
for x, y in zip(xs, ys):
|
||||||
|
poly.append([self._clamp01(float(y)), self._clamp01(float(x))])
|
||||||
|
# Ensure closure for consistency (optional)
|
||||||
|
if poly and poly[0] != poly[-1]:
|
||||||
|
poly.append(list(poly[0]))
|
||||||
|
out.append((class_idx, bbox, poly))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# bbox format: xc yc w h
|
||||||
|
if len(coords) >= 4:
|
||||||
|
xc, yc, w, h = coords[:4]
|
||||||
|
x_min = xc - w / 2.0
|
||||||
|
y_min = yc - h / 2.0
|
||||||
|
x_max = xc + w / 2.0
|
||||||
|
y_max = yc + h / 2.0
|
||||||
|
bbox = (
|
||||||
|
self._clamp01(float(x_min)),
|
||||||
|
self._clamp01(float(y_min)),
|
||||||
|
self._clamp01(float(x_max)),
|
||||||
|
self._clamp01(float(y_max)),
|
||||||
|
)
|
||||||
|
out.append((class_idx, bbox, None))
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _clamp01(value: float) -> float:
|
||||||
|
if value < 0.0:
|
||||||
|
return 0.0
|
||||||
|
if value > 1.0:
|
||||||
|
return 1.0
|
||||||
|
return float(value)
|
||||||
|
|
||||||
def _load_image(self):
|
def _load_image(self):
|
||||||
"""Load and display an image file."""
|
"""Load and display an image file."""
|
||||||
# Get last opened directory from QSettings
|
# Get last opened directory from QSettings
|
||||||
@@ -180,18 +580,35 @@ class AnnotationTab(QWidget):
|
|||||||
self.current_image_path = file_path
|
self.current_image_path = file_path
|
||||||
|
|
||||||
# Store the directory for next time
|
# Store the directory for next time
|
||||||
settings.setValue(
|
settings.setValue("annotation_tab/last_directory", str(Path(file_path).parent))
|
||||||
"annotation_tab/last_directory", str(Path(file_path).parent)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get or create image in database
|
# Get or create image in database
|
||||||
relative_path = str(Path(file_path).name) # Simplified for now
|
repo_root = self.config_manager.get_image_repository_path()
|
||||||
|
relative_path: str
|
||||||
|
try:
|
||||||
|
if repo_root:
|
||||||
|
repo_root_path = Path(repo_root).expanduser().resolve()
|
||||||
|
file_resolved = Path(file_path).expanduser().resolve()
|
||||||
|
if file_resolved.is_relative_to(repo_root_path):
|
||||||
|
relative_path = file_resolved.relative_to(repo_root_path).as_posix()
|
||||||
|
else:
|
||||||
|
# Fallback: store filename only to avoid leaking absolute paths.
|
||||||
|
relative_path = file_resolved.name
|
||||||
|
else:
|
||||||
|
relative_path = str(Path(file_path).name)
|
||||||
|
except Exception:
|
||||||
|
relative_path = str(Path(file_path).name)
|
||||||
self.current_image_id = self.db_manager.get_or_create_image(
|
self.current_image_id = self.db_manager.get_or_create_image(
|
||||||
relative_path,
|
relative_path,
|
||||||
Path(file_path).name,
|
Path(file_path).name,
|
||||||
self.current_image.width,
|
self.current_image.width,
|
||||||
self.current_image.height,
|
self.current_image.height,
|
||||||
)
|
)
|
||||||
|
# Mark as managed by Annotation tab so it appears in the left list.
|
||||||
|
try:
|
||||||
|
self.db_manager.set_image_source(int(self.current_image_id), "annotation_tab")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# Display image using the AnnotationCanvasWidget
|
# Display image using the AnnotationCanvasWidget
|
||||||
self.annotation_canvas.load_image(self.current_image)
|
self.annotation_canvas.load_image(self.current_image)
|
||||||
@@ -199,6 +616,9 @@ class AnnotationTab(QWidget):
|
|||||||
# Load and display any existing annotations for this image
|
# Load and display any existing annotations for this image
|
||||||
self._load_annotations_for_current_image()
|
self._load_annotations_for_current_image()
|
||||||
|
|
||||||
|
# Update annotated images list (newly annotated image added/selected).
|
||||||
|
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
|
||||||
|
|
||||||
# Update info label
|
# Update info label
|
||||||
self._update_image_info()
|
self._update_image_info()
|
||||||
|
|
||||||
@@ -206,9 +626,7 @@ class AnnotationTab(QWidget):
|
|||||||
|
|
||||||
except ImageLoadError as e:
|
except ImageLoadError as e:
|
||||||
logger.error(f"Failed to load image: {e}")
|
logger.error(f"Failed to load image: {e}")
|
||||||
QMessageBox.critical(
|
QMessageBox.critical(self, "Error Loading Image", f"Failed to load image:\n{str(e)}")
|
||||||
self, "Error Loading Image", f"Failed to load image:\n{str(e)}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error loading image: {e}")
|
logger.error(f"Unexpected error loading image: {e}")
|
||||||
QMessageBox.critical(self, "Error", f"Unexpected error:\n{str(e)}")
|
QMessageBox.critical(self, "Error", f"Unexpected error:\n{str(e)}")
|
||||||
@@ -296,6 +714,9 @@ class AnnotationTab(QWidget):
|
|||||||
# Reload annotations from DB and redraw (respecting current class filter)
|
# Reload annotations from DB and redraw (respecting current class filter)
|
||||||
self._load_annotations_for_current_image()
|
self._load_annotations_for_current_image()
|
||||||
|
|
||||||
|
# Update list counts.
|
||||||
|
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to save annotation: {e}")
|
logger.error(f"Failed to save annotation: {e}")
|
||||||
QMessageBox.critical(self, "Error", f"Failed to save annotation:\n{str(e)}")
|
QMessageBox.critical(self, "Error", f"Failed to save annotation:\n{str(e)}")
|
||||||
@@ -340,9 +761,7 @@ class AnnotationTab(QWidget):
|
|||||||
if not self.current_image_id:
|
if not self.current_image_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(f"Class color changed; reloading annotations for image ID {self.current_image_id}")
|
||||||
f"Class color changed; reloading annotations for image ID {self.current_image_id}"
|
|
||||||
)
|
|
||||||
self._load_annotations_for_current_image()
|
self._load_annotations_for_current_image()
|
||||||
|
|
||||||
def _on_class_selected(self, class_data):
|
def _on_class_selected(self, class_data):
|
||||||
@@ -355,9 +774,7 @@ class AnnotationTab(QWidget):
|
|||||||
if class_data:
|
if class_data:
|
||||||
logger.debug(f"Object class selected: {class_data['class_name']}")
|
logger.debug(f"Object class selected: {class_data['class_name']}")
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
logger.debug('No class selected ("-- Select Class --"), showing all annotations')
|
||||||
'No class selected ("-- Select Class --"), showing all annotations'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Changing the class filter invalidates any previous selection
|
# Changing the class filter invalidates any previous selection
|
||||||
self.selected_annotation_ids = []
|
self.selected_annotation_ids = []
|
||||||
@@ -390,9 +807,7 @@ class AnnotationTab(QWidget):
|
|||||||
question = "Are you sure you want to delete the selected annotation?"
|
question = "Are you sure you want to delete the selected annotation?"
|
||||||
title = "Delete Annotation"
|
title = "Delete Annotation"
|
||||||
else:
|
else:
|
||||||
question = (
|
question = f"Are you sure you want to delete the {count} selected annotations?"
|
||||||
f"Are you sure you want to delete the {count} selected annotations?"
|
|
||||||
)
|
|
||||||
title = "Delete Annotations"
|
title = "Delete Annotations"
|
||||||
|
|
||||||
reply = QMessageBox.question(
|
reply = QMessageBox.question(
|
||||||
@@ -420,13 +835,11 @@ class AnnotationTab(QWidget):
|
|||||||
QMessageBox.warning(
|
QMessageBox.warning(
|
||||||
self,
|
self,
|
||||||
"Partial Failure",
|
"Partial Failure",
|
||||||
"Some annotations could not be deleted:\n"
|
"Some annotations could not be deleted:\n" + ", ".join(str(a) for a in failed_ids),
|
||||||
+ ", ".join(str(a) for a in failed_ids),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Deleted {count} annotation(s): "
|
f"Deleted {count} annotation(s): " + ", ".join(str(a) for a in self.selected_annotation_ids)
|
||||||
+ ", ".join(str(a) for a in self.selected_annotation_ids)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Clear selection and reload annotations for the current image from DB
|
# Clear selection and reload annotations for the current image from DB
|
||||||
@@ -434,6 +847,9 @@ class AnnotationTab(QWidget):
|
|||||||
self.annotation_tools.set_has_selected_annotation(False)
|
self.annotation_tools.set_has_selected_annotation(False)
|
||||||
self._load_annotations_for_current_image()
|
self._load_annotations_for_current_image()
|
||||||
|
|
||||||
|
# Update list counts.
|
||||||
|
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to delete annotations: {e}")
|
logger.error(f"Failed to delete annotations: {e}")
|
||||||
QMessageBox.critical(
|
QMessageBox.critical(
|
||||||
@@ -456,17 +872,13 @@ class AnnotationTab(QWidget):
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.current_annotations = self.db_manager.get_annotations_for_image(
|
self.current_annotations = self.db_manager.get_annotations_for_image(self.current_image_id)
|
||||||
self.current_image_id
|
|
||||||
)
|
|
||||||
# New annotations loaded; reset any selection
|
# New annotations loaded; reset any selection
|
||||||
self.selected_annotation_ids = []
|
self.selected_annotation_ids = []
|
||||||
self.annotation_tools.set_has_selected_annotation(False)
|
self.annotation_tools.set_has_selected_annotation(False)
|
||||||
self._redraw_annotations_for_current_filter()
|
self._redraw_annotations_for_current_filter()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(f"Failed to load annotations for image {self.current_image_id}: {e}")
|
||||||
f"Failed to load annotations for image {self.current_image_id}: {e}"
|
|
||||||
)
|
|
||||||
QMessageBox.critical(
|
QMessageBox.critical(
|
||||||
self,
|
self,
|
||||||
"Error",
|
"Error",
|
||||||
@@ -490,10 +902,7 @@ class AnnotationTab(QWidget):
|
|||||||
drawn_count = 0
|
drawn_count = 0
|
||||||
for ann in self.current_annotations:
|
for ann in self.current_annotations:
|
||||||
# Filter by class if one is selected
|
# Filter by class if one is selected
|
||||||
if (
|
if selected_class_id is not None and ann.get("class_id") != selected_class_id:
|
||||||
selected_class_id is not None
|
|
||||||
and ann.get("class_id") != selected_class_id
|
|
||||||
):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if ann.get("segmentation_mask"):
|
if ann.get("segmentation_mask"):
|
||||||
@@ -545,22 +954,185 @@ class AnnotationTab(QWidget):
|
|||||||
settings = QSettings("microscopy_app", "object_detection")
|
settings = QSettings("microscopy_app", "object_detection")
|
||||||
|
|
||||||
# Save main splitter state
|
# Save main splitter state
|
||||||
settings.setValue(
|
settings.setValue("annotation_tab/main_splitter_state", self.main_splitter.saveState())
|
||||||
"annotation_tab/main_splitter_state", self.main_splitter.saveState()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save left splitter state
|
# Save left splitter state
|
||||||
settings.setValue(
|
settings.setValue("annotation_tab/left_splitter_state", self.left_splitter.saveState())
|
||||||
"annotation_tab/left_splitter_state", self.left_splitter.saveState()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save right splitter state
|
# Save right splitter state
|
||||||
settings.setValue(
|
settings.setValue("annotation_tab/right_splitter_state", self.right_splitter.saveState())
|
||||||
"annotation_tab/right_splitter_state", self.right_splitter.saveState()
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug("Saved annotation tab splitter states")
|
logger.debug("Saved annotation tab splitter states")
|
||||||
|
|
||||||
def refresh(self):
|
def refresh(self):
|
||||||
"""Refresh the tab."""
|
"""Refresh the tab."""
|
||||||
pass
|
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
|
||||||
|
|
||||||
|
# ==================== Annotated images list ====================
|
||||||
|
|
||||||
|
def _refresh_annotated_images_list(self, select_image_id: int | None = None) -> None:
|
||||||
|
"""Reload annotated-images list from the database."""
|
||||||
|
if not hasattr(self, "annotated_images_table"):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Preserve selection if possible
|
||||||
|
desired_id = select_image_id if select_image_id is not None else self.current_image_id
|
||||||
|
|
||||||
|
name_filter = ""
|
||||||
|
if hasattr(self, "annotated_filter_edit"):
|
||||||
|
name_filter = self.annotated_filter_edit.text().strip()
|
||||||
|
|
||||||
|
try:
|
||||||
|
rows = self.db_manager.get_images_summary(name_filter=name_filter, source_filter="annotation_tab")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Failed to load images summary: {exc}")
|
||||||
|
rows = []
|
||||||
|
|
||||||
|
sorting_enabled = self.annotated_images_table.isSortingEnabled()
|
||||||
|
self.annotated_images_table.setSortingEnabled(False)
|
||||||
|
self.annotated_images_table.blockSignals(True)
|
||||||
|
try:
|
||||||
|
self.annotated_images_table.setRowCount(len(rows))
|
||||||
|
for r, entry in enumerate(rows):
|
||||||
|
image_name = str(entry.get("filename") or "")
|
||||||
|
count = int(entry.get("annotation_count") or 0)
|
||||||
|
rel_path = str(entry.get("relative_path") or "")
|
||||||
|
|
||||||
|
name_item = QTableWidgetItem(image_name)
|
||||||
|
# Tooltip shows full path of the image (best-effort: repository_root + relative_path)
|
||||||
|
full_path = rel_path
|
||||||
|
repo_root = self.config_manager.get_image_repository_path()
|
||||||
|
if repo_root and rel_path and not Path(rel_path).is_absolute():
|
||||||
|
try:
|
||||||
|
full_path = str((Path(repo_root) / rel_path).resolve())
|
||||||
|
except Exception:
|
||||||
|
full_path = str(Path(repo_root) / rel_path)
|
||||||
|
name_item.setToolTip(full_path)
|
||||||
|
name_item.setData(Qt.UserRole, int(entry.get("id")))
|
||||||
|
name_item.setData(Qt.UserRole + 1, rel_path)
|
||||||
|
|
||||||
|
count_item = QTableWidgetItem()
|
||||||
|
# Use EditRole to ensure numeric sorting.
|
||||||
|
count_item.setData(Qt.EditRole, count)
|
||||||
|
count_item.setData(Qt.UserRole, int(entry.get("id")))
|
||||||
|
count_item.setData(Qt.UserRole + 1, rel_path)
|
||||||
|
|
||||||
|
self.annotated_images_table.setItem(r, 0, name_item)
|
||||||
|
self.annotated_images_table.setItem(r, 1, count_item)
|
||||||
|
|
||||||
|
# Re-select desired row
|
||||||
|
if desired_id is not None:
|
||||||
|
for r in range(self.annotated_images_table.rowCount()):
|
||||||
|
item = self.annotated_images_table.item(r, 0)
|
||||||
|
if item and item.data(Qt.UserRole) == desired_id:
|
||||||
|
self.annotated_images_table.selectRow(r)
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
self.annotated_images_table.blockSignals(False)
|
||||||
|
self.annotated_images_table.setSortingEnabled(sorting_enabled)
|
||||||
|
|
||||||
|
def _on_annotated_image_selected(self) -> None:
|
||||||
|
"""When user clicks an item in the list, load that image in the annotation canvas."""
|
||||||
|
selected = self.annotated_images_table.selectedItems()
|
||||||
|
if not selected:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Row selection -> take the first column item
|
||||||
|
row = self.annotated_images_table.currentRow()
|
||||||
|
item = self.annotated_images_table.item(row, 0)
|
||||||
|
if not item:
|
||||||
|
return
|
||||||
|
|
||||||
|
image_id = item.data(Qt.UserRole)
|
||||||
|
rel_path = item.data(Qt.UserRole + 1) or ""
|
||||||
|
if not image_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
image_path = self._resolve_image_path_for_relative_path(rel_path)
|
||||||
|
if not image_path:
|
||||||
|
QMessageBox.warning(
|
||||||
|
self,
|
||||||
|
"Image Not Found",
|
||||||
|
"Unable to locate image on disk for:\n"
|
||||||
|
f"{rel_path}\n\n"
|
||||||
|
"Tip: set Settings → Image repository path to the folder containing your images.",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.current_image = Image(image_path)
|
||||||
|
self.current_image_path = image_path
|
||||||
|
self.current_image_id = int(image_id)
|
||||||
|
self.annotation_canvas.load_image(self.current_image)
|
||||||
|
self._load_annotations_for_current_image()
|
||||||
|
self._update_image_info()
|
||||||
|
except ImageLoadError as exc:
|
||||||
|
logger.error(f"Failed to load image '{image_path}': {exc}")
|
||||||
|
QMessageBox.critical(self, "Error Loading Image", f"Failed to load image:\n{exc}")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Unexpected error loading image '{image_path}': {exc}")
|
||||||
|
QMessageBox.critical(self, "Error", f"Unexpected error:\n{exc}")
|
||||||
|
|
||||||
|
def _resolve_image_path_for_relative_path(self, relative_path: str) -> str | None:
|
||||||
|
"""Best-effort conversion from a DB relative_path to an on-disk file path."""
|
||||||
|
|
||||||
|
rel = (relative_path or "").strip()
|
||||||
|
if not rel:
|
||||||
|
return None
|
||||||
|
|
||||||
|
candidates: list[Path] = []
|
||||||
|
|
||||||
|
# 1) Repository root + relative
|
||||||
|
repo_root = (self.config_manager.get_image_repository_path() or "").strip()
|
||||||
|
if repo_root:
|
||||||
|
candidates.append(Path(repo_root) / rel)
|
||||||
|
|
||||||
|
# 2) If the DB path is absolute, try it directly.
|
||||||
|
candidates.append(Path(rel))
|
||||||
|
|
||||||
|
# 3) Try the directory of the currently loaded image (helps when DB stores only filenames)
|
||||||
|
if self.current_image_path:
|
||||||
|
try:
|
||||||
|
candidates.append(Path(self.current_image_path).expanduser().resolve().parent / Path(rel).name)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 4) Try the last directory used by the annotation file picker
|
||||||
|
try:
|
||||||
|
settings = QSettings("microscopy_app", "object_detection")
|
||||||
|
last_dir = settings.value("annotation_tab/last_directory", None)
|
||||||
|
if last_dir:
|
||||||
|
candidates.append(Path(str(last_dir)) / Path(rel).name)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 4b) Try the last directory used by the image import picker
|
||||||
|
try:
|
||||||
|
settings = QSettings("microscopy_app", "object_detection")
|
||||||
|
last_import_dir = settings.value("annotation_tab/last_image_import_directory", None)
|
||||||
|
if last_import_dir:
|
||||||
|
candidates.append(Path(str(last_import_dir)) / Path(rel).name)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
for p in candidates:
|
||||||
|
try:
|
||||||
|
expanded = p.expanduser()
|
||||||
|
if expanded.exists() and expanded.is_file():
|
||||||
|
return str(expanded.resolve())
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 5) Fallback: search by filename within repository root.
|
||||||
|
filename = Path(rel).name
|
||||||
|
if repo_root and filename:
|
||||||
|
root = Path(repo_root).expanduser()
|
||||||
|
try:
|
||||||
|
if root.exists():
|
||||||
|
for match in root.rglob(filename):
|
||||||
|
if match.is_file():
|
||||||
|
return str(match.resolve())
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug(f"Search for {filename} under {root} failed: {exc}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ Results tab for browsing stored detections and visualizing overlays.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from PySide6.QtWidgets import (
|
from PySide6.QtWidgets import (
|
||||||
QWidget,
|
QWidget,
|
||||||
@@ -40,6 +40,12 @@ class ResultsTab(QWidget):
|
|||||||
self.db_manager = db_manager
|
self.db_manager = db_manager
|
||||||
self.config_manager = config_manager
|
self.config_manager = config_manager
|
||||||
|
|
||||||
|
# Pagination
|
||||||
|
self.page_size = 200
|
||||||
|
self.current_page = 0 # 0-based
|
||||||
|
self.total_runs = 0
|
||||||
|
self.total_pages = 0
|
||||||
|
|
||||||
self.detection_summary: List[Dict] = []
|
self.detection_summary: List[Dict] = []
|
||||||
self.current_selection: Optional[Dict] = None
|
self.current_selection: Optional[Dict] = None
|
||||||
self.current_image: Optional[Image] = None
|
self.current_image: Optional[Image] = None
|
||||||
@@ -65,6 +71,36 @@ class ResultsTab(QWidget):
|
|||||||
self.refresh_btn = QPushButton("Refresh")
|
self.refresh_btn = QPushButton("Refresh")
|
||||||
self.refresh_btn.clicked.connect(self.refresh)
|
self.refresh_btn.clicked.connect(self.refresh)
|
||||||
controls_layout.addWidget(self.refresh_btn)
|
controls_layout.addWidget(self.refresh_btn)
|
||||||
|
|
||||||
|
self.prev_page_btn = QPushButton("◀ Prev")
|
||||||
|
self.prev_page_btn.setToolTip("Previous page")
|
||||||
|
self.prev_page_btn.clicked.connect(self._prev_page)
|
||||||
|
controls_layout.addWidget(self.prev_page_btn)
|
||||||
|
|
||||||
|
self.next_page_btn = QPushButton("Next ▶")
|
||||||
|
self.next_page_btn.setToolTip("Next page")
|
||||||
|
self.next_page_btn.clicked.connect(self._next_page)
|
||||||
|
controls_layout.addWidget(self.next_page_btn)
|
||||||
|
|
||||||
|
self.page_label = QLabel("Page 1/1")
|
||||||
|
self.page_label.setMinimumWidth(100)
|
||||||
|
controls_layout.addWidget(self.page_label)
|
||||||
|
|
||||||
|
self.delete_all_btn = QPushButton("Delete All Detections")
|
||||||
|
self.delete_all_btn.setToolTip(
|
||||||
|
"Permanently delete ALL detections from the database.\n" "This cannot be undone."
|
||||||
|
)
|
||||||
|
self.delete_all_btn.clicked.connect(self._delete_all_detections)
|
||||||
|
controls_layout.addWidget(self.delete_all_btn)
|
||||||
|
|
||||||
|
self.export_labels_btn = QPushButton("Export Labels")
|
||||||
|
self.export_labels_btn.setToolTip(
|
||||||
|
"Export YOLO .txt labels for the selected image/model run.\n"
|
||||||
|
"Output path is inferred from the image path (images/ -> labels/)."
|
||||||
|
)
|
||||||
|
self.export_labels_btn.clicked.connect(self._export_labels_for_current_selection)
|
||||||
|
controls_layout.addWidget(self.export_labels_btn)
|
||||||
|
|
||||||
controls_layout.addStretch()
|
controls_layout.addStretch()
|
||||||
left_layout.addLayout(controls_layout)
|
left_layout.addLayout(controls_layout)
|
||||||
|
|
||||||
@@ -130,8 +166,45 @@ class ResultsTab(QWidget):
|
|||||||
layout.addWidget(splitter)
|
layout.addWidget(splitter)
|
||||||
self.setLayout(layout)
|
self.setLayout(layout)
|
||||||
|
|
||||||
|
def _delete_all_detections(self):
|
||||||
|
"""Delete all detections from the database after user confirmation."""
|
||||||
|
confirm = QMessageBox.warning(
|
||||||
|
self,
|
||||||
|
"Delete All Detections",
|
||||||
|
"This will permanently delete ALL detections from the database.\n\n"
|
||||||
|
"This action cannot be undone.\n\n"
|
||||||
|
"Do you want to continue?",
|
||||||
|
QMessageBox.Yes | QMessageBox.No,
|
||||||
|
QMessageBox.No,
|
||||||
|
)
|
||||||
|
|
||||||
|
if confirm != QMessageBox.Yes:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
deleted = self.db_manager.delete_all_detections()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Failed to delete all detections: {exc}")
|
||||||
|
QMessageBox.critical(
|
||||||
|
self,
|
||||||
|
"Error",
|
||||||
|
f"Failed to delete detections:\n{exc}",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
QMessageBox.information(
|
||||||
|
self,
|
||||||
|
"Delete All Detections",
|
||||||
|
f"Deleted {deleted} detection(s) from the database.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reset UI state.
|
||||||
|
self.refresh()
|
||||||
|
|
||||||
def refresh(self):
|
def refresh(self):
|
||||||
"""Refresh the detection list and preview."""
|
"""Refresh the detection list and preview."""
|
||||||
|
# Reset to first page on refresh.
|
||||||
|
self.current_page = 0
|
||||||
self._load_detection_summary()
|
self._load_detection_summary()
|
||||||
self._populate_results_table()
|
self._populate_results_table()
|
||||||
self.current_selection = None
|
self.current_selection = None
|
||||||
@@ -139,58 +212,138 @@ class ResultsTab(QWidget):
|
|||||||
self.current_detections = []
|
self.current_detections = []
|
||||||
self.preview_canvas.clear()
|
self.preview_canvas.clear()
|
||||||
self.summary_label.setText("Select a detection result to preview.")
|
self.summary_label.setText("Select a detection result to preview.")
|
||||||
|
if hasattr(self, "export_labels_btn"):
|
||||||
|
self.export_labels_btn.setEnabled(False)
|
||||||
|
|
||||||
|
self._update_pagination_controls()
|
||||||
|
|
||||||
|
def _prev_page(self):
|
||||||
|
"""Go to previous results page."""
|
||||||
|
if self.current_page <= 0:
|
||||||
|
return
|
||||||
|
self.current_page -= 1
|
||||||
|
self._load_detection_summary()
|
||||||
|
self._populate_results_table()
|
||||||
|
self._update_pagination_controls()
|
||||||
|
|
||||||
|
def _next_page(self):
|
||||||
|
"""Go to next results page."""
|
||||||
|
if self.total_pages and self.current_page >= (self.total_pages - 1):
|
||||||
|
return
|
||||||
|
self.current_page += 1
|
||||||
|
self._load_detection_summary()
|
||||||
|
self._populate_results_table()
|
||||||
|
self._update_pagination_controls()
|
||||||
|
|
||||||
|
def _update_pagination_controls(self):
|
||||||
|
"""Update pagination label/button enabled state."""
|
||||||
|
|
||||||
|
# Default state (safe)
|
||||||
|
total_pages = max(int(self.total_pages or 0), 1)
|
||||||
|
current_page = min(max(int(self.current_page or 0), 0), total_pages - 1)
|
||||||
|
self.current_page = current_page
|
||||||
|
|
||||||
|
self.page_label.setText(f"Page {current_page + 1}/{total_pages}")
|
||||||
|
self.prev_page_btn.setEnabled(current_page > 0)
|
||||||
|
self.next_page_btn.setEnabled(current_page < (total_pages - 1))
|
||||||
|
|
||||||
def _load_detection_summary(self):
|
def _load_detection_summary(self):
|
||||||
"""Load latest detection summaries grouped by image + model."""
|
"""Load latest detection summaries grouped by image + model."""
|
||||||
try:
|
try:
|
||||||
detections = self.db_manager.get_detections(limit=500)
|
# Prefer run summaries (supports zero-detection runs). Fall back to legacy
|
||||||
summary_map: Dict[tuple, Dict] = {}
|
# detection aggregation if the DB/table isn't available.
|
||||||
|
try:
|
||||||
|
self.total_runs = int(self.db_manager.get_detection_run_total())
|
||||||
|
except Exception:
|
||||||
|
self.total_runs = 0
|
||||||
|
|
||||||
for det in detections:
|
self.total_pages = (self.total_runs + self.page_size - 1) // self.page_size if self.total_runs > 0 else 1
|
||||||
key = (det["image_id"], det["model_id"])
|
|
||||||
metadata = det.get("metadata") or {}
|
offset = int(self.current_page) * int(self.page_size)
|
||||||
entry = summary_map.setdefault(
|
runs = self.db_manager.get_detection_run_summaries(
|
||||||
key,
|
limit=int(self.page_size),
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
summary: List[Dict] = []
|
||||||
|
for run in runs:
|
||||||
|
meta = run.get("metadata") or {}
|
||||||
|
classes_raw = run.get("classes")
|
||||||
|
classes = set()
|
||||||
|
if isinstance(classes_raw, str) and classes_raw.strip():
|
||||||
|
classes = {c.strip() for c in classes_raw.split(",") if c.strip()}
|
||||||
|
|
||||||
|
summary.append(
|
||||||
{
|
{
|
||||||
"image_id": det["image_id"],
|
"image_id": run.get("image_id"),
|
||||||
"model_id": det["model_id"],
|
"model_id": run.get("model_id"),
|
||||||
"image_path": det.get("image_path"),
|
"image_path": run.get("image_path"),
|
||||||
"image_filename": det.get("image_filename") or det.get("image_path"),
|
"image_filename": run.get("image_filename") or run.get("image_path"),
|
||||||
"model_name": det.get("model_name", ""),
|
"model_name": run.get("model_name", ""),
|
||||||
"model_version": det.get("model_version", ""),
|
"model_version": run.get("model_version", ""),
|
||||||
"last_detected": det.get("detected_at"),
|
"last_detected": run.get("detected_at"),
|
||||||
"count": 0,
|
"count": int(run.get("count") or 0),
|
||||||
"classes": set(),
|
"classes": classes,
|
||||||
"source_path": metadata.get("source_path"),
|
"source_path": meta.get("source_path"),
|
||||||
"repository_root": metadata.get("repository_root"),
|
"repository_root": meta.get("repository_root"),
|
||||||
},
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
entry["count"] += 1
|
self.detection_summary = summary
|
||||||
if det.get("detected_at") and (
|
|
||||||
not entry.get("last_detected") or str(det.get("detected_at")) > str(entry.get("last_detected"))
|
|
||||||
):
|
|
||||||
entry["last_detected"] = det.get("detected_at")
|
|
||||||
if det.get("class_name"):
|
|
||||||
entry["classes"].add(det["class_name"])
|
|
||||||
if metadata.get("source_path") and not entry.get("source_path"):
|
|
||||||
entry["source_path"] = metadata.get("source_path")
|
|
||||||
if metadata.get("repository_root") and not entry.get("repository_root"):
|
|
||||||
entry["repository_root"] = metadata.get("repository_root")
|
|
||||||
|
|
||||||
self.detection_summary = sorted(
|
|
||||||
summary_map.values(),
|
|
||||||
key=lambda x: str(x.get("last_detected") or ""),
|
|
||||||
reverse=True,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to load detection summary: {e}")
|
logger.error(f"Failed to load detection run summaries, falling back: {e}")
|
||||||
QMessageBox.critical(
|
# Disable pagination if we can't page via detection_runs.
|
||||||
self,
|
self.total_runs = 0
|
||||||
"Error",
|
self.total_pages = 1
|
||||||
f"Failed to load detection results:\n{str(e)}",
|
self.current_page = 0
|
||||||
)
|
try:
|
||||||
self.detection_summary = []
|
detections = self.db_manager.get_detections(limit=int(self.page_size))
|
||||||
|
summary_map: Dict[tuple, Dict] = {}
|
||||||
|
|
||||||
|
for det in detections:
|
||||||
|
key = (det["image_id"], det["model_id"])
|
||||||
|
metadata = det.get("metadata") or {}
|
||||||
|
entry = summary_map.setdefault(
|
||||||
|
key,
|
||||||
|
{
|
||||||
|
"image_id": det["image_id"],
|
||||||
|
"model_id": det["model_id"],
|
||||||
|
"image_path": det.get("image_path"),
|
||||||
|
"image_filename": det.get("image_filename") or det.get("image_path"),
|
||||||
|
"model_name": det.get("model_name", ""),
|
||||||
|
"model_version": det.get("model_version", ""),
|
||||||
|
"last_detected": det.get("detected_at"),
|
||||||
|
"count": 0,
|
||||||
|
"classes": set(),
|
||||||
|
"source_path": metadata.get("source_path"),
|
||||||
|
"repository_root": metadata.get("repository_root"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
entry["count"] += 1
|
||||||
|
if det.get("detected_at") and (
|
||||||
|
not entry.get("last_detected") or str(det.get("detected_at")) > str(entry.get("last_detected"))
|
||||||
|
):
|
||||||
|
entry["last_detected"] = det.get("detected_at")
|
||||||
|
if det.get("class_name"):
|
||||||
|
entry["classes"].add(det["class_name"])
|
||||||
|
if metadata.get("source_path") and not entry.get("source_path"):
|
||||||
|
entry["source_path"] = metadata.get("source_path")
|
||||||
|
if metadata.get("repository_root") and not entry.get("repository_root"):
|
||||||
|
entry["repository_root"] = metadata.get("repository_root")
|
||||||
|
|
||||||
|
self.detection_summary = sorted(
|
||||||
|
summary_map.values(),
|
||||||
|
key=lambda x: str(x.get("last_detected") or ""),
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
except Exception as inner:
|
||||||
|
logger.error(f"Failed to load detection summary: {inner}")
|
||||||
|
QMessageBox.critical(
|
||||||
|
self,
|
||||||
|
"Error",
|
||||||
|
f"Failed to load detection results:\n{str(inner)}",
|
||||||
|
)
|
||||||
|
self.detection_summary = []
|
||||||
|
|
||||||
def _populate_results_table(self):
|
def _populate_results_table(self):
|
||||||
"""Populate the table widget with detection summaries."""
|
"""Populate the table widget with detection summaries."""
|
||||||
@@ -258,6 +411,231 @@ class ResultsTab(QWidget):
|
|||||||
self._load_detections_for_selection(entry)
|
self._load_detections_for_selection(entry)
|
||||||
self._apply_detection_overlays()
|
self._apply_detection_overlays()
|
||||||
self._update_summary_label(entry)
|
self._update_summary_label(entry)
|
||||||
|
if hasattr(self, "export_labels_btn"):
|
||||||
|
self.export_labels_btn.setEnabled(True)
|
||||||
|
|
||||||
|
def _export_labels_for_current_selection(self):
|
||||||
|
"""Export YOLO label file(s) for the currently selected image/model."""
|
||||||
|
if not self.current_selection:
|
||||||
|
QMessageBox.information(self, "Export Labels", "Select a detection result first.")
|
||||||
|
return
|
||||||
|
|
||||||
|
entry = self.current_selection
|
||||||
|
|
||||||
|
image_path_str = self._resolve_image_path(entry)
|
||||||
|
if not image_path_str:
|
||||||
|
QMessageBox.warning(
|
||||||
|
self,
|
||||||
|
"Export Labels",
|
||||||
|
"Unable to locate the image file for this detection; cannot infer labels path.",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Ensure we have the detections for the selection.
|
||||||
|
if not self.current_detections:
|
||||||
|
self._load_detections_for_selection(entry)
|
||||||
|
|
||||||
|
if not self.current_detections:
|
||||||
|
QMessageBox.information(
|
||||||
|
self,
|
||||||
|
"Export Labels",
|
||||||
|
"No detections found for this image/model selection.",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
image_path = Path(image_path_str)
|
||||||
|
try:
|
||||||
|
label_path = self._infer_yolo_label_path(image_path)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Failed to infer label path for {image_path}: {exc}")
|
||||||
|
QMessageBox.critical(
|
||||||
|
self,
|
||||||
|
"Export Labels",
|
||||||
|
f"Failed to infer export path for labels:\n{exc}",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
class_map = self._build_detection_class_index_map(self.current_detections)
|
||||||
|
if not class_map:
|
||||||
|
QMessageBox.warning(
|
||||||
|
self,
|
||||||
|
"Export Labels",
|
||||||
|
"Unable to build class->index mapping (missing class names).",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
lines_written = 0
|
||||||
|
skipped = 0
|
||||||
|
label_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
try:
|
||||||
|
with open(label_path, "w", encoding="utf-8") as handle:
|
||||||
|
print("writing to", label_path)
|
||||||
|
for det in self.current_detections:
|
||||||
|
yolo_line = self._format_detection_as_yolo_line(det, class_map)
|
||||||
|
if not yolo_line:
|
||||||
|
skipped += 1
|
||||||
|
continue
|
||||||
|
handle.write(yolo_line + "\n")
|
||||||
|
lines_written += 1
|
||||||
|
except OSError as exc:
|
||||||
|
logger.error(f"Failed to write labels file {label_path}: {exc}")
|
||||||
|
QMessageBox.critical(
|
||||||
|
self,
|
||||||
|
"Export Labels",
|
||||||
|
f"Failed to write label file:\n{label_path}\n\n{exc}",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
return
|
||||||
|
# Optional: write a classes.txt next to the labels root to make the mapping discoverable.
|
||||||
|
# This is not required by Ultralytics (data.yaml usually holds class names), but helps reuse.
|
||||||
|
try:
|
||||||
|
classes_txt = label_path.parent.parent / "classes.txt"
|
||||||
|
classes_txt.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
inv = {idx: name for name, idx in class_map.items()}
|
||||||
|
with open(classes_txt, "w", encoding="utf-8") as handle:
|
||||||
|
for idx in range(len(inv)):
|
||||||
|
handle.write(f"{inv[idx]}\n")
|
||||||
|
except Exception:
|
||||||
|
# Non-fatal
|
||||||
|
pass
|
||||||
|
|
||||||
|
QMessageBox.information(
|
||||||
|
self,
|
||||||
|
"Export Labels",
|
||||||
|
f"Exported {lines_written} label line(s) to:\n{label_path}\n\nSkipped {skipped} invalid detection(s).",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _infer_yolo_label_path(self, image_path: Path) -> Path:
|
||||||
|
"""Infer a YOLO label path from an image path.
|
||||||
|
|
||||||
|
If the image lives under an `images/` directory (anywhere in the path), we mirror the
|
||||||
|
subpath under a sibling `labels/` directory at the same level.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
/dataset/train/images/sub/img.jpg -> /dataset/train/labels/sub/img.txt
|
||||||
|
"""
|
||||||
|
|
||||||
|
resolved = image_path.expanduser().resolve()
|
||||||
|
|
||||||
|
# Find the nearest ancestor directory named 'images'
|
||||||
|
images_dir: Optional[Path] = None
|
||||||
|
for parent in [resolved.parent, *resolved.parents]:
|
||||||
|
if parent.name.lower() == "images":
|
||||||
|
images_dir = parent
|
||||||
|
break
|
||||||
|
|
||||||
|
if images_dir is not None:
|
||||||
|
rel = resolved.relative_to(images_dir)
|
||||||
|
labels_dir = images_dir.parent / "labels"
|
||||||
|
return (labels_dir / rel).with_suffix(".txt")
|
||||||
|
|
||||||
|
# Fallback: create a local sibling labels folder next to the image.
|
||||||
|
return (resolved.parent / "labels" / resolved.name).with_suffix(".txt")
|
||||||
|
|
||||||
|
def _build_detection_class_index_map(self, detections: List[Dict]) -> Dict[str, int]:
|
||||||
|
"""Build a stable class_name -> YOLO class index mapping.
|
||||||
|
|
||||||
|
Preference order:
|
||||||
|
1) Database object_classes table (alphabetical class_name order)
|
||||||
|
2) Fallback to class_name values present in the detections (alphabetical)
|
||||||
|
"""
|
||||||
|
|
||||||
|
names: List[str] = []
|
||||||
|
try:
|
||||||
|
db_classes = self.db_manager.get_object_classes() or []
|
||||||
|
names = [str(row.get("class_name")) for row in db_classes if row.get("class_name")]
|
||||||
|
except Exception:
|
||||||
|
names = []
|
||||||
|
|
||||||
|
if not names:
|
||||||
|
observed = sorted({str(det.get("class_name")) for det in detections if det.get("class_name")})
|
||||||
|
names = list(observed)
|
||||||
|
|
||||||
|
return {name: idx for idx, name in enumerate(names)}
|
||||||
|
|
||||||
|
def _format_detection_as_yolo_line(self, det: Dict, class_map: Dict[str, int]) -> Optional[str]:
|
||||||
|
"""Convert a detection row to a YOLO label line.
|
||||||
|
|
||||||
|
- If segmentation_mask is present, exports segmentation polygon format:
|
||||||
|
class x1 y1 x2 y2 ...
|
||||||
|
(normalized coordinates)
|
||||||
|
- Otherwise exports bbox format:
|
||||||
|
class x_center y_center width height
|
||||||
|
(normalized coordinates)
|
||||||
|
"""
|
||||||
|
|
||||||
|
class_name = det.get("class_name")
|
||||||
|
if not class_name or class_name not in class_map:
|
||||||
|
return None
|
||||||
|
class_idx = class_map[class_name]
|
||||||
|
|
||||||
|
mask = det.get("segmentation_mask")
|
||||||
|
polygon = self._convert_segmentation_mask_to_polygon(mask)
|
||||||
|
if polygon:
|
||||||
|
coords = " ".join(f"{value:.6f}" for value in polygon)
|
||||||
|
return f"{class_idx} {coords}".strip()
|
||||||
|
|
||||||
|
bbox = self._convert_bbox_to_yolo_xywh(det)
|
||||||
|
if bbox is None:
|
||||||
|
return None
|
||||||
|
x_center, y_center, width, height = bbox
|
||||||
|
return f"{class_idx} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
|
||||||
|
|
||||||
|
def _convert_bbox_to_yolo_xywh(self, det: Dict) -> Optional[Tuple[float, float, float, float]]:
|
||||||
|
"""Convert stored xyxy (normalized) bbox to YOLO xywh (normalized)."""
|
||||||
|
|
||||||
|
x_min = det.get("x_min")
|
||||||
|
y_min = det.get("y_min")
|
||||||
|
x_max = det.get("x_max")
|
||||||
|
y_max = det.get("y_max")
|
||||||
|
if any(v is None for v in (x_min, y_min, x_max, y_max)):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
x_min_f = self._clamp01(float(x_min))
|
||||||
|
y_min_f = self._clamp01(float(y_min))
|
||||||
|
x_max_f = self._clamp01(float(x_max))
|
||||||
|
y_max_f = self._clamp01(float(y_max))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
width = max(0.0, x_max_f - x_min_f)
|
||||||
|
height = max(0.0, y_max_f - y_min_f)
|
||||||
|
if width <= 0.0 or height <= 0.0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
x_center = x_min_f + width / 2.0
|
||||||
|
y_center = y_min_f + height / 2.0
|
||||||
|
return x_center, y_center, width, height
|
||||||
|
|
||||||
|
def _convert_segmentation_mask_to_polygon(self, mask_data) -> List[float]:
|
||||||
|
"""Convert stored segmentation_mask [[x,y], ...] to YOLO polygon coords [x1,y1,...]."""
|
||||||
|
|
||||||
|
if not isinstance(mask_data, list):
|
||||||
|
return []
|
||||||
|
|
||||||
|
coords: List[float] = []
|
||||||
|
for point in mask_data:
|
||||||
|
if not isinstance(point, (list, tuple)) or len(point) < 2:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
x = self._clamp01(float(point[0]))
|
||||||
|
y = self._clamp01(float(point[1]))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
coords.extend([x, y])
|
||||||
|
|
||||||
|
# Need at least 3 points => 6 values.
|
||||||
|
return coords if len(coords) >= 6 else []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _clamp01(value: float) -> float:
|
||||||
|
if value < 0.0:
|
||||||
|
return 0.0
|
||||||
|
if value > 1.0:
|
||||||
|
return 1.0
|
||||||
|
return value
|
||||||
|
|
||||||
def _load_detections_for_selection(self, entry: Dict):
|
def _load_detections_for_selection(self, entry: Dict):
|
||||||
"""Load detection records for the selected image/model pair."""
|
"""Load detection records for the selected image/model pair."""
|
||||||
|
|||||||
@@ -905,9 +905,14 @@ class TrainingTab(QWidget):
|
|||||||
if stats["registered_images"]:
|
if stats["registered_images"]:
|
||||||
message += f" {stats['registered_images']} image(s) had database-backed annotations."
|
message += f" {stats['registered_images']} image(s) had database-backed annotations."
|
||||||
if stats["missing_records"]:
|
if stats["missing_records"]:
|
||||||
message += (
|
preserved = stats.get("preserved_existing_labels", 0)
|
||||||
f" {stats['missing_records']} image(s) had no database entry; empty label files were written."
|
if preserved:
|
||||||
)
|
message += (
|
||||||
|
f" {stats['missing_records']} image(s) had no database annotations; "
|
||||||
|
f"preserved {preserved} existing label file(s) (no overwrite)."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
message += f" {stats['missing_records']} image(s) had no database annotations; empty label files were written."
|
||||||
split_messages.append(message)
|
split_messages.append(message)
|
||||||
|
|
||||||
for msg in split_messages:
|
for msg in split_messages:
|
||||||
@@ -929,6 +934,7 @@ class TrainingTab(QWidget):
|
|||||||
processed_images = 0
|
processed_images = 0
|
||||||
registered_images = 0
|
registered_images = 0
|
||||||
missing_records = 0
|
missing_records = 0
|
||||||
|
preserved_existing_labels = 0
|
||||||
total_annotations = 0
|
total_annotations = 0
|
||||||
|
|
||||||
for image_file in images_dir.rglob("*"):
|
for image_file in images_dir.rglob("*"):
|
||||||
@@ -950,6 +956,12 @@ class TrainingTab(QWidget):
|
|||||||
else:
|
else:
|
||||||
missing_records += 1
|
missing_records += 1
|
||||||
|
|
||||||
|
# If the database has no entry for this image, do not overwrite an existing label file
|
||||||
|
# with an empty one (preserve any manually created labels on disk).
|
||||||
|
if not found and label_path.exists():
|
||||||
|
preserved_existing_labels += 1
|
||||||
|
continue
|
||||||
|
|
||||||
annotations_written = 0
|
annotations_written = 0
|
||||||
with open(label_path, "w", encoding="utf-8") as handle:
|
with open(label_path, "w", encoding="utf-8") as handle:
|
||||||
for entry in annotation_entries:
|
for entry in annotation_entries:
|
||||||
@@ -979,6 +991,7 @@ class TrainingTab(QWidget):
|
|||||||
"processed_images": processed_images,
|
"processed_images": processed_images,
|
||||||
"registered_images": registered_images,
|
"registered_images": registered_images,
|
||||||
"missing_records": missing_records,
|
"missing_records": missing_records,
|
||||||
|
"preserved_existing_labels": preserved_existing_labels,
|
||||||
"total_annotations": total_annotations,
|
"total_annotations": total_annotations,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1008,6 +1021,10 @@ class TrainingTab(QWidget):
|
|||||||
resolved_image = image_path.resolve()
|
resolved_image = image_path.resolve()
|
||||||
candidates: List[str] = []
|
candidates: List[str] = []
|
||||||
|
|
||||||
|
# Some DBs store absolute paths in `images.relative_path`.
|
||||||
|
# Include the absolute resolved path as a lookup candidate.
|
||||||
|
candidates.append(resolved_image.as_posix())
|
||||||
|
|
||||||
for base in (dataset_root, images_dir):
|
for base in (dataset_root, images_dir):
|
||||||
try:
|
try:
|
||||||
relative = resolved_image.relative_to(base.resolve()).as_posix()
|
relative = resolved_image.relative_to(base.resolve()).as_posix()
|
||||||
@@ -1032,6 +1049,13 @@ class TrainingTab(QWidget):
|
|||||||
return False, []
|
return False, []
|
||||||
|
|
||||||
annotations = self.db_manager.get_annotations_for_image(image_row["id"]) or []
|
annotations = self.db_manager.get_annotations_for_image(image_row["id"]) or []
|
||||||
|
|
||||||
|
# Treat "found" as "has database-backed annotations".
|
||||||
|
# If the image exists in DB but has no annotations yet, we don't want to overwrite
|
||||||
|
# an existing label file on disk with an empty one.
|
||||||
|
if not annotations:
|
||||||
|
return False, []
|
||||||
|
|
||||||
yolo_entries: List[Dict[str, Any]] = []
|
yolo_entries: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
for ann in annotations:
|
for ann in annotations:
|
||||||
|
|||||||
@@ -2,45 +2,554 @@
|
|||||||
Validation tab for the microscopy object detection application.
|
Validation tab for the microscopy object detection application.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
from PySide6.QtCore import Qt, QSize
|
||||||
|
from PySide6.QtGui import QPainter, QPixmap
|
||||||
|
from PySide6.QtWidgets import (
|
||||||
|
QWidget,
|
||||||
|
QVBoxLayout,
|
||||||
|
QLabel,
|
||||||
|
QGroupBox,
|
||||||
|
QHBoxLayout,
|
||||||
|
QPushButton,
|
||||||
|
QComboBox,
|
||||||
|
QFormLayout,
|
||||||
|
QScrollArea,
|
||||||
|
QGridLayout,
|
||||||
|
QFrame,
|
||||||
|
QTableWidget,
|
||||||
|
QTableWidgetItem,
|
||||||
|
QHeaderView,
|
||||||
|
QSplitter,
|
||||||
|
QListWidget,
|
||||||
|
QListWidgetItem,
|
||||||
|
QAbstractItemView,
|
||||||
|
QGraphicsView,
|
||||||
|
QGraphicsScene,
|
||||||
|
QGraphicsPixmapItem,
|
||||||
|
)
|
||||||
|
|
||||||
from src.database.db_manager import DatabaseManager
|
from src.database.db_manager import DatabaseManager
|
||||||
from src.utils.config_manager import ConfigManager
|
from src.utils.config_manager import ConfigManager
|
||||||
|
from src.utils.logger import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class _PlotItem:
|
||||||
|
label: str
|
||||||
|
path: Path
|
||||||
|
|
||||||
|
|
||||||
|
class _ZoomableImageView(QGraphicsView):
|
||||||
|
"""Zoomable image viewer.
|
||||||
|
|
||||||
|
- Mouse wheel: zoom in/out
|
||||||
|
- Left mouse drag: pan (ScrollHandDrag)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, parent: Optional[QWidget] = None):
|
||||||
|
super().__init__(parent)
|
||||||
|
self._scene = QGraphicsScene(self)
|
||||||
|
self.setScene(self._scene)
|
||||||
|
self._pixmap_item = QGraphicsPixmapItem()
|
||||||
|
self._scene.addItem(self._pixmap_item)
|
||||||
|
|
||||||
|
# QGraphicsView render hints are QPainter.RenderHints.
|
||||||
|
self.setRenderHints(self.renderHints() | QPainter.RenderHint.SmoothPixmapTransform)
|
||||||
|
self.setDragMode(QGraphicsView.DragMode.ScrollHandDrag)
|
||||||
|
self.setTransformationAnchor(QGraphicsView.ViewportAnchor.AnchorUnderMouse)
|
||||||
|
self.setResizeAnchor(QGraphicsView.ViewportAnchor.AnchorUnderMouse)
|
||||||
|
|
||||||
|
self._has_pixmap = False
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
self._pixmap_item.setPixmap(QPixmap())
|
||||||
|
self._scene.setSceneRect(0, 0, 1, 1)
|
||||||
|
self.resetTransform()
|
||||||
|
self._has_pixmap = False
|
||||||
|
|
||||||
|
def set_pixmap(self, pixmap: QPixmap, *, fit: bool = True) -> None:
|
||||||
|
self._pixmap_item.setPixmap(pixmap)
|
||||||
|
self._scene.setSceneRect(pixmap.rect())
|
||||||
|
self._has_pixmap = not pixmap.isNull()
|
||||||
|
self.resetTransform()
|
||||||
|
if fit and self._has_pixmap:
|
||||||
|
self.fitInView(self._pixmap_item, Qt.AspectRatioMode.KeepAspectRatio)
|
||||||
|
|
||||||
|
def wheelEvent(self, event) -> None: # type: ignore[override]
|
||||||
|
if not self._has_pixmap:
|
||||||
|
return
|
||||||
|
zoom_in_factor = 1.25
|
||||||
|
zoom_out_factor = 1.0 / zoom_in_factor
|
||||||
|
factor = zoom_in_factor if event.angleDelta().y() > 0 else zoom_out_factor
|
||||||
|
self.scale(factor, factor)
|
||||||
|
|
||||||
|
|
||||||
class ValidationTab(QWidget):
|
class ValidationTab(QWidget):
|
||||||
"""Validation tab placeholder."""
|
"""Validation tab that shows stored validation metrics + plots for a selected model."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
|
||||||
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
|
||||||
):
|
|
||||||
super().__init__(parent)
|
super().__init__(parent)
|
||||||
self.db_manager = db_manager
|
self.db_manager = db_manager
|
||||||
self.config_manager = config_manager
|
self.config_manager = config_manager
|
||||||
|
|
||||||
|
self._models: List[Dict[str, Any]] = []
|
||||||
|
self._selected_model_id: Optional[int] = None
|
||||||
|
self._plot_widgets: List[QWidget] = []
|
||||||
|
self._plot_items: List[_PlotItem] = []
|
||||||
|
|
||||||
self._setup_ui()
|
self._setup_ui()
|
||||||
|
self.refresh()
|
||||||
|
|
||||||
def _setup_ui(self):
|
def _setup_ui(self):
|
||||||
"""Setup user interface."""
|
"""Setup user interface."""
|
||||||
layout = QVBoxLayout()
|
layout = QVBoxLayout(self)
|
||||||
|
|
||||||
group = QGroupBox("Validation")
|
# ===== Header controls =====
|
||||||
group_layout = QVBoxLayout()
|
header = QGroupBox("Validation")
|
||||||
label = QLabel(
|
header_layout = QVBoxLayout()
|
||||||
"Validation functionality will be implemented here.\n\n"
|
header_row = QHBoxLayout()
|
||||||
"Features:\n"
|
|
||||||
"- Model validation\n"
|
|
||||||
"- Metrics visualization\n"
|
|
||||||
"- Confusion matrix\n"
|
|
||||||
"- Precision-Recall curves"
|
|
||||||
)
|
|
||||||
group_layout.addWidget(label)
|
|
||||||
group.setLayout(group_layout)
|
|
||||||
|
|
||||||
layout.addWidget(group)
|
header_row.addWidget(QLabel("Select model:"))
|
||||||
layout.addStretch()
|
|
||||||
self.setLayout(layout)
|
self.model_combo = QComboBox()
|
||||||
|
self.model_combo.setMinimumWidth(420)
|
||||||
|
self.model_combo.currentIndexChanged.connect(self._on_model_selected)
|
||||||
|
header_row.addWidget(self.model_combo, 1)
|
||||||
|
|
||||||
|
self.refresh_btn = QPushButton("Refresh")
|
||||||
|
self.refresh_btn.clicked.connect(self.refresh)
|
||||||
|
header_row.addWidget(self.refresh_btn)
|
||||||
|
header_row.addStretch()
|
||||||
|
|
||||||
|
header_layout.addLayout(header_row)
|
||||||
|
self.header_status = QLabel("No models loaded.")
|
||||||
|
self.header_status.setWordWrap(True)
|
||||||
|
header_layout.addWidget(self.header_status)
|
||||||
|
header.setLayout(header_layout)
|
||||||
|
layout.addWidget(header)
|
||||||
|
|
||||||
|
# ===== Metrics =====
|
||||||
|
metrics_group = QGroupBox("Validation Metrics")
|
||||||
|
metrics_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
self.metrics_form = QFormLayout()
|
||||||
|
self.metric_labels: Dict[str, QLabel] = {}
|
||||||
|
for key in ("mAP50", "mAP50-95", "precision", "recall", "fitness"):
|
||||||
|
value_label = QLabel("–")
|
||||||
|
value_label.setTextInteractionFlags(Qt.TextSelectableByMouse)
|
||||||
|
self.metric_labels[key] = value_label
|
||||||
|
self.metrics_form.addRow(f"{key}:", value_label)
|
||||||
|
metrics_layout.addLayout(self.metrics_form)
|
||||||
|
|
||||||
|
self.per_class_table = QTableWidget(0, 3)
|
||||||
|
self.per_class_table.setHorizontalHeaderLabels(["Class", "AP", "AP50"])
|
||||||
|
self.per_class_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Stretch)
|
||||||
|
self.per_class_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeToContents)
|
||||||
|
self.per_class_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeToContents)
|
||||||
|
self.per_class_table.setEditTriggers(QTableWidget.NoEditTriggers)
|
||||||
|
self.per_class_table.setMinimumHeight(160)
|
||||||
|
metrics_layout.addWidget(QLabel("Per-class metrics (if available):"))
|
||||||
|
metrics_layout.addWidget(self.per_class_table)
|
||||||
|
|
||||||
|
metrics_group.setLayout(metrics_layout)
|
||||||
|
layout.addWidget(metrics_group)
|
||||||
|
|
||||||
|
# ===== Plots =====
|
||||||
|
plots_group = QGroupBox("Validation Plots")
|
||||||
|
plots_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
self.plots_status = QLabel("Select a model to see validation plots.")
|
||||||
|
self.plots_status.setWordWrap(True)
|
||||||
|
plots_layout.addWidget(self.plots_status)
|
||||||
|
|
||||||
|
self.plots_splitter = QSplitter(Qt.Orientation.Horizontal)
|
||||||
|
|
||||||
|
# Left: selected image viewer
|
||||||
|
left_widget = QWidget()
|
||||||
|
left_layout = QVBoxLayout(left_widget)
|
||||||
|
left_layout.setContentsMargins(0, 0, 0, 0)
|
||||||
|
|
||||||
|
self.selected_plot_title = QLabel("No image selected.")
|
||||||
|
self.selected_plot_title.setWordWrap(True)
|
||||||
|
self.selected_plot_title.setTextInteractionFlags(Qt.TextSelectableByMouse)
|
||||||
|
left_layout.addWidget(self.selected_plot_title)
|
||||||
|
|
||||||
|
self.plot_view = _ZoomableImageView()
|
||||||
|
self.plot_view.setMinimumHeight(360)
|
||||||
|
left_layout.addWidget(self.plot_view, 1)
|
||||||
|
|
||||||
|
self.selected_plot_path = QLabel("")
|
||||||
|
self.selected_plot_path.setWordWrap(True)
|
||||||
|
self.selected_plot_path.setStyleSheet("color: #888;")
|
||||||
|
self.selected_plot_path.setTextInteractionFlags(Qt.TextSelectableByMouse)
|
||||||
|
left_layout.addWidget(self.selected_plot_path)
|
||||||
|
|
||||||
|
# Right: scrollable list
|
||||||
|
right_widget = QWidget()
|
||||||
|
right_layout = QVBoxLayout(right_widget)
|
||||||
|
right_layout.setContentsMargins(0, 0, 0, 0)
|
||||||
|
right_layout.addWidget(QLabel("Images:"))
|
||||||
|
|
||||||
|
self.plots_list = QListWidget()
|
||||||
|
self.plots_list.setSelectionMode(QAbstractItemView.SelectionMode.SingleSelection)
|
||||||
|
self.plots_list.setIconSize(QSize(160, 160))
|
||||||
|
self.plots_list.itemSelectionChanged.connect(self._on_plot_item_selected)
|
||||||
|
right_layout.addWidget(self.plots_list, 1)
|
||||||
|
|
||||||
|
self.plots_splitter.addWidget(left_widget)
|
||||||
|
self.plots_splitter.addWidget(right_widget)
|
||||||
|
self.plots_splitter.setStretchFactor(0, 3)
|
||||||
|
self.plots_splitter.setStretchFactor(1, 1)
|
||||||
|
plots_layout.addWidget(self.plots_splitter, 1)
|
||||||
|
|
||||||
|
plots_group.setLayout(plots_layout)
|
||||||
|
layout.addWidget(plots_group, 1)
|
||||||
|
|
||||||
|
layout.addStretch(0)
|
||||||
|
|
||||||
|
self._clear_metrics()
|
||||||
|
self._clear_plots()
|
||||||
|
|
||||||
|
# ==================== Public API ====================
|
||||||
|
|
||||||
def refresh(self):
|
def refresh(self):
|
||||||
"""Refresh the tab."""
|
"""Refresh the tab."""
|
||||||
pass
|
self._load_models()
|
||||||
|
self._populate_model_combo()
|
||||||
|
self._restore_or_select_default_model()
|
||||||
|
|
||||||
|
# ==================== Internal: models ====================
|
||||||
|
|
||||||
|
def _load_models(self) -> None:
|
||||||
|
try:
|
||||||
|
self._models = self.db_manager.get_models() or []
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to load models: %s", exc)
|
||||||
|
self._models = []
|
||||||
|
|
||||||
|
def _populate_model_combo(self) -> None:
|
||||||
|
self.model_combo.blockSignals(True)
|
||||||
|
self.model_combo.clear()
|
||||||
|
self.model_combo.addItem("Select a model…", None)
|
||||||
|
|
||||||
|
for model in self._models:
|
||||||
|
model_id = model.get("id")
|
||||||
|
name = (model.get("model_name") or "").strip()
|
||||||
|
version = (model.get("model_version") or "").strip()
|
||||||
|
created_at = model.get("created_at")
|
||||||
|
label = f"{name} {version}".strip()
|
||||||
|
if created_at:
|
||||||
|
label = f"{label} ({created_at})"
|
||||||
|
self.model_combo.addItem(label, model_id)
|
||||||
|
|
||||||
|
self.model_combo.blockSignals(False)
|
||||||
|
|
||||||
|
if self._models:
|
||||||
|
self.header_status.setText(f"Loaded {len(self._models)} model(s).")
|
||||||
|
else:
|
||||||
|
self.header_status.setText("No models found. Train a model first.")
|
||||||
|
|
||||||
|
def _restore_or_select_default_model(self) -> None:
|
||||||
|
if not self._models:
|
||||||
|
self._selected_model_id = None
|
||||||
|
self._clear_metrics()
|
||||||
|
self._clear_plots()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Keep selection if still present.
|
||||||
|
if self._selected_model_id is not None:
|
||||||
|
for idx in range(1, self.model_combo.count()):
|
||||||
|
if self.model_combo.itemData(idx) == self._selected_model_id:
|
||||||
|
self.model_combo.setCurrentIndex(idx)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Otherwise select the newest model (top of get_models ORDER BY created_at DESC).
|
||||||
|
first_model_id = self.model_combo.itemData(1) if self.model_combo.count() > 1 else None
|
||||||
|
if first_model_id is not None:
|
||||||
|
self.model_combo.setCurrentIndex(1)
|
||||||
|
|
||||||
|
def _on_model_selected(self, index: int) -> None:
|
||||||
|
model_id = self.model_combo.itemData(index)
|
||||||
|
if not model_id:
|
||||||
|
self._selected_model_id = None
|
||||||
|
self._clear_metrics()
|
||||||
|
self._clear_plots()
|
||||||
|
self.plots_status.setText("Select a model to see validation plots.")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._selected_model_id = int(model_id)
|
||||||
|
model = self._get_model_by_id(self._selected_model_id)
|
||||||
|
if not model:
|
||||||
|
self._clear_metrics()
|
||||||
|
self._clear_plots()
|
||||||
|
self.plots_status.setText("Selected model not found.")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._render_metrics(model)
|
||||||
|
self._render_plots(model)
|
||||||
|
|
||||||
|
def _get_model_by_id(self, model_id: int) -> Optional[Dict[str, Any]]:
|
||||||
|
for model in self._models:
|
||||||
|
if model.get("id") == model_id:
|
||||||
|
return model
|
||||||
|
try:
|
||||||
|
return self.db_manager.get_model_by_id(model_id)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ==================== Internal: metrics ====================
|
||||||
|
|
||||||
|
def _clear_metrics(self) -> None:
|
||||||
|
for label in self.metric_labels.values():
|
||||||
|
label.setText("–")
|
||||||
|
self.per_class_table.setRowCount(0)
|
||||||
|
|
||||||
|
def _render_metrics(self, model: Dict[str, Any]) -> None:
|
||||||
|
self._clear_metrics()
|
||||||
|
|
||||||
|
metrics: Dict[str, Any] = model.get("metrics") or {}
|
||||||
|
# Training tab stores metrics under results['metrics'] in training results payload.
|
||||||
|
if isinstance(metrics, dict) and "metrics" in metrics and isinstance(metrics.get("metrics"), dict):
|
||||||
|
metrics = metrics.get("metrics") or {}
|
||||||
|
|
||||||
|
def set_metric(key: str, value: Any) -> None:
|
||||||
|
if key not in self.metric_labels:
|
||||||
|
return
|
||||||
|
if value is None:
|
||||||
|
self.metric_labels[key].setText("–")
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
self.metric_labels[key].setText(f"{float(value):.4f}")
|
||||||
|
except Exception:
|
||||||
|
self.metric_labels[key].setText(str(value))
|
||||||
|
|
||||||
|
set_metric("mAP50", metrics.get("mAP50"))
|
||||||
|
set_metric("mAP50-95", metrics.get("mAP50-95") or metrics.get("mAP50_95") or metrics.get("mAP50-95"))
|
||||||
|
set_metric("precision", metrics.get("precision"))
|
||||||
|
set_metric("recall", metrics.get("recall"))
|
||||||
|
set_metric("fitness", metrics.get("fitness"))
|
||||||
|
|
||||||
|
# Optional per-class metrics
|
||||||
|
class_metrics = metrics.get("class_metrics") if isinstance(metrics, dict) else None
|
||||||
|
if isinstance(class_metrics, dict) and class_metrics:
|
||||||
|
items = sorted(class_metrics.items(), key=lambda kv: str(kv[0]))
|
||||||
|
self.per_class_table.setRowCount(len(items))
|
||||||
|
for row, (cls_name, cls_stats) in enumerate(items):
|
||||||
|
ap = (cls_stats or {}).get("ap")
|
||||||
|
ap50 = (cls_stats or {}).get("ap50")
|
||||||
|
self.per_class_table.setItem(row, 0, QTableWidgetItem(str(cls_name)))
|
||||||
|
self.per_class_table.setItem(row, 1, QTableWidgetItem(self._format_float(ap)))
|
||||||
|
self.per_class_table.setItem(row, 2, QTableWidgetItem(self._format_float(ap50)))
|
||||||
|
else:
|
||||||
|
self.per_class_table.setRowCount(0)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _format_float(value: Any) -> str:
|
||||||
|
if value is None:
|
||||||
|
return "–"
|
||||||
|
try:
|
||||||
|
return f"{float(value):.4f}"
|
||||||
|
except Exception:
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
# ==================== Internal: plots ====================
|
||||||
|
|
||||||
|
def _clear_plots(self) -> None:
|
||||||
|
# Remove legacy grid widgets (from the initial implementation).
|
||||||
|
for widget in self._plot_widgets:
|
||||||
|
widget.setParent(None)
|
||||||
|
widget.deleteLater()
|
||||||
|
self._plot_widgets = []
|
||||||
|
|
||||||
|
self._plot_items = []
|
||||||
|
|
||||||
|
if hasattr(self, "plots_list"):
|
||||||
|
self.plots_list.blockSignals(True)
|
||||||
|
self.plots_list.clear()
|
||||||
|
self.plots_list.blockSignals(False)
|
||||||
|
|
||||||
|
if hasattr(self, "plot_view"):
|
||||||
|
self.plot_view.clear()
|
||||||
|
if hasattr(self, "selected_plot_title"):
|
||||||
|
self.selected_plot_title.setText("No image selected.")
|
||||||
|
if hasattr(self, "selected_plot_path"):
|
||||||
|
self.selected_plot_path.setText("")
|
||||||
|
|
||||||
|
def _render_plots(self, model: Dict[str, Any]) -> None:
|
||||||
|
self._clear_plots()
|
||||||
|
|
||||||
|
plot_dirs = self._infer_run_directories(model)
|
||||||
|
plot_items = self._discover_plot_items(plot_dirs)
|
||||||
|
|
||||||
|
if not plot_items:
|
||||||
|
dirs_text = "\n".join(str(p) for p in plot_dirs if p)
|
||||||
|
self.plots_status.setText(
|
||||||
|
"No validation plot images found for this model.\n\n"
|
||||||
|
"Searched directories:\n" + (dirs_text or "(none)")
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
self._plot_items = list(plot_items)
|
||||||
|
self.plots_status.setText(f"Found {len(plot_items)} plot image(s). Select one to view/zoom.")
|
||||||
|
|
||||||
|
self.plots_list.blockSignals(True)
|
||||||
|
self.plots_list.clear()
|
||||||
|
for idx, item in enumerate(self._plot_items):
|
||||||
|
qitem = QListWidgetItem(item.label)
|
||||||
|
qitem.setData(Qt.ItemDataRole.UserRole, idx)
|
||||||
|
|
||||||
|
pix = QPixmap(str(item.path))
|
||||||
|
if not pix.isNull():
|
||||||
|
thumb = pix.scaled(
|
||||||
|
self.plots_list.iconSize(),
|
||||||
|
Qt.AspectRatioMode.KeepAspectRatio,
|
||||||
|
Qt.TransformationMode.SmoothTransformation,
|
||||||
|
)
|
||||||
|
qitem.setIcon(thumb)
|
||||||
|
self.plots_list.addItem(qitem)
|
||||||
|
self.plots_list.blockSignals(False)
|
||||||
|
|
||||||
|
if self.plots_list.count() > 0:
|
||||||
|
self.plots_list.setCurrentRow(0)
|
||||||
|
|
||||||
|
def _on_plot_item_selected(self) -> None:
|
||||||
|
if not self._plot_items:
|
||||||
|
return
|
||||||
|
|
||||||
|
selected = self.plots_list.selectedItems()
|
||||||
|
if not selected:
|
||||||
|
return
|
||||||
|
|
||||||
|
idx = selected[0].data(Qt.ItemDataRole.UserRole)
|
||||||
|
try:
|
||||||
|
idx_int = int(idx)
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
if idx_int < 0 or idx_int >= len(self._plot_items):
|
||||||
|
return
|
||||||
|
|
||||||
|
plot = self._plot_items[idx_int]
|
||||||
|
self.selected_plot_title.setText(plot.label)
|
||||||
|
self.selected_plot_path.setText(str(plot.path))
|
||||||
|
|
||||||
|
pix = QPixmap(str(plot.path))
|
||||||
|
if pix.isNull():
|
||||||
|
self.plot_view.clear()
|
||||||
|
return
|
||||||
|
self.plot_view.set_pixmap(pix, fit=True)
|
||||||
|
|
||||||
|
def _infer_run_directories(self, model: Dict[str, Any]) -> List[Path]:
|
||||||
|
dirs: List[Path] = []
|
||||||
|
|
||||||
|
# 1) Infer from model_path: .../<run>/weights/best.pt -> <run>
|
||||||
|
model_path = model.get("model_path")
|
||||||
|
if model_path:
|
||||||
|
try:
|
||||||
|
p = Path(str(model_path)).expanduser()
|
||||||
|
if p.name.lower().endswith(".pt"):
|
||||||
|
# If it lives under weights/, use parent.parent.
|
||||||
|
if p.parent.name == "weights" and p.parent.parent.exists():
|
||||||
|
dirs.append(p.parent.parent)
|
||||||
|
elif p.parent.exists():
|
||||||
|
dirs.append(p.parent)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 2) Look at training_params.stage_results[].results.save_dir
|
||||||
|
training_params = model.get("training_params") or {}
|
||||||
|
stage_results = None
|
||||||
|
if isinstance(training_params, dict):
|
||||||
|
stage_results = training_params.get("stage_results")
|
||||||
|
if isinstance(stage_results, list):
|
||||||
|
for stage in stage_results:
|
||||||
|
results = (stage or {}).get("results")
|
||||||
|
save_dir = (results or {}).get("save_dir") if isinstance(results, dict) else None
|
||||||
|
if save_dir:
|
||||||
|
try:
|
||||||
|
save_path = Path(str(save_dir)).expanduser()
|
||||||
|
if save_path.exists():
|
||||||
|
dirs.append(save_path)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Deduplicate while preserving order.
|
||||||
|
unique: List[Path] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for d in dirs:
|
||||||
|
try:
|
||||||
|
resolved = str(d.resolve())
|
||||||
|
except Exception:
|
||||||
|
resolved = str(d)
|
||||||
|
if resolved not in seen and d.exists() and d.is_dir():
|
||||||
|
seen.add(resolved)
|
||||||
|
unique.append(d)
|
||||||
|
return unique
|
||||||
|
|
||||||
|
def _discover_plot_items(self, directories: Sequence[Path]) -> List[_PlotItem]:
|
||||||
|
# Prefer canonical Ultralytics filenames first, then fall back to any png/jpg.
|
||||||
|
preferred_names = [
|
||||||
|
"results.png",
|
||||||
|
"results.jpg",
|
||||||
|
"confusion_matrix.png",
|
||||||
|
"confusion_matrix_normalized.png",
|
||||||
|
"labels.jpg",
|
||||||
|
"labels.png",
|
||||||
|
"BoxPR_curve.png",
|
||||||
|
"BoxP_curve.png",
|
||||||
|
"BoxR_curve.png",
|
||||||
|
"BoxF1_curve.png",
|
||||||
|
"MaskPR_curve.png",
|
||||||
|
"MaskP_curve.png",
|
||||||
|
"MaskR_curve.png",
|
||||||
|
"MaskF1_curve.png",
|
||||||
|
"val_batch0_pred.jpg",
|
||||||
|
"val_batch0_labels.jpg",
|
||||||
|
]
|
||||||
|
|
||||||
|
found: List[_PlotItem] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
|
||||||
|
for d in directories:
|
||||||
|
# 1) Preferred
|
||||||
|
for name in preferred_names:
|
||||||
|
p = d / name
|
||||||
|
if p.exists() and p.is_file():
|
||||||
|
key = str(p)
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
found.append(_PlotItem(label=f"{name} (from {d.name})", path=p))
|
||||||
|
|
||||||
|
# 2) Curated globs
|
||||||
|
for pattern in ("train_batch*.jpg", "val_batch*.jpg", "*curve*.png"):
|
||||||
|
for p in sorted(d.glob(pattern)):
|
||||||
|
if not p.is_file():
|
||||||
|
continue
|
||||||
|
key = str(p)
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
found.append(_PlotItem(label=f"{p.name} (from {d.name})", path=p))
|
||||||
|
|
||||||
|
# 3) Fallback: any top-level png/jpg (excluding weights dir contents)
|
||||||
|
for ext in ("*.png", "*.jpg", "*.jpeg", "*.webp"):
|
||||||
|
for p in sorted(d.glob(ext)):
|
||||||
|
if not p.is_file():
|
||||||
|
continue
|
||||||
|
key = str(p)
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
found.append(_PlotItem(label=f"{p.name} (from {d.name})", path=p))
|
||||||
|
|
||||||
|
# Keep list bounded to avoid UI overload for huge runs.
|
||||||
|
return found[:60]
|
||||||
|
|||||||
@@ -84,25 +84,19 @@ class InferenceEngine:
|
|||||||
|
|
||||||
# Save detections to database, replacing any previous results for this image/model
|
# Save detections to database, replacing any previous results for this image/model
|
||||||
if save_to_db:
|
if save_to_db:
|
||||||
deleted_count = self.db_manager.delete_detections_for_image(
|
deleted_count = self.db_manager.delete_detections_for_image(image_id, self.model_id)
|
||||||
image_id, self.model_id
|
|
||||||
)
|
|
||||||
if detections:
|
if detections:
|
||||||
detection_records = []
|
detection_records = []
|
||||||
for det in detections:
|
for det in detections:
|
||||||
# Use normalized bbox from detection
|
# Use normalized bbox from detection
|
||||||
bbox_normalized = det[
|
bbox_normalized = det["bbox_normalized"] # [x_min, y_min, x_max, y_max]
|
||||||
"bbox_normalized"
|
|
||||||
] # [x_min, y_min, x_max, y_max]
|
|
||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
"class_id": det["class_id"],
|
"class_id": det["class_id"],
|
||||||
"source_path": str(Path(image_path).resolve()),
|
"source_path": str(Path(image_path).resolve()),
|
||||||
}
|
}
|
||||||
if repository_root:
|
if repository_root:
|
||||||
metadata["repository_root"] = str(
|
metadata["repository_root"] = str(Path(repository_root).resolve())
|
||||||
Path(repository_root).resolve()
|
|
||||||
)
|
|
||||||
|
|
||||||
record = {
|
record = {
|
||||||
"image_id": image_id,
|
"image_id": image_id,
|
||||||
@@ -115,16 +109,27 @@ class InferenceEngine:
|
|||||||
}
|
}
|
||||||
detection_records.append(record)
|
detection_records.append(record)
|
||||||
|
|
||||||
inserted_count = self.db_manager.add_detections_batch(
|
inserted_count = self.db_manager.add_detections_batch(detection_records)
|
||||||
detection_records
|
logger.info(f"Saved {inserted_count} detections to database (replaced {deleted_count})")
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Saved {inserted_count} detections to database (replaced {deleted_count})"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(f"Detection run removed {deleted_count} stale entries but produced no new detections")
|
||||||
f"Detection run removed {deleted_count} stale entries but produced no new detections"
|
|
||||||
|
# Always store a run summary so the Results tab can show zero-detection runs.
|
||||||
|
try:
|
||||||
|
run_metadata = {
|
||||||
|
"source_path": str(Path(image_path).resolve()),
|
||||||
|
}
|
||||||
|
if repository_root:
|
||||||
|
run_metadata["repository_root"] = str(Path(repository_root).resolve())
|
||||||
|
self.db_manager.upsert_detection_run(
|
||||||
|
image_id=image_id,
|
||||||
|
model_id=self.model_id,
|
||||||
|
count=len(detections),
|
||||||
|
metadata=run_metadata,
|
||||||
)
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
# Non-fatal: detection records may still be present.
|
||||||
|
logger.warning(f"Failed to store detection run summary: {exc}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
@@ -232,9 +237,7 @@ class InferenceEngine:
|
|||||||
for det in detections:
|
for det in detections:
|
||||||
# Get color for this class
|
# Get color for this class
|
||||||
class_name = det["class_name"]
|
class_name = det["class_name"]
|
||||||
color_hex = bbox_colors.get(
|
color_hex = bbox_colors.get(class_name, bbox_colors.get("default", "#00FF00"))
|
||||||
class_name, bbox_colors.get("default", "#00FF00")
|
|
||||||
)
|
|
||||||
color = self._hex_to_bgr(color_hex)
|
color = self._hex_to_bgr(color_hex)
|
||||||
|
|
||||||
# Draw segmentation mask if available and requested
|
# Draw segmentation mask if available and requested
|
||||||
@@ -243,10 +246,7 @@ class InferenceEngine:
|
|||||||
if mask_normalized and len(mask_normalized) > 0:
|
if mask_normalized and len(mask_normalized) > 0:
|
||||||
# Convert normalized coordinates to absolute pixels
|
# Convert normalized coordinates to absolute pixels
|
||||||
mask_points = np.array(
|
mask_points = np.array(
|
||||||
[
|
[[int(pt[0] * width), int(pt[1] * height)] for pt in mask_normalized],
|
||||||
[int(pt[0] * width), int(pt[1] * height)]
|
|
||||||
for pt in mask_normalized
|
|
||||||
],
|
|
||||||
dtype=np.int32,
|
dtype=np.int32,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -270,9 +270,7 @@ class InferenceEngine:
|
|||||||
label = f"{class_name} {det['confidence']:.2f}"
|
label = f"{class_name} {det['confidence']:.2f}"
|
||||||
|
|
||||||
# Draw label background
|
# Draw label background
|
||||||
(label_w, label_h), baseline = cv2.getTextSize(
|
(label_w, label_h), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
||||||
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
|
|
||||||
)
|
|
||||||
cv2.rectangle(
|
cv2.rectangle(
|
||||||
img,
|
img,
|
||||||
(x1, y1 - label_h - baseline - 5),
|
(x1, y1 - label_h - baseline - 5),
|
||||||
|
|||||||
@@ -186,7 +186,7 @@ class YOLOWrapper:
|
|||||||
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
||||||
|
|
||||||
prepared_source, cleanup_path = self._prepare_source(source)
|
prepared_source, cleanup_path = self._prepare_source(source)
|
||||||
imgsz = 1088
|
imgsz = 640 # 1088
|
||||||
try:
|
try:
|
||||||
logger.info(f"Running inference on {source} -> prepared_source {prepared_source}")
|
logger.info(f"Running inference on {source} -> prepared_source {prepared_source}")
|
||||||
results = self.model.predict(
|
results = self.model.predict(
|
||||||
|
|||||||
103
src/utils/create_mask_from_detection.py
Normal file
103
src/utils/create_mask_from_detection.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from skimage.draw import polygon
|
||||||
|
from tifffile import TiffFile
|
||||||
|
|
||||||
|
from src.database.db_manager import DatabaseManager
|
||||||
|
|
||||||
|
|
||||||
|
def read_image(image_path: Path) -> np.ndarray:
|
||||||
|
metadata = {}
|
||||||
|
with TiffFile(image_path) as tif:
|
||||||
|
image = tif.asarray()
|
||||||
|
metadata = tif.imagej_metadata
|
||||||
|
return image, metadata
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
polygon_vertices = np.array([[10, 10], [50, 10], [50, 50], [10, 50]])
|
||||||
|
image = np.zeros((100, 100), dtype=np.uint8)
|
||||||
|
rr, cc = polygon(polygon_vertices[:, 0], polygon_vertices[:, 1])
|
||||||
|
image[rr, cc] = 255
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
db = DatabaseManager()
|
||||||
|
model_name = "c17"
|
||||||
|
model_id = db.get_models(filters={"model_name": model_name})[0]["id"]
|
||||||
|
print(f"Model name {model_name}, id {model_id}")
|
||||||
|
detections = db.get_detections(filters={"model_id": model_id})
|
||||||
|
|
||||||
|
file_stems = set()
|
||||||
|
|
||||||
|
for detection in detections:
|
||||||
|
file_stems.add(detection["image_filename"].split("_")[0])
|
||||||
|
|
||||||
|
print("Files:", file_stems)
|
||||||
|
|
||||||
|
for stem in file_stems:
|
||||||
|
print(stem)
|
||||||
|
detections = db.get_detections(filters={"model_id": model_id, "i.filename": f"LIKE %{stem}%"})
|
||||||
|
annotations = []
|
||||||
|
for detection in detections:
|
||||||
|
source_path = Path(detection["metadata"]["source_path"])
|
||||||
|
image, metadata = read_image(source_path)
|
||||||
|
|
||||||
|
offset = np.array(list(map(int, metadata["tile_section"].split(","))))[::-1]
|
||||||
|
scale = np.array(list(map(int, metadata["patch_size"].split(","))))[::-1]
|
||||||
|
# tile_size = np.array(list(map(int, metadata["tile_size"].split(","))))
|
||||||
|
segmentation = np.array(detection["segmentation_mask"]) # * tile_size
|
||||||
|
|
||||||
|
# print(source_path, image, metadata, segmentation.shape)
|
||||||
|
# print(offset)
|
||||||
|
# print(scale)
|
||||||
|
# print(segmentation)
|
||||||
|
|
||||||
|
# segmentation = (segmentation + offset * tile_size) / (tile_size * scale)
|
||||||
|
segmentation = (segmentation + offset) / scale
|
||||||
|
|
||||||
|
yolo_annotation = f"{detection['metadata']['class_id']} " + " ".join(
|
||||||
|
[f"{x:.6f} {y:.6f}" for x, y in segmentation]
|
||||||
|
)
|
||||||
|
annotations.append(yolo_annotation)
|
||||||
|
# print(segmentation)
|
||||||
|
# print(yolo_annotation)
|
||||||
|
|
||||||
|
# aa
|
||||||
|
print(
|
||||||
|
" ",
|
||||||
|
detection["model_name"],
|
||||||
|
detection["image_id"],
|
||||||
|
detection["image_filename"],
|
||||||
|
source_path,
|
||||||
|
metadata["label_path"],
|
||||||
|
)
|
||||||
|
# section_i_section_j = detection["image_filename"].split("_")[1].split(".")[0]
|
||||||
|
# print(" ", section_i_section_j)
|
||||||
|
|
||||||
|
label_path = metadata["label_path"]
|
||||||
|
print(" ", label_path)
|
||||||
|
with open(label_path, "w") as f:
|
||||||
|
f.write("\n".join(annotations))
|
||||||
|
|
||||||
|
exit()
|
||||||
|
|
||||||
|
for detection in detections:
|
||||||
|
print(detection["model_name"], detection["image_id"], detection["image_filename"])
|
||||||
|
|
||||||
|
print(detections[0])
|
||||||
|
# polygon_vertices = np.array([[10, 10], [50, 10], [50, 50], [10, 50]])
|
||||||
|
|
||||||
|
# image = np.zeros((100, 100), dtype=np.uint8)
|
||||||
|
|
||||||
|
# rr, cc = polygon(polygon_vertices[:, 0], polygon_vertices[:, 1])
|
||||||
|
|
||||||
|
# image[rr, cc] = 255
|
||||||
|
|
||||||
|
# import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
# plt.imshow(image, cmap='gray')
|
||||||
|
# plt.show()
|
||||||
@@ -6,6 +6,9 @@ import cv2
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
from skimage.restoration import rolling_ball
|
||||||
|
from skimage.filters import threshold_otsu
|
||||||
|
from scipy.ndimage import median_filter, gaussian_filter
|
||||||
|
|
||||||
from src.utils.logger import get_logger
|
from src.utils.logger import get_logger
|
||||||
from src.utils.file_utils import validate_file_path, is_image_file
|
from src.utils.file_utils import validate_file_path, is_image_file
|
||||||
@@ -37,6 +40,8 @@ def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.5) -> np.ndarray:
|
|||||||
a1[a1 > p999] = p999
|
a1[a1 > p999] = p999
|
||||||
a1 /= a1.max()
|
a1 /= a1.max()
|
||||||
|
|
||||||
|
# print("Using get_pseudo_rgb")
|
||||||
|
|
||||||
if 1:
|
if 1:
|
||||||
a2 = a1.copy()
|
a2 = a1.copy()
|
||||||
a2 = a2**gamma
|
a2 = a2**gamma
|
||||||
@@ -46,14 +51,75 @@ def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.5) -> np.ndarray:
|
|||||||
p9999 = np.percentile(a3, 99.99)
|
p9999 = np.percentile(a3, 99.99)
|
||||||
a3[a3 > p9999] = p9999
|
a3[a3 > p9999] = p9999
|
||||||
a3 /= a3.max()
|
a3 /= a3.max()
|
||||||
|
out = np.stack([a1, a2, a3], axis=0)
|
||||||
|
|
||||||
|
else:
|
||||||
|
out = np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
# return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
|
# return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
|
||||||
# return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
|
# return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
|
||||||
out = np.stack([a1, a2, a3], axis=0)
|
|
||||||
# print(any(np.isnan(out).flatten()))
|
# print(any(np.isnan(out).flatten()))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.3) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Convert a grayscale image to a pseudo-RGB image using a gamma correction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arr: Input grayscale image as numpy array
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pseudo-RGB image as numpy array
|
||||||
|
"""
|
||||||
|
if arr.ndim != 2:
|
||||||
|
raise ValueError("Input array must be a grayscale image with shape (H, W)")
|
||||||
|
|
||||||
|
radius = 80
|
||||||
|
# bg = rolling_ball(arr, radius=radius)
|
||||||
|
a1 = arr.copy().astype(np.float32)
|
||||||
|
a1 = a1.astype(np.float32)
|
||||||
|
# a1 -= bg
|
||||||
|
# a1[a1 < 0] = 0
|
||||||
|
# a1 -= np.percentile(a1, 2)
|
||||||
|
# a1[a1 < 0] = 0
|
||||||
|
p999 = np.percentile(a1, 99.99)
|
||||||
|
a1[a1 > p999] = p999
|
||||||
|
a1 /= a1.max()
|
||||||
|
|
||||||
|
print("Using get_pseudo_rgb")
|
||||||
|
|
||||||
|
if 1:
|
||||||
|
a2 = a1.copy()
|
||||||
|
_a2 = a2**gamma
|
||||||
|
thr = threshold_otsu(_a2)
|
||||||
|
mask = gaussian_filter((_a2 > thr).astype(np.float32), sigma=5)
|
||||||
|
mask[mask > 0.0001] = 1
|
||||||
|
mask[mask <= 0.0001] = 0
|
||||||
|
a2 *= mask
|
||||||
|
|
||||||
|
# bg2 = rolling_ball(a2, radius=radius)
|
||||||
|
# a2 -= bg2
|
||||||
|
a2 -= np.percentile(a2, 2)
|
||||||
|
a2[a2 < 0] = 0
|
||||||
|
a2 /= a2.max()
|
||||||
|
|
||||||
|
a3 = a1.copy()
|
||||||
|
p9999 = np.percentile(a3, 99.99)
|
||||||
|
a3[a3 > p9999] = p9999
|
||||||
|
a3 /= a3.max()
|
||||||
|
out = np.stack([a1, a2, _a2], axis=0)
|
||||||
|
|
||||||
|
else:
|
||||||
|
out = np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
# return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
|
||||||
|
# return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
|
||||||
|
# print(any(np.isnan(out).flatten()))
|
||||||
|
|
||||||
|
|
||||||
class ImageLoadError(Exception):
|
class ImageLoadError(Exception):
|
||||||
"""Exception raised when an image cannot be loaded."""
|
"""Exception raised when an image cannot be loaded."""
|
||||||
@@ -330,17 +396,20 @@ class Image:
|
|||||||
return self._channels >= 3
|
return self._channels >= 3
|
||||||
|
|
||||||
def save(self, path: Union[str, Path], pseudo_rgb: bool = True) -> None:
|
def save(self, path: Union[str, Path], pseudo_rgb: bool = True) -> None:
|
||||||
|
# print("Image.save", self.data.shape)
|
||||||
if self.channels == 1:
|
if self.channels == 1:
|
||||||
|
# print("Image.save grayscale")
|
||||||
if pseudo_rgb:
|
if pseudo_rgb:
|
||||||
img = get_pseudo_rgb(self.data)
|
img = get_pseudo_rgb(self.data)
|
||||||
print("Image.save", img.shape)
|
# print("Image.save", img.shape)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
img = np.repeat(self.data, 3, axis=2)
|
img = np.repeat(self.data, 3, axis=2)
|
||||||
|
# print("Image.save no pseudo", img.shape)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Only grayscale images are supported for now.")
|
raise NotImplementedError("Only grayscale images are supported for now.")
|
||||||
|
|
||||||
|
# print("Image.save imwrite", img.shape)
|
||||||
imwrite(path, data=img)
|
imwrite(path, data=img)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
|||||||
@@ -97,7 +97,7 @@ class UT:
|
|||||||
class_index: int = 0,
|
class_index: int = 0,
|
||||||
):
|
):
|
||||||
"""Export rois to a file"""
|
"""Export rois to a file"""
|
||||||
with open(path / subfolder / f"{self.stem}.txt", "w") as f:
|
with open(path / subfolder / f"{PREFIX}-{self.stem}.txt", "w") as f:
|
||||||
for i, roi in enumerate(self.rois):
|
for i, roi in enumerate(self.rois):
|
||||||
rc = roi.subpixel_coordinates
|
rc = roi.subpixel_coordinates
|
||||||
if rc is None:
|
if rc is None:
|
||||||
@@ -129,8 +129,8 @@ class UT:
|
|||||||
self.image = np.max(self.image[channel], axis=0)
|
self.image = np.max(self.image[channel], axis=0)
|
||||||
print(self.image.shape)
|
print(self.image.shape)
|
||||||
|
|
||||||
print(path / subfolder / f"{self.stem}.tif")
|
print(path / subfolder / f"{PREFIX}_{self.stem}.tif")
|
||||||
with TiffWriter(path / subfolder / f"{self.stem}.tif") as tif:
|
with TiffWriter(path / subfolder / f"{PREFIX}-{self.stem}.tif") as tif:
|
||||||
tif.write(self.image)
|
tif.write(self.image)
|
||||||
|
|
||||||
|
|
||||||
@@ -145,11 +145,19 @@ if __name__ == "__main__":
|
|||||||
action="store_false",
|
action="store_false",
|
||||||
help="Source does not have labels, export only images",
|
help="Source does not have labels, export only images",
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--prefix", help="Prefix for output files")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
PREFIX = args.prefix
|
||||||
# print(args)
|
# print(args)
|
||||||
# aa
|
# aa
|
||||||
|
|
||||||
|
# for path in args.input:
|
||||||
|
# print(path)
|
||||||
|
# ut = UT(path, args.no_labels)
|
||||||
|
# ut.export_image(args.output, plane_mode="max projection", channel=0)
|
||||||
|
# ut.export_rois(args.output, class_index=0)
|
||||||
|
|
||||||
for path in args.input:
|
for path in args.input:
|
||||||
print("Path:", path)
|
print("Path:", path)
|
||||||
if not args.no_labels:
|
if not args.no_labels:
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from tifffile import imread, imwrite
|
|||||||
from shapely.geometry import LineString
|
from shapely.geometry import LineString
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from scipy.ndimage import zoom
|
from scipy.ndimage import zoom
|
||||||
|
from skimage.restoration import rolling_ball
|
||||||
|
|
||||||
# debug
|
# debug
|
||||||
from src.utils.image import Image
|
from src.utils.image import Image
|
||||||
@@ -160,8 +160,11 @@ class YoloLabelReader:
|
|||||||
|
|
||||||
|
|
||||||
class ImageSplitter:
|
class ImageSplitter:
|
||||||
def __init__(self, image_path: Path, label_path: Path):
|
def __init__(self, image_path: Path, label_path: Path, subtract_bg: bool = False):
|
||||||
self.image = imread(image_path)
|
self.image = imread(image_path)
|
||||||
|
if subtract_bg:
|
||||||
|
self.image = self.image - rolling_ball(self.image, radius=12)
|
||||||
|
self.image[self.image < 0] = 0
|
||||||
self.image_path = image_path
|
self.image_path = image_path
|
||||||
self.label_path = label_path
|
self.label_path = label_path
|
||||||
if not label_path.exists():
|
if not label_path.exists():
|
||||||
@@ -273,13 +276,14 @@ def main(args):
|
|||||||
if args.output:
|
if args.output:
|
||||||
args.output.mkdir(exist_ok=True, parents=True)
|
args.output.mkdir(exist_ok=True, parents=True)
|
||||||
(args.output / "images").mkdir(exist_ok=True)
|
(args.output / "images").mkdir(exist_ok=True)
|
||||||
(args.output / "images-zoomed").mkdir(exist_ok=True)
|
# (args.output / "images-zoomed").mkdir(exist_ok=True)
|
||||||
(args.output / "labels").mkdir(exist_ok=True)
|
(args.output / "labels").mkdir(exist_ok=True)
|
||||||
|
|
||||||
for image_path in (args.input / "images").glob("*.tif"):
|
for image_path in (args.input / "images").glob("*.tif"):
|
||||||
data = ImageSplitter(
|
data = ImageSplitter(
|
||||||
image_path=image_path,
|
image_path=image_path,
|
||||||
label_path=(args.input / "labels" / image_path.stem).with_suffix(".txt"),
|
label_path=(args.input / "labels" / image_path.stem).with_suffix(".txt"),
|
||||||
|
subtract_bg=args.subtract_bg,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.split_around_label:
|
if args.split_around_label:
|
||||||
@@ -332,10 +336,20 @@ def main(args):
|
|||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
with open(args.output / "labels" / f"{image_path.stem}_{tile_reference}.txt", "w") as f:
|
with open(args.output / "labels" / f"{image_path.stem}_{tile_reference}.txt", "w") as f:
|
||||||
|
print(
|
||||||
|
f"Writing {len(labels)} labels to {args.output / 'labels' / f'{image_path.stem}_{tile_reference}.txt'}"
|
||||||
|
)
|
||||||
for label in labels:
|
for label in labels:
|
||||||
# label.offset_label(tile.shape[1], tile.shape[0])
|
# label.offset_label(tile.shape[1], tile.shape[0])
|
||||||
f.write(label.to_string() + "\n")
|
f.write(label.to_string() + "\n")
|
||||||
|
|
||||||
|
# { debug
|
||||||
|
if debug:
|
||||||
|
print(label.to_string())
|
||||||
|
# } debug
|
||||||
|
# break
|
||||||
|
# break
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import argparse
|
import argparse
|
||||||
@@ -363,6 +377,7 @@ if __name__ == "__main__":
|
|||||||
default=67,
|
default=67,
|
||||||
help="Padding around the label when splitting around the label.",
|
help="Padding around the label when splitting around the label.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument("-bg", "--subtract-bg", action="store_true", help="Subtract background from the image.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
@@ -189,33 +189,38 @@ def main():
|
|||||||
# continue and just show image
|
# continue and just show image
|
||||||
out = draw_annotations(img.copy(), labels, alpha=args.alpha, draw_bbox_for_poly=(not args.no_bbox))
|
out = draw_annotations(img.copy(), labels, alpha=args.alpha, draw_bbox_for_poly=(not args.no_bbox))
|
||||||
|
|
||||||
lclass, coords = labels[0]
|
|
||||||
print(lclass, coords)
|
|
||||||
bbox = coords[:4]
|
|
||||||
print("bbox", bbox)
|
|
||||||
bbox = np.array(bbox) * np.array([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
|
|
||||||
yc, xc, h, w = bbox
|
|
||||||
print("bbox", bbox)
|
|
||||||
|
|
||||||
# polyline = np.array(coords[4:]).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
|
|
||||||
polyline = np.array(coords).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
|
|
||||||
print("pl", coords[4:])
|
|
||||||
print("pl", polyline)
|
|
||||||
|
|
||||||
# Convert BGR -> RGB for matplotlib display
|
|
||||||
# out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
|
|
||||||
out_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
out_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
# out_rgb = Image()
|
|
||||||
plt.figure(figsize=(10, 10 * out.shape[0] / out.shape[1]))
|
plt.figure(figsize=(10, 10 * out.shape[0] / out.shape[1]))
|
||||||
plt.imshow(out_rgb)
|
|
||||||
plt.plot(polyline[:, 0], polyline[:, 1], "y", linewidth=2)
|
|
||||||
if 0:
|
if 0:
|
||||||
plt.plot(
|
plt.imshow(out_rgb.transpose(1, 0, 2))
|
||||||
[yc - h / 2, yc - h / 2, yc + h / 2, yc + h / 2, yc - h / 2],
|
else:
|
||||||
[xc - w / 2, xc + w / 2, xc + w / 2, xc - w / 2, xc - w / 2],
|
plt.imshow(out_rgb)
|
||||||
"r",
|
|
||||||
linewidth=2,
|
for label in labels:
|
||||||
)
|
lclass, coords = label
|
||||||
|
# print(lclass, coords)
|
||||||
|
bbox = coords[:4]
|
||||||
|
# print("bbox", bbox)
|
||||||
|
bbox = np.array(bbox) * np.array([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
|
||||||
|
yc, xc, h, w = bbox
|
||||||
|
# print("bbox", bbox)
|
||||||
|
|
||||||
|
# polyline = np.array(coords[4:]).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
|
||||||
|
polyline = np.array(coords).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
|
||||||
|
# print("pl", coords[4:])
|
||||||
|
# print("pl", polyline)
|
||||||
|
|
||||||
|
# Convert BGR -> RGB for matplotlib display
|
||||||
|
# out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
|
||||||
|
# out_rgb = Image()
|
||||||
|
plt.plot(polyline[:, 0], polyline[:, 1], "y", linewidth=2)
|
||||||
|
if 0:
|
||||||
|
plt.plot(
|
||||||
|
[yc - h / 2, yc - h / 2, yc + h / 2, yc + h / 2, yc - h / 2],
|
||||||
|
[xc - w / 2, xc + w / 2, xc + w / 2, xc - w / 2, xc - w / 2],
|
||||||
|
"r",
|
||||||
|
linewidth=2,
|
||||||
|
)
|
||||||
|
|
||||||
# plt.axis("off")
|
# plt.axis("off")
|
||||||
plt.title(f"{img_path.name} ({lbl_path.name})")
|
plt.title(f"{img_path.name} ({lbl_path.name})")
|
||||||
|
|||||||
Reference in New Issue
Block a user