Compare commits
31 Commits
bb26d43dd7
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| e5036c10cf | |||
| c7e388d9ae | |||
| 6b995e7325 | |||
| 0e0741d323 | |||
| dd99a0677c | |||
| 9c4c39fb39 | |||
| 20a87c9040 | |||
| 9f7d2be1ac | |||
| dbde07c0e8 | |||
| b3c5a51dbb | |||
| 9a221acb63 | |||
| 32a6a122bd | |||
| 9ba44043ef | |||
| 8eb1cc8c86 | |||
| e4ce882a18 | |||
| 6b6d6fad03 | |||
| c0684a9c14 | |||
| 221c80aa8c | |||
| 833b222fad | |||
| 5370d31dce | |||
| 5d196c3a4a | |||
| f719c7ec40 | |||
| e6a5e74fa1 | |||
| 35e2398e95 | |||
| c3d44ac945 | |||
| dad5c2bf74 | |||
| 73cb698488 | |||
| 12f2bf94d5 | |||
| 710b684456 | |||
| fc22479621 | |||
| f84dea0bff |
@@ -12,12 +12,28 @@ image_repository:
|
||||
models:
|
||||
default_base_model: yolov8s-seg.pt
|
||||
models_directory: data/models
|
||||
base_model_choices:
|
||||
- yolov8s-seg.pt
|
||||
- yolo11s-seg.pt
|
||||
training:
|
||||
default_epochs: 100
|
||||
default_batch_size: 16
|
||||
default_imgsz: 640
|
||||
default_imgsz: 1024
|
||||
default_patience: 50
|
||||
default_lr0: 0.01
|
||||
two_stage:
|
||||
enabled: false
|
||||
stage1:
|
||||
epochs: 20
|
||||
lr0: 0.0005
|
||||
patience: 10
|
||||
freeze: 10
|
||||
stage2:
|
||||
epochs: 150
|
||||
lr0: 0.0003
|
||||
patience: 30
|
||||
last_dataset_yaml: /home/martin/code/object_detection/data/datasets/data.yaml
|
||||
last_dataset_dir: /home/martin/code/object_detection/data/datasets
|
||||
detection:
|
||||
default_confidence: 0.25
|
||||
default_iou: 0.45
|
||||
|
||||
@@ -10,6 +10,14 @@ from typing import List, Dict, Optional, Tuple, Any, Union
|
||||
from pathlib import Path
|
||||
import csv
|
||||
import hashlib
|
||||
import yaml
|
||||
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.image import Image
|
||||
|
||||
IMAGE_EXTENSIONS = tuple(Image.SUPPORTED_EXTENSIONS)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
@@ -30,18 +38,48 @@ class DatabaseManager:
|
||||
# Create directory if it doesn't exist
|
||||
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Read schema file and execute
|
||||
schema_path = Path(__file__).parent / "schema.sql"
|
||||
with open(schema_path, "r") as f:
|
||||
schema_sql = f.read()
|
||||
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
# Check if annotations table needs migration
|
||||
self._migrate_annotations_table(conn)
|
||||
|
||||
# Read schema file and execute
|
||||
schema_path = Path(__file__).parent / "schema.sql"
|
||||
with open(schema_path, "r") as f:
|
||||
schema_sql = f.read()
|
||||
|
||||
conn.executescript(schema_sql)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _migrate_annotations_table(self, conn: sqlite3.Connection) -> None:
|
||||
"""
|
||||
Migrate annotations table from old schema (class_name) to new schema (class_id).
|
||||
"""
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check if annotations table exists
|
||||
cursor.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='annotations'"
|
||||
)
|
||||
if not cursor.fetchone():
|
||||
# Table doesn't exist yet, no migration needed
|
||||
return
|
||||
|
||||
# Check if table has old schema (class_name column)
|
||||
cursor.execute("PRAGMA table_info(annotations)")
|
||||
columns = {row[1]: row for row in cursor.fetchall()}
|
||||
|
||||
if "class_name" in columns and "class_id" not in columns:
|
||||
# Old schema detected, need to migrate
|
||||
print("Migrating annotations table to new schema with class_id...")
|
||||
|
||||
# Drop old annotations table (assuming no critical data since this is a new feature)
|
||||
cursor.execute("DROP TABLE IF EXISTS annotations")
|
||||
conn.commit()
|
||||
print("Old annotations table dropped, will be recreated with new schema")
|
||||
|
||||
def get_connection(self) -> sqlite3.Connection:
|
||||
"""Get database connection with proper settings."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
@@ -413,6 +451,25 @@ class DatabaseManager:
|
||||
filters["model_id"] = model_id
|
||||
return self.get_detections(filters)
|
||||
|
||||
def delete_detections_for_image(
|
||||
self, image_id: int, model_id: Optional[int] = None
|
||||
) -> int:
|
||||
"""Delete detections tied to a specific image and optional model."""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
if model_id is not None:
|
||||
cursor.execute(
|
||||
"DELETE FROM detections WHERE image_id = ? AND model_id = ?",
|
||||
(image_id, model_id),
|
||||
)
|
||||
else:
|
||||
cursor.execute("DELETE FROM detections WHERE image_id = ?", (image_id,))
|
||||
conn.commit()
|
||||
return cursor.rowcount
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def delete_detections_for_model(self, model_id: int) -> int:
|
||||
"""Delete all detections for a specific model."""
|
||||
conn = self.get_connection()
|
||||
@@ -593,25 +650,38 @@ class DatabaseManager:
|
||||
def add_annotation(
|
||||
self,
|
||||
image_id: int,
|
||||
class_name: str,
|
||||
class_id: int,
|
||||
bbox: Tuple[float, float, float, float],
|
||||
annotator: str,
|
||||
segmentation_mask: Optional[List[List[float]]] = None,
|
||||
verified: bool = False,
|
||||
) -> int:
|
||||
"""Add manual annotation."""
|
||||
"""
|
||||
Add manual annotation.
|
||||
|
||||
Args:
|
||||
image_id: ID of the image
|
||||
class_id: ID of the object class (foreign key to object_classes)
|
||||
bbox: Bounding box coordinates (normalized 0-1)
|
||||
annotator: Name of person/tool creating annotation
|
||||
segmentation_mask: Polygon coordinates for segmentation
|
||||
verified: Whether annotation has been verified
|
||||
|
||||
Returns:
|
||||
ID of the inserted annotation
|
||||
"""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
x_min, y_min, x_max, y_max = bbox
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO annotations (image_id, class_name, x_min, y_min, x_max, y_max, segmentation_mask, annotator, verified)
|
||||
INSERT INTO annotations (image_id, class_id, x_min, y_min, x_max, y_max, segmentation_mask, annotator, verified)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
image_id,
|
||||
class_name,
|
||||
class_id,
|
||||
x_min,
|
||||
y_min,
|
||||
x_max,
|
||||
@@ -627,15 +697,378 @@ class DatabaseManager:
|
||||
conn.close()
|
||||
|
||||
def get_annotations_for_image(self, image_id: int) -> List[Dict]:
|
||||
"""Get all annotations for an image."""
|
||||
"""
|
||||
Get all annotations for an image with class information.
|
||||
|
||||
Args:
|
||||
image_id: ID of the image
|
||||
|
||||
Returns:
|
||||
List of annotation dictionaries with joined class information
|
||||
"""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM annotations WHERE image_id = ?", (image_id,))
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
a.*,
|
||||
c.class_name,
|
||||
c.color as class_color,
|
||||
c.description as class_description
|
||||
FROM annotations a
|
||||
JOIN object_classes c ON a.class_id = c.id
|
||||
WHERE a.image_id = ?
|
||||
ORDER BY a.created_at DESC
|
||||
""",
|
||||
(image_id,),
|
||||
)
|
||||
annotations = []
|
||||
for row in cursor.fetchall():
|
||||
ann = dict(row)
|
||||
if ann.get("segmentation_mask"):
|
||||
ann["segmentation_mask"] = json.loads(ann["segmentation_mask"])
|
||||
annotations.append(ann)
|
||||
return annotations
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def delete_annotation(self, annotation_id: int) -> bool:
|
||||
"""
|
||||
Delete a manual annotation by ID.
|
||||
|
||||
Args:
|
||||
annotation_id: ID of the annotation to delete
|
||||
|
||||
Returns:
|
||||
True if an annotation was deleted, False otherwise.
|
||||
"""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM annotations WHERE id = ?", (annotation_id,))
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# ==================== Object Class Operations ====================
|
||||
|
||||
def get_object_classes(self) -> List[Dict]:
|
||||
"""
|
||||
Get all object classes.
|
||||
|
||||
Returns:
|
||||
List of object class dictionaries
|
||||
"""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM object_classes ORDER BY class_name")
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_object_class_by_id(self, class_id: int) -> Optional[Dict]:
|
||||
"""Get object class by ID."""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM object_classes WHERE id = ?", (class_id,))
|
||||
row = cursor.fetchone()
|
||||
return dict(row) if row else None
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_object_class_by_name(self, class_name: str) -> Optional[Dict]:
|
||||
"""Get object class by name."""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT * FROM object_classes WHERE class_name = ?", (class_name,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return dict(row) if row else None
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def add_object_class(
|
||||
self, class_name: str, color: str, description: Optional[str] = None
|
||||
) -> int:
|
||||
"""
|
||||
Add a new object class.
|
||||
|
||||
Args:
|
||||
class_name: Name of the object class
|
||||
color: Hex color code (e.g., '#FF0000')
|
||||
description: Optional description
|
||||
|
||||
Returns:
|
||||
ID of the inserted object class
|
||||
"""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO object_classes (class_name, color, description)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
(class_name, color, description),
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.lastrowid
|
||||
except sqlite3.IntegrityError:
|
||||
# Class already exists
|
||||
existing = self.get_object_class_by_name(class_name)
|
||||
return existing["id"] if existing else None
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def update_object_class(
|
||||
self,
|
||||
class_id: int,
|
||||
class_name: Optional[str] = None,
|
||||
color: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Update an object class.
|
||||
|
||||
Args:
|
||||
class_id: ID of the class to update
|
||||
class_name: New class name (optional)
|
||||
color: New color (optional)
|
||||
description: New description (optional)
|
||||
|
||||
Returns:
|
||||
True if updated, False otherwise
|
||||
"""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
updates = {}
|
||||
if class_name is not None:
|
||||
updates["class_name"] = class_name
|
||||
if color is not None:
|
||||
updates["color"] = color
|
||||
if description is not None:
|
||||
updates["description"] = description
|
||||
|
||||
if not updates:
|
||||
return False
|
||||
|
||||
set_clauses = [f"{key} = ?" for key in updates.keys()]
|
||||
params = list(updates.values()) + [class_id]
|
||||
|
||||
query = f"UPDATE object_classes SET {', '.join(set_clauses)} WHERE id = ?"
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(query, params)
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def delete_object_class(self, class_id: int) -> bool:
|
||||
"""
|
||||
Delete an object class.
|
||||
|
||||
Args:
|
||||
class_id: ID of the class to delete
|
||||
|
||||
Returns:
|
||||
True if deleted, False otherwise
|
||||
"""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM object_classes WHERE id = ?", (class_id,))
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# ==================== Dataset Utilities ====================
|
||||
|
||||
def compose_data_yaml(
|
||||
self,
|
||||
dataset_root: str,
|
||||
output_path: Optional[str] = None,
|
||||
splits: Optional[Dict[str, str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Compose a YOLO data.yaml file based on dataset folders and database metadata.
|
||||
|
||||
Args:
|
||||
dataset_root: Base directory containing the dataset structure.
|
||||
output_path: Optional output path; defaults to <dataset_root>/data.yaml.
|
||||
splits: Optional mapping overriding train/val/test image directories (relative
|
||||
to dataset_root or absolute paths).
|
||||
|
||||
Returns:
|
||||
Path to the generated YAML file.
|
||||
"""
|
||||
dataset_root_path = Path(dataset_root).expanduser()
|
||||
if not dataset_root_path.exists():
|
||||
raise ValueError(f"Dataset root does not exist: {dataset_root_path}")
|
||||
dataset_root_path = dataset_root_path.resolve()
|
||||
|
||||
split_map: Dict[str, str] = {key: "" for key in ("train", "val", "test")}
|
||||
if splits:
|
||||
for key, value in splits.items():
|
||||
if key in split_map and value:
|
||||
split_map[key] = value
|
||||
|
||||
inferred = self._infer_split_dirs(dataset_root_path)
|
||||
for key in split_map:
|
||||
if not split_map[key]:
|
||||
split_map[key] = inferred.get(key, "")
|
||||
|
||||
for required in ("train", "val"):
|
||||
if not split_map[required]:
|
||||
raise ValueError(
|
||||
"Unable to determine %s image directory under %s. Provide it "
|
||||
"explicitly via the 'splits' argument."
|
||||
% (required, dataset_root_path)
|
||||
)
|
||||
|
||||
yaml_splits: Dict[str, str] = {}
|
||||
for key, value in split_map.items():
|
||||
if not value:
|
||||
continue
|
||||
yaml_splits[key] = self._normalize_split_value(value, dataset_root_path)
|
||||
|
||||
class_names = self._fetch_annotation_class_names()
|
||||
if not class_names:
|
||||
class_names = [cls["class_name"] for cls in self.get_object_classes()]
|
||||
if not class_names:
|
||||
raise ValueError("No object classes available to populate data.yaml")
|
||||
|
||||
names_map = {idx: name for idx, name in enumerate(class_names)}
|
||||
payload: Dict[str, Any] = {
|
||||
"path": dataset_root_path.as_posix(),
|
||||
"train": yaml_splits["train"],
|
||||
"val": yaml_splits["val"],
|
||||
"names": names_map,
|
||||
"nc": len(class_names),
|
||||
}
|
||||
if yaml_splits.get("test"):
|
||||
payload["test"] = yaml_splits["test"]
|
||||
|
||||
output_path_obj = (
|
||||
Path(output_path).expanduser()
|
||||
if output_path
|
||||
else dataset_root_path / "data.yaml"
|
||||
)
|
||||
output_path_obj.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path_obj, "w", encoding="utf-8") as handle:
|
||||
yaml.safe_dump(payload, handle, sort_keys=False)
|
||||
|
||||
logger.info(f"Generated data.yaml at {output_path_obj}")
|
||||
return output_path_obj.as_posix()
|
||||
|
||||
def _fetch_annotation_class_names(self) -> List[str]:
|
||||
"""Return class names referenced by annotations (ordered by class ID)."""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT DISTINCT c.id, c.class_name
|
||||
FROM annotations a
|
||||
JOIN object_classes c ON a.class_id = c.id
|
||||
ORDER BY c.id
|
||||
"""
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
return [row["class_name"] for row in rows]
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _infer_split_dirs(self, dataset_root: Path) -> Dict[str, str]:
|
||||
"""Infer train/val/test image directories relative to dataset_root."""
|
||||
patterns = {
|
||||
"train": [
|
||||
"train/images",
|
||||
"training/images",
|
||||
"images/train",
|
||||
"images/training",
|
||||
"train",
|
||||
"training",
|
||||
],
|
||||
"val": [
|
||||
"val/images",
|
||||
"validation/images",
|
||||
"images/val",
|
||||
"images/validation",
|
||||
"val",
|
||||
"validation",
|
||||
],
|
||||
"test": [
|
||||
"test/images",
|
||||
"testing/images",
|
||||
"images/test",
|
||||
"images/testing",
|
||||
"test",
|
||||
"testing",
|
||||
],
|
||||
}
|
||||
|
||||
inferred: Dict[str, str] = {key: "" for key in patterns}
|
||||
for split_name, options in patterns.items():
|
||||
for relative in options:
|
||||
candidate = (dataset_root / relative).resolve()
|
||||
if (
|
||||
candidate.exists()
|
||||
and candidate.is_dir()
|
||||
and self._directory_has_images(candidate)
|
||||
):
|
||||
try:
|
||||
inferred[split_name] = candidate.relative_to(
|
||||
dataset_root
|
||||
).as_posix()
|
||||
except ValueError:
|
||||
inferred[split_name] = candidate.as_posix()
|
||||
break
|
||||
return inferred
|
||||
|
||||
def _normalize_split_value(self, split_value: str, dataset_root: Path) -> str:
|
||||
"""Validate and normalize a split directory to a YAML-friendly string."""
|
||||
split_path = Path(split_value).expanduser()
|
||||
if not split_path.is_absolute():
|
||||
split_path = (dataset_root / split_path).resolve()
|
||||
else:
|
||||
split_path = split_path.resolve()
|
||||
|
||||
if not split_path.exists() or not split_path.is_dir():
|
||||
raise ValueError(f"Split directory not found: {split_path}")
|
||||
|
||||
if not self._directory_has_images(split_path):
|
||||
raise ValueError(f"No images found under {split_path}")
|
||||
|
||||
try:
|
||||
return split_path.relative_to(dataset_root).as_posix()
|
||||
except ValueError:
|
||||
return split_path.as_posix()
|
||||
|
||||
@staticmethod
|
||||
def _directory_has_images(directory: Path, max_checks: int = 2000) -> bool:
|
||||
"""Return True if directory tree contains at least one image file."""
|
||||
checked = 0
|
||||
try:
|
||||
for file_path in directory.rglob("*"):
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
if file_path.suffix.lower() in IMAGE_EXTENSIONS:
|
||||
return True
|
||||
checked += 1
|
||||
if checked >= max_checks:
|
||||
break
|
||||
except Exception:
|
||||
return False
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def calculate_checksum(file_path: str) -> str:
|
||||
"""Calculate MD5 checksum of a file."""
|
||||
|
||||
@@ -44,11 +44,27 @@ CREATE TABLE IF NOT EXISTS detections (
|
||||
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- Annotations table: stores manual annotations (future feature)
|
||||
-- Object classes table: stores annotation class definitions with colors
|
||||
CREATE TABLE IF NOT EXISTS object_classes (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
class_name TEXT NOT NULL UNIQUE,
|
||||
color TEXT NOT NULL, -- Hex color code (e.g., '#FF0000')
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
description TEXT
|
||||
);
|
||||
|
||||
-- Insert default object classes
|
||||
INSERT OR IGNORE INTO object_classes (class_name, color, description) VALUES
|
||||
('cell', '#FF0000', 'Cell object'),
|
||||
('nucleus', '#00FF00', 'Cell nucleus'),
|
||||
('mitochondria', '#0000FF', 'Mitochondria'),
|
||||
('vesicle', '#FFFF00', 'Vesicle');
|
||||
|
||||
-- Annotations table: stores manual annotations
|
||||
CREATE TABLE IF NOT EXISTS annotations (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
image_id INTEGER NOT NULL,
|
||||
class_name TEXT NOT NULL,
|
||||
class_id INTEGER NOT NULL,
|
||||
x_min REAL NOT NULL CHECK(x_min >= 0 AND x_min <= 1),
|
||||
y_min REAL NOT NULL CHECK(y_min >= 0 AND y_min <= 1),
|
||||
x_max REAL NOT NULL CHECK(x_max >= 0 AND x_max <= 1),
|
||||
@@ -57,7 +73,8 @@ CREATE TABLE IF NOT EXISTS annotations (
|
||||
annotator TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
verified BOOLEAN DEFAULT 0,
|
||||
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE
|
||||
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (class_id) REFERENCES object_classes (id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- Create indexes for performance optimization
|
||||
@@ -69,4 +86,6 @@ CREATE INDEX IF NOT EXISTS idx_detections_confidence ON detections(confidence);
|
||||
CREATE INDEX IF NOT EXISTS idx_images_relative_path ON images(relative_path);
|
||||
CREATE INDEX IF NOT EXISTS idx_images_added_at ON images(added_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_annotations_image_id ON annotations(image_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_annotations_class_id ON annotations(class_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_object_classes_class_name ON object_classes(class_name);
|
||||
@@ -13,7 +13,7 @@ from PySide6.QtWidgets import (
|
||||
QVBoxLayout,
|
||||
QLabel,
|
||||
)
|
||||
from PySide6.QtCore import Qt, QTimer
|
||||
from PySide6.QtCore import Qt, QTimer, QSettings
|
||||
from PySide6.QtGui import QAction, QKeySequence
|
||||
|
||||
from src.database.db_manager import DatabaseManager
|
||||
@@ -52,8 +52,8 @@ class MainWindow(QMainWindow):
|
||||
self._create_tab_widget()
|
||||
self._create_status_bar()
|
||||
|
||||
# Center window on screen
|
||||
self._center_window()
|
||||
# Restore window geometry or center window on screen
|
||||
self._restore_window_state()
|
||||
|
||||
logger.info("Main window initialized")
|
||||
|
||||
@@ -156,6 +156,24 @@ class MainWindow(QMainWindow):
|
||||
(screen.width() - size.width()) // 2, (screen.height() - size.height()) // 2
|
||||
)
|
||||
|
||||
def _restore_window_state(self):
|
||||
"""Restore window geometry from settings or center window."""
|
||||
settings = QSettings("microscopy_app", "object_detection")
|
||||
geometry = settings.value("main_window/geometry")
|
||||
|
||||
if geometry:
|
||||
self.restoreGeometry(geometry)
|
||||
logger.debug("Restored window geometry from settings")
|
||||
else:
|
||||
self._center_window()
|
||||
logger.debug("Centered window on screen")
|
||||
|
||||
def _save_window_state(self):
|
||||
"""Save window geometry to settings."""
|
||||
settings = QSettings("microscopy_app", "object_detection")
|
||||
settings.setValue("main_window/geometry", self.saveGeometry())
|
||||
logger.debug("Saved window geometry to settings")
|
||||
|
||||
def _show_settings(self):
|
||||
"""Show settings dialog."""
|
||||
logger.info("Opening settings dialog")
|
||||
@@ -276,6 +294,15 @@ class MainWindow(QMainWindow):
|
||||
)
|
||||
|
||||
if reply == QMessageBox.Yes:
|
||||
# Save window state before closing
|
||||
self._save_window_state()
|
||||
|
||||
# Persist tab state and stop background work before exit
|
||||
if hasattr(self, "training_tab"):
|
||||
self.training_tab.shutdown()
|
||||
if hasattr(self, "annotation_tab"):
|
||||
self.annotation_tab.save_state()
|
||||
|
||||
logger.info("Application closing")
|
||||
event.accept()
|
||||
else:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
Annotation tab for the microscopy object detection application.
|
||||
Future feature for manual annotation.
|
||||
Manual annotation with pen tool and object class management.
|
||||
"""
|
||||
|
||||
from PySide6.QtWidgets import (
|
||||
@@ -12,6 +12,7 @@ from PySide6.QtWidgets import (
|
||||
QPushButton,
|
||||
QFileDialog,
|
||||
QMessageBox,
|
||||
QSplitter,
|
||||
)
|
||||
from PySide6.QtCore import Qt, QSettings
|
||||
from pathlib import Path
|
||||
@@ -20,13 +21,13 @@ from src.database.db_manager import DatabaseManager
|
||||
from src.utils.config_manager import ConfigManager
|
||||
from src.utils.image import Image, ImageLoadError
|
||||
from src.utils.logger import get_logger
|
||||
from src.gui.widgets import ImageDisplayWidget
|
||||
from src.gui.widgets import AnnotationCanvasWidget, AnnotationToolsWidget
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AnnotationTab(QWidget):
|
||||
"""Annotation tab placeholder (future feature)."""
|
||||
"""Annotation tab for manual image annotation."""
|
||||
|
||||
def __init__(
|
||||
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
||||
@@ -36,6 +37,10 @@ class AnnotationTab(QWidget):
|
||||
self.config_manager = config_manager
|
||||
self.current_image = None
|
||||
self.current_image_path = None
|
||||
self.current_image_id = None
|
||||
self.current_annotations = []
|
||||
# IDs of annotations currently selected on the canvas (multi-select)
|
||||
self.selected_annotation_ids = []
|
||||
|
||||
self._setup_ui()
|
||||
|
||||
@@ -43,6 +48,75 @@ class AnnotationTab(QWidget):
|
||||
"""Setup user interface."""
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# Main horizontal splitter to divide left (image) and right (controls)
|
||||
self.main_splitter = QSplitter(Qt.Horizontal)
|
||||
self.main_splitter.setHandleWidth(10)
|
||||
|
||||
# { Left splitter for image display and zoom info
|
||||
self.left_splitter = QSplitter(Qt.Vertical)
|
||||
self.left_splitter.setHandleWidth(10)
|
||||
|
||||
# Annotation canvas section
|
||||
canvas_group = QGroupBox("Annotation Canvas")
|
||||
canvas_layout = QVBoxLayout()
|
||||
|
||||
# Use the AnnotationCanvasWidget
|
||||
self.annotation_canvas = AnnotationCanvasWidget()
|
||||
self.annotation_canvas.zoom_changed.connect(self._on_zoom_changed)
|
||||
self.annotation_canvas.annotation_drawn.connect(self._on_annotation_drawn)
|
||||
# Selection of existing polylines (when tool is not in drawing mode)
|
||||
self.annotation_canvas.annotation_selected.connect(self._on_annotation_selected)
|
||||
canvas_layout.addWidget(self.annotation_canvas)
|
||||
|
||||
canvas_group.setLayout(canvas_layout)
|
||||
self.left_splitter.addWidget(canvas_group)
|
||||
|
||||
# Controls info
|
||||
controls_info = QLabel(
|
||||
"Zoom: Mouse wheel or +/- keys | Drawing: Enable pen and drag mouse"
|
||||
)
|
||||
controls_info.setStyleSheet("QLabel { color: #888; font-style: italic; }")
|
||||
self.left_splitter.addWidget(controls_info)
|
||||
# }
|
||||
|
||||
# { Right splitter for annotation tools and controls
|
||||
self.right_splitter = QSplitter(Qt.Vertical)
|
||||
self.right_splitter.setHandleWidth(10)
|
||||
|
||||
# Annotation tools section
|
||||
self.annotation_tools = AnnotationToolsWidget(self.db_manager)
|
||||
self.annotation_tools.polyline_enabled_changed.connect(
|
||||
self.annotation_canvas.set_polyline_enabled
|
||||
)
|
||||
self.annotation_tools.polyline_pen_color_changed.connect(
|
||||
self.annotation_canvas.set_polyline_pen_color
|
||||
)
|
||||
self.annotation_tools.polyline_pen_width_changed.connect(
|
||||
self.annotation_canvas.set_polyline_pen_width
|
||||
)
|
||||
# Show / hide bounding boxes
|
||||
self.annotation_tools.show_bboxes_changed.connect(
|
||||
self.annotation_canvas.set_show_bboxes
|
||||
)
|
||||
# RDP simplification controls
|
||||
self.annotation_tools.simplify_on_finish_changed.connect(
|
||||
self._on_simplify_on_finish_changed
|
||||
)
|
||||
self.annotation_tools.simplify_epsilon_changed.connect(
|
||||
self._on_simplify_epsilon_changed
|
||||
)
|
||||
# Class selection and class-color changes
|
||||
self.annotation_tools.class_selected.connect(self._on_class_selected)
|
||||
self.annotation_tools.class_color_changed.connect(self._on_class_color_changed)
|
||||
self.annotation_tools.clear_annotations_requested.connect(
|
||||
self._on_clear_annotations
|
||||
)
|
||||
# Delete selected annotation on canvas
|
||||
self.annotation_tools.delete_selected_annotation_requested.connect(
|
||||
self._on_delete_selected_annotation
|
||||
)
|
||||
self.right_splitter.addWidget(self.annotation_tools)
|
||||
|
||||
# Image loading section
|
||||
load_group = QGroupBox("Image Loading")
|
||||
load_layout = QVBoxLayout()
|
||||
@@ -53,7 +127,6 @@ class AnnotationTab(QWidget):
|
||||
self.load_image_btn.clicked.connect(self._load_image)
|
||||
button_layout.addWidget(self.load_image_btn)
|
||||
button_layout.addStretch()
|
||||
|
||||
load_layout.addLayout(button_layout)
|
||||
|
||||
# Image info label
|
||||
@@ -61,43 +134,22 @@ class AnnotationTab(QWidget):
|
||||
load_layout.addWidget(self.image_info_label)
|
||||
|
||||
load_group.setLayout(load_layout)
|
||||
layout.addWidget(load_group)
|
||||
self.right_splitter.addWidget(load_group)
|
||||
# }
|
||||
|
||||
# Image display section
|
||||
display_group = QGroupBox("Image Display")
|
||||
display_layout = QVBoxLayout()
|
||||
# Add both splitters to the main horizontal splitter
|
||||
self.main_splitter.addWidget(self.left_splitter)
|
||||
self.main_splitter.addWidget(self.right_splitter)
|
||||
|
||||
# Use the reusable ImageDisplayWidget
|
||||
self.image_display_widget = ImageDisplayWidget()
|
||||
self.image_display_widget.zoom_changed.connect(self._on_zoom_changed)
|
||||
display_layout.addWidget(self.image_display_widget)
|
||||
|
||||
display_group.setLayout(display_layout)
|
||||
layout.addWidget(display_group)
|
||||
|
||||
# Future features info
|
||||
info_group = QGroupBox("Annotation Tool (Future Feature)")
|
||||
info_layout = QVBoxLayout()
|
||||
info_label = QLabel(
|
||||
"Full annotation functionality will be implemented in future version.\n\n"
|
||||
"Planned Features:\n"
|
||||
"- Drawing tools for bounding boxes\n"
|
||||
"- Class label assignment\n"
|
||||
"- Export annotations to YOLO format\n"
|
||||
"- Annotation verification"
|
||||
)
|
||||
info_layout.addWidget(info_label)
|
||||
info_group.setLayout(info_layout)
|
||||
|
||||
layout.addWidget(info_group)
|
||||
|
||||
# Zoom controls info
|
||||
zoom_info = QLabel("Zoom: Mouse wheel or +/- keys to zoom in/out")
|
||||
zoom_info.setStyleSheet("QLabel { color: #888; font-style: italic; }")
|
||||
layout.addWidget(zoom_info)
|
||||
# Set initial sizes: 75% for left (image), 25% for right (controls)
|
||||
self.main_splitter.setSizes([750, 250])
|
||||
|
||||
layout.addWidget(self.main_splitter)
|
||||
self.setLayout(layout)
|
||||
|
||||
# Restore splitter positions from settings
|
||||
self._restore_state()
|
||||
|
||||
def _load_image(self):
|
||||
"""Load and display an image file."""
|
||||
# Get last opened directory from QSettings
|
||||
@@ -116,7 +168,7 @@ class AnnotationTab(QWidget):
|
||||
self,
|
||||
"Select Image",
|
||||
start_dir,
|
||||
"Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)",
|
||||
"Images (*" + " *".join(Image.SUPPORTED_EXTENSIONS) + ")",
|
||||
)
|
||||
|
||||
if not file_path:
|
||||
@@ -132,13 +184,25 @@ class AnnotationTab(QWidget):
|
||||
"annotation_tab/last_directory", str(Path(file_path).parent)
|
||||
)
|
||||
|
||||
# Display image using the ImageDisplayWidget
|
||||
self.image_display_widget.load_image(self.current_image)
|
||||
# Get or create image in database
|
||||
relative_path = str(Path(file_path).name) # Simplified for now
|
||||
self.current_image_id = self.db_manager.get_or_create_image(
|
||||
relative_path,
|
||||
Path(file_path).name,
|
||||
self.current_image.width,
|
||||
self.current_image.height,
|
||||
)
|
||||
|
||||
# Display image using the AnnotationCanvasWidget
|
||||
self.annotation_canvas.load_image(self.current_image)
|
||||
|
||||
# Load and display any existing annotations for this image
|
||||
self._load_annotations_for_current_image()
|
||||
|
||||
# Update info label
|
||||
self._update_image_info()
|
||||
|
||||
logger.info(f"Loaded image: {file_path}")
|
||||
logger.info(f"Loaded image: {file_path} (DB ID: {self.current_image_id})")
|
||||
|
||||
except ImageLoadError as e:
|
||||
logger.error(f"Failed to load image: {e}")
|
||||
@@ -155,7 +219,7 @@ class AnnotationTab(QWidget):
|
||||
self.image_info_label.setText("No image loaded")
|
||||
return
|
||||
|
||||
zoom_percentage = self.image_display_widget.get_zoom_percentage()
|
||||
zoom_percentage = self.annotation_canvas.get_zoom_percentage()
|
||||
info_text = (
|
||||
f"File: {Path(self.current_image_path).name}\n"
|
||||
f"Size: {self.current_image.width}x{self.current_image.height} pixels\n"
|
||||
@@ -168,9 +232,335 @@ class AnnotationTab(QWidget):
|
||||
self.image_info_label.setText(info_text)
|
||||
|
||||
def _on_zoom_changed(self, zoom_scale: float):
|
||||
"""Handle zoom level changes from the image display widget."""
|
||||
"""Handle zoom level changes from the annotation canvas."""
|
||||
self._update_image_info()
|
||||
|
||||
def _on_annotation_drawn(self, points: list):
|
||||
"""
|
||||
Handle when an annotation stroke is drawn.
|
||||
|
||||
Saves the new annotation directly to the database and refreshes the
|
||||
on-canvas display of annotations for the current image.
|
||||
"""
|
||||
# Ensure we have an image loaded and in the DB
|
||||
if not self.current_image or not self.current_image_id:
|
||||
logger.warning("Annotation drawn but no image loaded")
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"No Image",
|
||||
"Please load an image before drawing annotations.",
|
||||
)
|
||||
return
|
||||
|
||||
current_class = self.annotation_tools.get_current_class()
|
||||
|
||||
if not current_class:
|
||||
logger.warning("Annotation drawn but no object class selected")
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"No Class Selected",
|
||||
"Please select an object class before drawing annotations.",
|
||||
)
|
||||
return
|
||||
|
||||
if not points:
|
||||
logger.warning("Annotation drawn with no points, ignoring")
|
||||
return
|
||||
|
||||
# points are [(x_norm, y_norm), ...]
|
||||
xs = [p[0] for p in points]
|
||||
ys = [p[1] for p in points]
|
||||
x_min, x_max = min(xs), max(xs)
|
||||
y_min, y_max = min(ys), max(ys)
|
||||
|
||||
# Store segmentation mask in [y_norm, x_norm] format to match DB
|
||||
db_polyline = [[float(y), float(x)] for (x, y) in points]
|
||||
|
||||
try:
|
||||
annotation_id = self.db_manager.add_annotation(
|
||||
image_id=self.current_image_id,
|
||||
class_id=current_class["id"],
|
||||
bbox=(x_min, y_min, x_max, y_max),
|
||||
annotator="manual",
|
||||
segmentation_mask=db_polyline,
|
||||
verified=False,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Saved annotation (ID: {annotation_id}) for class "
|
||||
f"'{current_class['class_name']}' "
|
||||
f"Bounding box: ({x_min:.3f}, {y_min:.3f}) to ({x_max:.3f}, {y_max:.3f})\n"
|
||||
f"with {len(points)} polyline points"
|
||||
)
|
||||
|
||||
# Reload annotations from DB and redraw (respecting current class filter)
|
||||
self._load_annotations_for_current_image()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save annotation: {e}")
|
||||
QMessageBox.critical(self, "Error", f"Failed to save annotation:\n{str(e)}")
|
||||
|
||||
def _on_annotation_selected(self, annotation_ids):
|
||||
"""
|
||||
Handle selection of existing annotations on the canvas.
|
||||
|
||||
Args:
|
||||
annotation_ids: List of selected annotation IDs, or None/empty if cleared.
|
||||
"""
|
||||
if not annotation_ids:
|
||||
self.selected_annotation_ids = []
|
||||
self.annotation_tools.set_has_selected_annotation(False)
|
||||
logger.debug("Annotation selection cleared on canvas")
|
||||
return
|
||||
|
||||
# Normalize to a unique, sorted list of integer IDs
|
||||
ids = sorted({int(aid) for aid in annotation_ids if isinstance(aid, int)})
|
||||
self.selected_annotation_ids = ids
|
||||
self.annotation_tools.set_has_selected_annotation(bool(ids))
|
||||
logger.debug(f"Annotations selected on canvas: IDs={ids}")
|
||||
|
||||
def _on_simplify_on_finish_changed(self, enabled: bool):
|
||||
"""Update canvas simplify-on-finish flag from tools widget."""
|
||||
self.annotation_canvas.simplify_on_finish = enabled
|
||||
logger.debug(f"Annotation simplification on finish set to {enabled}")
|
||||
|
||||
def _on_simplify_epsilon_changed(self, epsilon: float):
|
||||
"""Update canvas RDP epsilon from tools widget."""
|
||||
self.annotation_canvas.simplify_epsilon = float(epsilon)
|
||||
logger.debug(f"Annotation simplification epsilon set to {epsilon}")
|
||||
|
||||
def _on_class_color_changed(self):
|
||||
"""
|
||||
Handle changes to the selected object's class color.
|
||||
|
||||
When the user updates a class color in the tools widget, reload the
|
||||
annotations for the current image so that all polylines are redrawn
|
||||
using the updated per-class colors.
|
||||
"""
|
||||
if not self.current_image_id:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Class color changed; reloading annotations for image ID {self.current_image_id}"
|
||||
)
|
||||
self._load_annotations_for_current_image()
|
||||
|
||||
def _on_class_selected(self, class_data):
|
||||
"""
|
||||
Handle when an object class is selected or cleared.
|
||||
|
||||
When a specific class is selected, only annotations of that class are drawn.
|
||||
When the selection is cleared ("-- Select Class --"), all annotations are shown.
|
||||
"""
|
||||
if class_data:
|
||||
logger.debug(f"Object class selected: {class_data['class_name']}")
|
||||
else:
|
||||
logger.debug(
|
||||
'No class selected ("-- Select Class --"), showing all annotations'
|
||||
)
|
||||
|
||||
# Changing the class filter invalidates any previous selection
|
||||
self.selected_annotation_ids = []
|
||||
self.annotation_tools.set_has_selected_annotation(False)
|
||||
|
||||
# Whenever the selection changes, update which annotations are visible
|
||||
self._redraw_annotations_for_current_filter()
|
||||
|
||||
def _on_clear_annotations(self):
|
||||
"""Handle clearing all annotations."""
|
||||
self.annotation_canvas.clear_annotations()
|
||||
# Clear in-memory state and selection, but keep DB entries unchanged
|
||||
self.current_annotations = []
|
||||
self.selected_annotation_ids = []
|
||||
self.annotation_tools.set_has_selected_annotation(False)
|
||||
logger.info("Cleared all annotations")
|
||||
|
||||
def _on_delete_selected_annotation(self):
|
||||
"""Handle deleting the currently selected annotation(s) (if any)."""
|
||||
if not self.selected_annotation_ids:
|
||||
QMessageBox.information(
|
||||
self,
|
||||
"No Selection",
|
||||
"No annotation is currently selected.",
|
||||
)
|
||||
return
|
||||
|
||||
count = len(self.selected_annotation_ids)
|
||||
if count == 1:
|
||||
question = "Are you sure you want to delete the selected annotation?"
|
||||
title = "Delete Annotation"
|
||||
else:
|
||||
question = (
|
||||
f"Are you sure you want to delete the {count} selected annotations?"
|
||||
)
|
||||
title = "Delete Annotations"
|
||||
|
||||
reply = QMessageBox.question(
|
||||
self,
|
||||
title,
|
||||
question,
|
||||
QMessageBox.Yes | QMessageBox.No,
|
||||
QMessageBox.No,
|
||||
)
|
||||
if reply != QMessageBox.Yes:
|
||||
return
|
||||
|
||||
failed_ids = []
|
||||
try:
|
||||
for ann_id in self.selected_annotation_ids:
|
||||
try:
|
||||
deleted = self.db_manager.delete_annotation(ann_id)
|
||||
if not deleted:
|
||||
failed_ids.append(ann_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete annotation ID {ann_id}: {e}")
|
||||
failed_ids.append(ann_id)
|
||||
|
||||
if failed_ids:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"Partial Failure",
|
||||
"Some annotations could not be deleted:\n"
|
||||
+ ", ".join(str(a) for a in failed_ids),
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Deleted {count} annotation(s): "
|
||||
+ ", ".join(str(a) for a in self.selected_annotation_ids)
|
||||
)
|
||||
|
||||
# Clear selection and reload annotations for the current image from DB
|
||||
self.selected_annotation_ids = []
|
||||
self.annotation_tools.set_has_selected_annotation(False)
|
||||
self._load_annotations_for_current_image()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete annotations: {e}")
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Error",
|
||||
f"Failed to delete annotations:\n{str(e)}",
|
||||
)
|
||||
|
||||
def _load_annotations_for_current_image(self):
|
||||
"""
|
||||
Load all annotations for the current image from the database and
|
||||
redraw them on the canvas, honoring the currently selected class
|
||||
filter (if any).
|
||||
"""
|
||||
if not self.current_image_id:
|
||||
self.current_annotations = []
|
||||
self.annotation_canvas.clear_annotations()
|
||||
self.selected_annotation_ids = []
|
||||
self.annotation_tools.set_has_selected_annotation(False)
|
||||
return
|
||||
|
||||
try:
|
||||
self.current_annotations = self.db_manager.get_annotations_for_image(
|
||||
self.current_image_id
|
||||
)
|
||||
# New annotations loaded; reset any selection
|
||||
self.selected_annotation_ids = []
|
||||
self.annotation_tools.set_has_selected_annotation(False)
|
||||
self._redraw_annotations_for_current_filter()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to load annotations for image {self.current_image_id}: {e}"
|
||||
)
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Error",
|
||||
f"Failed to load annotations for this image:\n{str(e)}",
|
||||
)
|
||||
|
||||
def _redraw_annotations_for_current_filter(self):
|
||||
"""
|
||||
Redraw annotations for the current image, optionally filtered by the
|
||||
currently selected object class.
|
||||
"""
|
||||
# Clear current on-canvas annotations but keep the image
|
||||
self.annotation_canvas.clear_annotations()
|
||||
|
||||
if not self.current_annotations:
|
||||
return
|
||||
|
||||
current_class = self.annotation_tools.get_current_class()
|
||||
selected_class_id = current_class["id"] if current_class else None
|
||||
|
||||
drawn_count = 0
|
||||
for ann in self.current_annotations:
|
||||
# Filter by class if one is selected
|
||||
if (
|
||||
selected_class_id is not None
|
||||
and ann.get("class_id") != selected_class_id
|
||||
):
|
||||
continue
|
||||
|
||||
if ann.get("segmentation_mask"):
|
||||
polyline = ann["segmentation_mask"]
|
||||
color = ann.get("class_color", "#FF0000")
|
||||
|
||||
self.annotation_canvas.draw_saved_polyline(
|
||||
polyline,
|
||||
color,
|
||||
width=3,
|
||||
annotation_id=ann["id"],
|
||||
)
|
||||
self.annotation_canvas.draw_saved_bbox(
|
||||
[ann["x_min"], ann["y_min"], ann["x_max"], ann["y_max"]],
|
||||
color,
|
||||
width=3,
|
||||
)
|
||||
drawn_count += 1
|
||||
|
||||
logger.info(
|
||||
f"Displayed {drawn_count} annotation(s) for current image with "
|
||||
f"{'no class filter' if selected_class_id is None else f'class_id={selected_class_id}'}"
|
||||
)
|
||||
|
||||
def _restore_state(self):
|
||||
"""Restore splitter positions from settings."""
|
||||
settings = QSettings("microscopy_app", "object_detection")
|
||||
|
||||
# Restore main splitter state
|
||||
main_state = settings.value("annotation_tab/main_splitter_state")
|
||||
if main_state:
|
||||
self.main_splitter.restoreState(main_state)
|
||||
logger.debug("Restored main splitter state")
|
||||
|
||||
# Restore left splitter state
|
||||
left_state = settings.value("annotation_tab/left_splitter_state")
|
||||
if left_state:
|
||||
self.left_splitter.restoreState(left_state)
|
||||
logger.debug("Restored left splitter state")
|
||||
|
||||
# Restore right splitter state
|
||||
right_state = settings.value("annotation_tab/right_splitter_state")
|
||||
if right_state:
|
||||
self.right_splitter.restoreState(right_state)
|
||||
logger.debug("Restored right splitter state")
|
||||
|
||||
def save_state(self):
|
||||
"""Save splitter positions to settings."""
|
||||
settings = QSettings("microscopy_app", "object_detection")
|
||||
|
||||
# Save main splitter state
|
||||
settings.setValue(
|
||||
"annotation_tab/main_splitter_state", self.main_splitter.saveState()
|
||||
)
|
||||
|
||||
# Save left splitter state
|
||||
settings.setValue(
|
||||
"annotation_tab/left_splitter_state", self.left_splitter.saveState()
|
||||
)
|
||||
|
||||
# Save right splitter state
|
||||
settings.setValue(
|
||||
"annotation_tab/right_splitter_state", self.right_splitter.saveState()
|
||||
)
|
||||
|
||||
logger.debug("Saved annotation tab splitter states")
|
||||
|
||||
def refresh(self):
|
||||
"""Refresh the tab."""
|
||||
pass
|
||||
|
||||
@@ -20,12 +20,14 @@ from PySide6.QtWidgets import (
|
||||
)
|
||||
from PySide6.QtCore import Qt, QThread, Signal
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from src.database.db_manager import DatabaseManager
|
||||
from src.utils.config_manager import ConfigManager
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.file_utils import get_image_files
|
||||
from src.model.inference import InferenceEngine
|
||||
from src.utils.image import Image
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -147,30 +149,66 @@ class DetectionTab(QWidget):
|
||||
self.model_combo.currentIndexChanged.connect(self._on_model_changed)
|
||||
|
||||
def _load_models(self):
|
||||
"""Load available models from database."""
|
||||
"""Load available models from database and local storage."""
|
||||
try:
|
||||
models = self.db_manager.get_models()
|
||||
self.model_combo.clear()
|
||||
models = self.db_manager.get_models()
|
||||
has_models = False
|
||||
|
||||
if not models:
|
||||
self.model_combo.addItem("No models available", None)
|
||||
self._set_buttons_enabled(False)
|
||||
return
|
||||
known_paths = set()
|
||||
|
||||
# Add base model option
|
||||
# Add base model option first (always available)
|
||||
base_model = self.config_manager.get(
|
||||
"models.default_base_model", "yolov8s-seg.pt"
|
||||
)
|
||||
self.model_combo.addItem(
|
||||
f"Base Model ({base_model})", {"id": 0, "path": base_model}
|
||||
)
|
||||
if base_model:
|
||||
base_data = {
|
||||
"id": 0,
|
||||
"path": base_model,
|
||||
"model_name": Path(base_model).stem or "Base Model",
|
||||
"model_version": "pretrained",
|
||||
"base_model": base_model,
|
||||
"source": "base",
|
||||
}
|
||||
self.model_combo.addItem(f"Base Model ({base_model})", base_data)
|
||||
known_paths.add(self._normalize_model_path(base_model))
|
||||
has_models = True
|
||||
|
||||
# Add trained models
|
||||
# Add trained models from database
|
||||
for model in models:
|
||||
display_name = f"{model['model_name']} v{model['model_version']}"
|
||||
self.model_combo.addItem(display_name, model)
|
||||
model_data = {**model, "path": model.get("model_path")}
|
||||
normalized = self._normalize_model_path(model_data.get("path"))
|
||||
if normalized:
|
||||
known_paths.add(normalized)
|
||||
self.model_combo.addItem(display_name, model_data)
|
||||
has_models = True
|
||||
|
||||
self._set_buttons_enabled(True)
|
||||
# Discover local model files not yet in the database
|
||||
local_models = self._discover_local_models()
|
||||
for model_path in local_models:
|
||||
normalized = self._normalize_model_path(model_path)
|
||||
if normalized in known_paths:
|
||||
continue
|
||||
|
||||
display_name = f"Local Model ({Path(model_path).stem})"
|
||||
model_data = {
|
||||
"id": None,
|
||||
"path": str(model_path),
|
||||
"model_name": Path(model_path).stem,
|
||||
"model_version": "local",
|
||||
"base_model": Path(model_path).stem,
|
||||
"source": "local",
|
||||
}
|
||||
self.model_combo.addItem(display_name, model_data)
|
||||
known_paths.add(normalized)
|
||||
has_models = True
|
||||
|
||||
if not has_models:
|
||||
self.model_combo.addItem("No models available", None)
|
||||
self._set_buttons_enabled(False)
|
||||
else:
|
||||
self._set_buttons_enabled(True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading models: {e}")
|
||||
@@ -199,7 +237,7 @@ class DetectionTab(QWidget):
|
||||
self,
|
||||
"Select Image",
|
||||
start_dir,
|
||||
"Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)",
|
||||
"Images (*" + " *".join(Image.SUPPORTED_EXTENSIONS) + ")",
|
||||
)
|
||||
|
||||
if not file_path:
|
||||
@@ -249,25 +287,39 @@ class DetectionTab(QWidget):
|
||||
QMessageBox.warning(self, "No Model", "Please select a model first.")
|
||||
return
|
||||
|
||||
model_path = model_data["path"]
|
||||
model_id = model_data["id"]
|
||||
model_path = model_data.get("path")
|
||||
if not model_path:
|
||||
QMessageBox.warning(
|
||||
self, "Invalid Model", "Selected model is missing a file path."
|
||||
)
|
||||
return
|
||||
|
||||
# Ensure we have a valid model ID (create entry for base model if needed)
|
||||
if model_id == 0:
|
||||
# Create database entry for base model
|
||||
base_model = self.config_manager.get(
|
||||
"models.default_base_model", "yolov8s-seg.pt"
|
||||
)
|
||||
model_id = self.db_manager.add_model(
|
||||
model_name="Base Model",
|
||||
model_version="pretrained",
|
||||
model_path=base_model,
|
||||
base_model=base_model,
|
||||
if not Path(model_path).exists():
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Model Not Found",
|
||||
f"The selected model file could not be found:\n{model_path}",
|
||||
)
|
||||
return
|
||||
|
||||
model_id = model_data.get("id")
|
||||
|
||||
# Ensure we have a database entry for the selected model
|
||||
if model_id in (None, 0):
|
||||
model_id = self._ensure_model_record(model_data)
|
||||
if not model_id:
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Model Registration Failed",
|
||||
"Unable to register the selected model in the database.",
|
||||
)
|
||||
return
|
||||
|
||||
normalized_model_path = self._normalize_model_path(model_path) or model_path
|
||||
|
||||
# Create inference engine
|
||||
self.inference_engine = InferenceEngine(
|
||||
model_path, self.db_manager, model_id
|
||||
normalized_model_path, self.db_manager, model_id
|
||||
)
|
||||
|
||||
# Get confidence threshold
|
||||
@@ -338,6 +390,76 @@ class DetectionTab(QWidget):
|
||||
self.batch_btn.setEnabled(enabled)
|
||||
self.model_combo.setEnabled(enabled)
|
||||
|
||||
def _discover_local_models(self) -> list:
|
||||
"""Scan the models directory for standalone .pt files."""
|
||||
models_dir = self.config_manager.get_models_directory()
|
||||
if not models_dir:
|
||||
return []
|
||||
|
||||
models_path = Path(models_dir)
|
||||
if not models_path.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
return sorted(
|
||||
[p for p in models_path.rglob("*.pt") if p.is_file()],
|
||||
key=lambda p: str(p).lower(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error discovering local models: {e}")
|
||||
return []
|
||||
|
||||
def _normalize_model_path(self, path_value) -> str:
|
||||
"""Return a normalized absolute path string for comparison."""
|
||||
if not path_value:
|
||||
return ""
|
||||
try:
|
||||
return str(Path(path_value).resolve())
|
||||
except Exception:
|
||||
return str(path_value)
|
||||
|
||||
def _ensure_model_record(self, model_data: dict) -> Optional[int]:
|
||||
"""Ensure a database record exists for the selected model."""
|
||||
model_path = model_data.get("path")
|
||||
if not model_path:
|
||||
return None
|
||||
|
||||
normalized_target = self._normalize_model_path(model_path)
|
||||
|
||||
try:
|
||||
existing_models = self.db_manager.get_models()
|
||||
for model in existing_models:
|
||||
existing_path = model.get("model_path")
|
||||
if not existing_path:
|
||||
continue
|
||||
normalized_existing = self._normalize_model_path(existing_path)
|
||||
if (
|
||||
normalized_existing == normalized_target
|
||||
or existing_path == model_path
|
||||
):
|
||||
return model["id"]
|
||||
|
||||
model_name = (
|
||||
model_data.get("model_name") or Path(model_path).stem or "Custom Model"
|
||||
)
|
||||
model_version = (
|
||||
model_data.get("model_version") or model_data.get("source") or "local"
|
||||
)
|
||||
base_model = model_data.get(
|
||||
"base_model",
|
||||
self.config_manager.get("models.default_base_model", "yolov8s-seg.pt"),
|
||||
)
|
||||
|
||||
return self.db_manager.add_model(
|
||||
model_name=model_name,
|
||||
model_version=model_version,
|
||||
model_path=normalized_target,
|
||||
base_model=base_model,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ensure model record for {model_path}: {e}")
|
||||
return None
|
||||
|
||||
def refresh(self):
|
||||
"""Refresh the tab."""
|
||||
self._load_models()
|
||||
|
||||
@@ -1,15 +1,39 @@
|
||||
"""
|
||||
Results tab for the microscopy object detection application.
|
||||
Results tab for browsing stored detections and visualizing overlays.
|
||||
"""
|
||||
|
||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from PySide6.QtWidgets import (
|
||||
QWidget,
|
||||
QVBoxLayout,
|
||||
QHBoxLayout,
|
||||
QLabel,
|
||||
QGroupBox,
|
||||
QPushButton,
|
||||
QSplitter,
|
||||
QTableWidget,
|
||||
QTableWidgetItem,
|
||||
QHeaderView,
|
||||
QAbstractItemView,
|
||||
QMessageBox,
|
||||
QCheckBox,
|
||||
)
|
||||
from PySide6.QtCore import Qt
|
||||
|
||||
from src.database.db_manager import DatabaseManager
|
||||
from src.utils.config_manager import ConfigManager
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.image import Image, ImageLoadError
|
||||
from src.gui.widgets import AnnotationCanvasWidget
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ResultsTab(QWidget):
|
||||
"""Results tab placeholder."""
|
||||
"""Results tab showing detection history and preview overlays."""
|
||||
|
||||
def __init__(
|
||||
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
||||
@@ -18,29 +42,398 @@ class ResultsTab(QWidget):
|
||||
self.db_manager = db_manager
|
||||
self.config_manager = config_manager
|
||||
|
||||
self.detection_summary: List[Dict] = []
|
||||
self.current_selection: Optional[Dict] = None
|
||||
self.current_image: Optional[Image] = None
|
||||
self.current_detections: List[Dict] = []
|
||||
self._image_path_cache: Dict[str, str] = {}
|
||||
|
||||
self._setup_ui()
|
||||
self.refresh()
|
||||
|
||||
def _setup_ui(self):
|
||||
"""Setup user interface."""
|
||||
layout = QVBoxLayout()
|
||||
|
||||
group = QGroupBox("Results")
|
||||
group_layout = QVBoxLayout()
|
||||
label = QLabel(
|
||||
"Results viewer will be implemented here.\n\n"
|
||||
"Features:\n"
|
||||
"- Detection history browser\n"
|
||||
"- Advanced filtering\n"
|
||||
"- Statistics dashboard\n"
|
||||
"- Export functionality"
|
||||
)
|
||||
group_layout.addWidget(label)
|
||||
group.setLayout(group_layout)
|
||||
# Splitter for list + preview
|
||||
splitter = QSplitter(Qt.Horizontal)
|
||||
|
||||
layout.addWidget(group)
|
||||
layout.addStretch()
|
||||
# Left pane: detection list
|
||||
left_container = QWidget()
|
||||
left_layout = QVBoxLayout()
|
||||
left_layout.setContentsMargins(0, 0, 0, 0)
|
||||
|
||||
controls_layout = QHBoxLayout()
|
||||
self.refresh_btn = QPushButton("Refresh")
|
||||
self.refresh_btn.clicked.connect(self.refresh)
|
||||
controls_layout.addWidget(self.refresh_btn)
|
||||
controls_layout.addStretch()
|
||||
left_layout.addLayout(controls_layout)
|
||||
|
||||
self.results_table = QTableWidget(0, 5)
|
||||
self.results_table.setHorizontalHeaderLabels(
|
||||
["Image", "Model", "Detections", "Classes", "Last Updated"]
|
||||
)
|
||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
||||
0, QHeaderView.Stretch
|
||||
)
|
||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
||||
1, QHeaderView.Stretch
|
||||
)
|
||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
||||
2, QHeaderView.ResizeToContents
|
||||
)
|
||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
||||
3, QHeaderView.Stretch
|
||||
)
|
||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
||||
4, QHeaderView.ResizeToContents
|
||||
)
|
||||
self.results_table.setSelectionBehavior(QAbstractItemView.SelectRows)
|
||||
self.results_table.setSelectionMode(QAbstractItemView.SingleSelection)
|
||||
self.results_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
|
||||
self.results_table.itemSelectionChanged.connect(self._on_result_selected)
|
||||
|
||||
left_layout.addWidget(self.results_table)
|
||||
left_container.setLayout(left_layout)
|
||||
|
||||
# Right pane: preview canvas and controls
|
||||
right_container = QWidget()
|
||||
right_layout = QVBoxLayout()
|
||||
right_layout.setContentsMargins(0, 0, 0, 0)
|
||||
|
||||
preview_group = QGroupBox("Detection Preview")
|
||||
preview_layout = QVBoxLayout()
|
||||
|
||||
self.preview_canvas = AnnotationCanvasWidget()
|
||||
self.preview_canvas.set_polyline_enabled(False)
|
||||
self.preview_canvas.set_show_bboxes(True)
|
||||
preview_layout.addWidget(self.preview_canvas)
|
||||
|
||||
toggles_layout = QHBoxLayout()
|
||||
self.show_masks_checkbox = QCheckBox("Show Masks")
|
||||
self.show_masks_checkbox.setChecked(False)
|
||||
self.show_masks_checkbox.stateChanged.connect(self._apply_detection_overlays)
|
||||
self.show_bboxes_checkbox = QCheckBox("Show Bounding Boxes")
|
||||
self.show_bboxes_checkbox.setChecked(True)
|
||||
self.show_bboxes_checkbox.stateChanged.connect(self._toggle_bboxes)
|
||||
self.show_confidence_checkbox = QCheckBox("Show Confidence")
|
||||
self.show_confidence_checkbox.setChecked(False)
|
||||
self.show_confidence_checkbox.stateChanged.connect(
|
||||
self._apply_detection_overlays
|
||||
)
|
||||
toggles_layout.addWidget(self.show_masks_checkbox)
|
||||
toggles_layout.addWidget(self.show_bboxes_checkbox)
|
||||
toggles_layout.addWidget(self.show_confidence_checkbox)
|
||||
toggles_layout.addStretch()
|
||||
preview_layout.addLayout(toggles_layout)
|
||||
|
||||
self.summary_label = QLabel("Select a detection result to preview.")
|
||||
self.summary_label.setWordWrap(True)
|
||||
preview_layout.addWidget(self.summary_label)
|
||||
|
||||
preview_group.setLayout(preview_layout)
|
||||
right_layout.addWidget(preview_group)
|
||||
right_container.setLayout(right_layout)
|
||||
|
||||
splitter.addWidget(left_container)
|
||||
splitter.addWidget(right_container)
|
||||
splitter.setStretchFactor(0, 1)
|
||||
splitter.setStretchFactor(1, 2)
|
||||
|
||||
layout.addWidget(splitter)
|
||||
self.setLayout(layout)
|
||||
|
||||
def refresh(self):
|
||||
"""Refresh the tab."""
|
||||
pass
|
||||
"""Refresh the detection list and preview."""
|
||||
self._load_detection_summary()
|
||||
self._populate_results_table()
|
||||
self.current_selection = None
|
||||
self.current_image = None
|
||||
self.current_detections = []
|
||||
self.preview_canvas.clear()
|
||||
self.summary_label.setText("Select a detection result to preview.")
|
||||
|
||||
def _load_detection_summary(self):
|
||||
"""Load latest detection summaries grouped by image + model."""
|
||||
try:
|
||||
detections = self.db_manager.get_detections(limit=500)
|
||||
summary_map: Dict[tuple, Dict] = {}
|
||||
|
||||
for det in detections:
|
||||
key = (det["image_id"], det["model_id"])
|
||||
metadata = det.get("metadata") or {}
|
||||
entry = summary_map.setdefault(
|
||||
key,
|
||||
{
|
||||
"image_id": det["image_id"],
|
||||
"model_id": det["model_id"],
|
||||
"image_path": det.get("image_path"),
|
||||
"image_filename": det.get("image_filename")
|
||||
or det.get("image_path"),
|
||||
"model_name": det.get("model_name", ""),
|
||||
"model_version": det.get("model_version", ""),
|
||||
"last_detected": det.get("detected_at"),
|
||||
"count": 0,
|
||||
"classes": set(),
|
||||
"source_path": metadata.get("source_path"),
|
||||
"repository_root": metadata.get("repository_root"),
|
||||
},
|
||||
)
|
||||
|
||||
entry["count"] += 1
|
||||
if det.get("detected_at") and (
|
||||
not entry.get("last_detected")
|
||||
or str(det.get("detected_at")) > str(entry.get("last_detected"))
|
||||
):
|
||||
entry["last_detected"] = det.get("detected_at")
|
||||
if det.get("class_name"):
|
||||
entry["classes"].add(det["class_name"])
|
||||
if metadata.get("source_path") and not entry.get("source_path"):
|
||||
entry["source_path"] = metadata.get("source_path")
|
||||
if metadata.get("repository_root") and not entry.get("repository_root"):
|
||||
entry["repository_root"] = metadata.get("repository_root")
|
||||
|
||||
self.detection_summary = sorted(
|
||||
summary_map.values(),
|
||||
key=lambda x: str(x.get("last_detected") or ""),
|
||||
reverse=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load detection summary: {e}")
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Error",
|
||||
f"Failed to load detection results:\n{str(e)}",
|
||||
)
|
||||
self.detection_summary = []
|
||||
|
||||
def _populate_results_table(self):
|
||||
"""Populate the table widget with detection summaries."""
|
||||
self.results_table.setRowCount(len(self.detection_summary))
|
||||
|
||||
for row, entry in enumerate(self.detection_summary):
|
||||
model_label = f"{entry['model_name']} {entry['model_version']}".strip()
|
||||
class_list = (
|
||||
", ".join(sorted(entry["classes"])) if entry["classes"] else "-"
|
||||
)
|
||||
|
||||
items = [
|
||||
QTableWidgetItem(entry.get("image_filename", "")),
|
||||
QTableWidgetItem(model_label),
|
||||
QTableWidgetItem(str(entry.get("count", 0))),
|
||||
QTableWidgetItem(class_list),
|
||||
QTableWidgetItem(str(entry.get("last_detected") or "")),
|
||||
]
|
||||
|
||||
for col, item in enumerate(items):
|
||||
item.setData(Qt.UserRole, row)
|
||||
self.results_table.setItem(row, col, item)
|
||||
|
||||
self.results_table.clearSelection()
|
||||
|
||||
def _on_result_selected(self):
|
||||
"""Handle selection changes in the detection table."""
|
||||
selected_items = self.results_table.selectedItems()
|
||||
if not selected_items:
|
||||
return
|
||||
|
||||
row = selected_items[0].data(Qt.UserRole)
|
||||
if row is None or row >= len(self.detection_summary):
|
||||
return
|
||||
|
||||
entry = self.detection_summary[row]
|
||||
if (
|
||||
self.current_selection
|
||||
and self.current_selection.get("image_id") == entry["image_id"]
|
||||
and self.current_selection.get("model_id") == entry["model_id"]
|
||||
):
|
||||
return
|
||||
|
||||
self.current_selection = entry
|
||||
|
||||
image_path = self._resolve_image_path(entry)
|
||||
if not image_path:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"Image Not Found",
|
||||
"Unable to locate the image file for this detection.",
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
self.current_image = Image(image_path)
|
||||
self.preview_canvas.load_image(self.current_image)
|
||||
except ImageLoadError as e:
|
||||
logger.error(f"Failed to load image '{image_path}': {e}")
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Image Error",
|
||||
f"Failed to load image for preview:\n{str(e)}",
|
||||
)
|
||||
return
|
||||
|
||||
self._load_detections_for_selection(entry)
|
||||
self._apply_detection_overlays()
|
||||
self._update_summary_label(entry)
|
||||
|
||||
def _load_detections_for_selection(self, entry: Dict):
|
||||
"""Load detection records for the selected image/model pair."""
|
||||
self.current_detections = []
|
||||
if not entry:
|
||||
return
|
||||
|
||||
try:
|
||||
filters = {"image_id": entry["image_id"], "model_id": entry["model_id"]}
|
||||
self.current_detections = self.db_manager.get_detections(filters)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load detections for preview: {e}")
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Error",
|
||||
f"Failed to load detections for this image:\n{str(e)}",
|
||||
)
|
||||
self.current_detections = []
|
||||
|
||||
def _apply_detection_overlays(self):
|
||||
"""Draw detections onto the preview canvas based on current toggles."""
|
||||
self.preview_canvas.clear_annotations()
|
||||
self.preview_canvas.set_show_bboxes(self.show_bboxes_checkbox.isChecked())
|
||||
|
||||
if not self.current_detections or not self.current_image:
|
||||
return
|
||||
|
||||
for det in self.current_detections:
|
||||
color = self._get_class_color(det.get("class_name"))
|
||||
|
||||
if self.show_masks_checkbox.isChecked() and det.get("segmentation_mask"):
|
||||
mask_points = self._convert_mask(det["segmentation_mask"])
|
||||
if mask_points:
|
||||
self.preview_canvas.draw_saved_polyline(mask_points, color)
|
||||
|
||||
bbox = [
|
||||
det.get("x_min"),
|
||||
det.get("y_min"),
|
||||
det.get("x_max"),
|
||||
det.get("y_max"),
|
||||
]
|
||||
if all(v is not None for v in bbox):
|
||||
label = None
|
||||
if self.show_confidence_checkbox.isChecked():
|
||||
confidence = det.get("confidence")
|
||||
if confidence is not None:
|
||||
label = f"{confidence:.2f}"
|
||||
self.preview_canvas.draw_saved_bbox(bbox, color, label=label)
|
||||
|
||||
def _convert_mask(self, mask_points: List[List[float]]) -> List[List[float]]:
|
||||
"""Convert stored [x, y] masks to [y, x] format for the canvas."""
|
||||
converted = []
|
||||
for point in mask_points:
|
||||
if len(point) >= 2:
|
||||
x, y = point[0], point[1]
|
||||
converted.append([y, x])
|
||||
return converted
|
||||
|
||||
def _toggle_bboxes(self):
|
||||
"""Update bounding box visibility on the canvas."""
|
||||
self.preview_canvas.set_show_bboxes(self.show_bboxes_checkbox.isChecked())
|
||||
# Re-render to respect show/hide when toggled
|
||||
self._apply_detection_overlays()
|
||||
|
||||
def _update_summary_label(self, entry: Dict):
|
||||
"""Display textual summary for the selected detection run."""
|
||||
classes = ", ".join(sorted(entry.get("classes", []))) or "-"
|
||||
summary_text = (
|
||||
f"Image: {entry.get('image_filename', 'unknown')}\n"
|
||||
f"Model: {entry.get('model_name', '')} {entry.get('model_version', '')}\n"
|
||||
f"Detections: {entry.get('count', 0)}\n"
|
||||
f"Classes: {classes}\n"
|
||||
f"Last Updated: {entry.get('last_detected', 'n/a')}"
|
||||
)
|
||||
self.summary_label.setText(summary_text)
|
||||
|
||||
def _resolve_image_path(self, entry: Dict) -> Optional[str]:
|
||||
"""Resolve an image path using metadata, cache, and repository hints."""
|
||||
relative_path = entry.get("image_path") if entry else None
|
||||
cache_key = relative_path or entry.get("source_path")
|
||||
if cache_key and cache_key in self._image_path_cache:
|
||||
cached = Path(self._image_path_cache[cache_key])
|
||||
if cached.exists():
|
||||
return self._image_path_cache[cache_key]
|
||||
del self._image_path_cache[cache_key]
|
||||
|
||||
candidates = []
|
||||
source_path = entry.get("source_path") if entry else None
|
||||
if source_path:
|
||||
candidates.append(Path(source_path))
|
||||
|
||||
repo_roots = []
|
||||
if entry.get("repository_root"):
|
||||
repo_roots.append(entry["repository_root"])
|
||||
config_repo = self.config_manager.get_image_repository_path()
|
||||
if config_repo:
|
||||
repo_roots.append(config_repo)
|
||||
|
||||
for root in repo_roots:
|
||||
if relative_path:
|
||||
candidates.append(Path(root) / relative_path)
|
||||
|
||||
if relative_path:
|
||||
candidates.append(Path(relative_path))
|
||||
|
||||
for candidate in candidates:
|
||||
try:
|
||||
if candidate and candidate.exists():
|
||||
resolved = str(candidate.resolve())
|
||||
if cache_key:
|
||||
self._image_path_cache[cache_key] = resolved
|
||||
return resolved
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Fallback: search by filename in known roots
|
||||
filename = Path(relative_path).name if relative_path else None
|
||||
if filename:
|
||||
search_roots = [Path(root) for root in repo_roots if root]
|
||||
if not search_roots:
|
||||
search_roots = [Path("data")]
|
||||
match = self._search_in_roots(filename, search_roots)
|
||||
if match:
|
||||
resolved = str(match.resolve())
|
||||
if cache_key:
|
||||
self._image_path_cache[cache_key] = resolved
|
||||
return resolved
|
||||
|
||||
return None
|
||||
|
||||
def _search_in_roots(self, filename: str, roots: List[Path]) -> Optional[Path]:
|
||||
"""Search for a file name within a list of root directories."""
|
||||
for root in roots:
|
||||
try:
|
||||
if not root.exists():
|
||||
continue
|
||||
for candidate in root.rglob(filename):
|
||||
return candidate
|
||||
except Exception as e:
|
||||
logger.debug(f"Error searching for {filename} in {root}: {e}")
|
||||
return None
|
||||
|
||||
def _get_class_color(self, class_name: Optional[str]) -> str:
|
||||
"""Return consistent color hex for a class name."""
|
||||
if not class_name:
|
||||
return "#FF6B6B"
|
||||
|
||||
color_map = self.config_manager.get_bbox_colors()
|
||||
if class_name in color_map:
|
||||
return color_map[class_name]
|
||||
|
||||
# Deterministic fallback color based on hash
|
||||
palette = [
|
||||
"#FF6B6B",
|
||||
"#4ECDC4",
|
||||
"#FFD166",
|
||||
"#1D3557",
|
||||
"#F4A261",
|
||||
"#E76F51",
|
||||
]
|
||||
return palette[hash(class_name) % len(palette)]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,7 @@
|
||||
"""GUI widgets for the microscopy object detection application."""
|
||||
|
||||
from src.gui.widgets.image_display_widget import ImageDisplayWidget
|
||||
from src.gui.widgets.annotation_canvas_widget import AnnotationCanvasWidget
|
||||
from src.gui.widgets.annotation_tools_widget import AnnotationToolsWidget
|
||||
|
||||
__all__ = ["ImageDisplayWidget"]
|
||||
__all__ = ["ImageDisplayWidget", "AnnotationCanvasWidget", "AnnotationToolsWidget"]
|
||||
|
||||
931
src/gui/widgets/annotation_canvas_widget.py
Normal file
931
src/gui/widgets/annotation_canvas_widget.py
Normal file
@@ -0,0 +1,931 @@
|
||||
"""
|
||||
Annotation canvas widget for drawing annotations on images.
|
||||
Currently supports polyline drawing tool with color selection for manual annotation.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QScrollArea
|
||||
from PySide6.QtGui import (
|
||||
QPixmap,
|
||||
QImage,
|
||||
QPainter,
|
||||
QPen,
|
||||
QColor,
|
||||
QKeyEvent,
|
||||
QMouseEvent,
|
||||
QPaintEvent,
|
||||
QPolygonF,
|
||||
)
|
||||
from PySide6.QtCore import Qt, QEvent, Signal, QPoint, QPointF, QRect
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from src.utils.image import Image, ImageLoadError
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def perpendicular_distance(
|
||||
point: Tuple[float, float],
|
||||
start: Tuple[float, float],
|
||||
end: Tuple[float, float],
|
||||
) -> float:
|
||||
"""Perpendicular distance from `point` to the line defined by `start`->`end`."""
|
||||
(x, y), (x1, y1), (x2, y2) = point, start, end
|
||||
dx = x2 - x1
|
||||
dy = y2 - y1
|
||||
if dx == 0.0 and dy == 0.0:
|
||||
return math.hypot(x - x1, y - y1)
|
||||
num = abs(dy * x - dx * y + x2 * y1 - y2 * x1)
|
||||
den = math.hypot(dx, dy)
|
||||
return num / den
|
||||
|
||||
|
||||
def rdp(points: List[Tuple[float, float]], epsilon: float) -> List[Tuple[float, float]]:
|
||||
"""
|
||||
Recursive Ramer-Douglas-Peucker (RDP) polyline simplification.
|
||||
|
||||
Args:
|
||||
points: List of (x, y) points.
|
||||
epsilon: Maximum allowed perpendicular distance in pixels.
|
||||
|
||||
Returns:
|
||||
Simplified list of (x, y) points including first and last.
|
||||
"""
|
||||
if len(points) <= 2:
|
||||
return list(points)
|
||||
|
||||
start = points[0]
|
||||
end = points[-1]
|
||||
max_dist = -1.0
|
||||
index = -1
|
||||
|
||||
for i in range(1, len(points) - 1):
|
||||
d = perpendicular_distance(points[i], start, end)
|
||||
if d > max_dist:
|
||||
max_dist = d
|
||||
index = i
|
||||
|
||||
if max_dist > epsilon:
|
||||
# Recursive split
|
||||
left = rdp(points[: index + 1], epsilon)
|
||||
right = rdp(points[index:], epsilon)
|
||||
# Concatenate but avoid duplicate at split point
|
||||
return left[:-1] + right
|
||||
|
||||
# Keep only start and end
|
||||
return [start, end]
|
||||
|
||||
|
||||
def simplify_polyline(
|
||||
points: List[Tuple[float, float]], epsilon: float
|
||||
) -> List[Tuple[float, float]]:
|
||||
"""
|
||||
Simplify a polyline with RDP while preserving closure semantics.
|
||||
|
||||
If the polyline is closed (first == last), the duplicate last point is removed
|
||||
before simplification and then re-added after simplification.
|
||||
"""
|
||||
if not points:
|
||||
return []
|
||||
|
||||
pts = [(float(x), float(y)) for x, y in points]
|
||||
closed = False
|
||||
|
||||
if len(pts) >= 2 and pts[0] == pts[-1]:
|
||||
closed = True
|
||||
pts = pts[:-1] # remove duplicate last for simplification
|
||||
|
||||
if len(pts) <= 2:
|
||||
simplified = list(pts)
|
||||
else:
|
||||
simplified = rdp(pts, epsilon)
|
||||
|
||||
if closed and simplified:
|
||||
if simplified[0] != simplified[-1]:
|
||||
simplified.append(simplified[0])
|
||||
|
||||
return simplified
|
||||
|
||||
|
||||
class AnnotationCanvasWidget(QWidget):
|
||||
"""
|
||||
Widget for displaying images and drawing annotations with zoom and drawing tools.
|
||||
|
||||
Features:
|
||||
- Display images with zoom functionality
|
||||
- Polyline tool for drawing annotations
|
||||
- Configurable pen color and width
|
||||
- Mouse-based drawing interface
|
||||
- Zoom in/out with mouse wheel and keyboard
|
||||
|
||||
Signals:
|
||||
zoom_changed: Emitted when zoom level changes (float zoom_scale)
|
||||
annotation_drawn: Emitted when a new stroke is completed (list of points)
|
||||
"""
|
||||
|
||||
zoom_changed = Signal(float)
|
||||
annotation_drawn = Signal(list) # List of (x, y) points in normalized coordinates
|
||||
# Emitted when the user selects an existing polyline on the canvas.
|
||||
# Carries the associated annotation_id (int) or None if selection is cleared
|
||||
annotation_selected = Signal(object)
|
||||
|
||||
def __init__(self, parent=None):
|
||||
"""Initialize the annotation canvas widget."""
|
||||
super().__init__(parent)
|
||||
|
||||
self.current_image = None
|
||||
self.original_pixmap = None
|
||||
self.annotation_pixmap = None # Overlay for annotations
|
||||
self.zoom_scale = 1.0
|
||||
self.zoom_min = 0.1
|
||||
self.zoom_max = 10.0
|
||||
self.zoom_step = 0.1
|
||||
self.zoom_wheel_step = 0.15
|
||||
|
||||
# Drawing / interaction state
|
||||
self.is_drawing = False
|
||||
self.polyline_enabled = False
|
||||
self.polyline_pen_color = QColor(255, 0, 0, 128) # Default red with 50% alpha
|
||||
self.polyline_pen_width = 3
|
||||
self.show_bboxes: bool = True # Control visibility of bounding boxes
|
||||
|
||||
# Current stroke and stored polylines (in image coordinates, pixel units)
|
||||
self.current_stroke: List[Tuple[float, float]] = []
|
||||
self.polylines: List[List[Tuple[float, float]]] = []
|
||||
self.stroke_meta: List[Dict[str, Any]] = [] # per-polyline style (color, width)
|
||||
# Optional DB annotation_id for each stored polyline (None for temporary / unsaved)
|
||||
self.polyline_annotation_ids: List[Optional[int]] = []
|
||||
# Indices in self.polylines of the currently selected polylines (multi-select)
|
||||
self.selected_polyline_indices: List[int] = []
|
||||
|
||||
# Stored bounding boxes in normalized coordinates (x_min, y_min, x_max, y_max)
|
||||
self.bboxes: List[List[float]] = []
|
||||
self.bbox_meta: List[Dict[str, Any]] = [] # per-bbox style (color, width)
|
||||
|
||||
# Legacy collection of strokes in normalized coordinates (kept for API compatibility)
|
||||
self.all_strokes: List[dict] = []
|
||||
|
||||
# RDP simplification parameters (in pixels)
|
||||
self.simplify_on_finish: bool = True
|
||||
self.simplify_epsilon: float = 2.0
|
||||
self.sample_threshold: float = 2.0 # minimum movement to sample a new point
|
||||
|
||||
self._setup_ui()
|
||||
|
||||
def _setup_ui(self):
|
||||
"""Setup user interface."""
|
||||
layout = QVBoxLayout()
|
||||
layout.setContentsMargins(0, 0, 0, 0)
|
||||
|
||||
# Scroll area for canvas
|
||||
self.scroll_area = QScrollArea()
|
||||
self.scroll_area.setWidgetResizable(True)
|
||||
self.scroll_area.setMinimumHeight(400)
|
||||
|
||||
self.canvas_label = QLabel("No image loaded")
|
||||
self.canvas_label.setAlignment(Qt.AlignCenter)
|
||||
self.canvas_label.setStyleSheet(
|
||||
"QLabel { background-color: #2b2b2b; color: #888; }"
|
||||
)
|
||||
self.canvas_label.setScaledContents(False)
|
||||
self.canvas_label.setMouseTracking(True)
|
||||
|
||||
self.scroll_area.setWidget(self.canvas_label)
|
||||
self.scroll_area.viewport().installEventFilter(self)
|
||||
|
||||
layout.addWidget(self.scroll_area)
|
||||
self.setLayout(layout)
|
||||
|
||||
self.setFocusPolicy(Qt.StrongFocus)
|
||||
|
||||
def load_image(self, image: Image):
|
||||
"""
|
||||
Load and display an image.
|
||||
|
||||
Args:
|
||||
image: Image object to display
|
||||
"""
|
||||
self.current_image = image
|
||||
self.zoom_scale = 1.0
|
||||
self.clear_annotations()
|
||||
self._display_image()
|
||||
logger.debug(
|
||||
f"Loaded image into annotation canvas: {image.width}x{image.height}"
|
||||
)
|
||||
|
||||
def clear(self):
|
||||
"""Clear the displayed image and all annotations."""
|
||||
self.current_image = None
|
||||
self.original_pixmap = None
|
||||
self.annotation_pixmap = None
|
||||
self.zoom_scale = 1.0
|
||||
self.clear_annotations()
|
||||
self.canvas_label.setText("No image loaded")
|
||||
self.canvas_label.setPixmap(QPixmap())
|
||||
|
||||
def clear_annotations(self):
|
||||
"""Clear all drawn annotations."""
|
||||
self.all_strokes = []
|
||||
self.current_stroke = []
|
||||
self.polylines = []
|
||||
self.stroke_meta = []
|
||||
self.polyline_annotation_ids = []
|
||||
self.selected_polyline_indices = []
|
||||
self.bboxes = []
|
||||
self.bbox_meta = []
|
||||
self.is_drawing = False
|
||||
if self.annotation_pixmap:
|
||||
self.annotation_pixmap.fill(Qt.transparent)
|
||||
self._update_display()
|
||||
|
||||
def _display_image(self):
|
||||
"""Display the current image in the canvas."""
|
||||
if self.current_image is None:
|
||||
return
|
||||
|
||||
try:
|
||||
# Get image data in a format compatible with Qt
|
||||
if self.current_image.channels in (3, 4):
|
||||
image_data = self.current_image.get_rgb()
|
||||
height, width = image_data.shape[:2]
|
||||
else:
|
||||
image_data = self.current_image.get_grayscale()
|
||||
height, width = image_data.shape
|
||||
|
||||
image_data = np.ascontiguousarray(image_data)
|
||||
bytes_per_line = image_data.strides[0]
|
||||
|
||||
qimage = QImage(
|
||||
image_data.data,
|
||||
width,
|
||||
height,
|
||||
bytes_per_line,
|
||||
self.current_image.qtimage_format,
|
||||
).copy() # Copy so Qt owns the buffer even after numpy array goes out of scope
|
||||
|
||||
self.original_pixmap = QPixmap.fromImage(qimage)
|
||||
|
||||
# Create transparent overlay for annotations
|
||||
self.annotation_pixmap = QPixmap(self.original_pixmap.size())
|
||||
self.annotation_pixmap.fill(Qt.transparent)
|
||||
|
||||
self._apply_zoom()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error displaying image: {e}")
|
||||
raise ImageLoadError(f"Failed to display image: {str(e)}")
|
||||
|
||||
def _apply_zoom(self):
|
||||
"""Apply current zoom level to the displayed image."""
|
||||
if self.original_pixmap is None:
|
||||
return
|
||||
|
||||
scaled_width = int(self.original_pixmap.width() * self.zoom_scale)
|
||||
scaled_height = int(self.original_pixmap.height() * self.zoom_scale)
|
||||
|
||||
# Scale both image and annotations
|
||||
scaled_image = self.original_pixmap.scaled(
|
||||
scaled_width,
|
||||
scaled_height,
|
||||
Qt.KeepAspectRatio,
|
||||
(
|
||||
Qt.SmoothTransformation
|
||||
if self.zoom_scale >= 1.0
|
||||
else Qt.FastTransformation
|
||||
),
|
||||
)
|
||||
|
||||
scaled_annotations = self.annotation_pixmap.scaled(
|
||||
scaled_width,
|
||||
scaled_height,
|
||||
Qt.KeepAspectRatio,
|
||||
(
|
||||
Qt.SmoothTransformation
|
||||
if self.zoom_scale >= 1.0
|
||||
else Qt.FastTransformation
|
||||
),
|
||||
)
|
||||
|
||||
# Composite image and annotations
|
||||
combined = QPixmap(scaled_image.size())
|
||||
painter = QPainter(combined)
|
||||
painter.drawPixmap(0, 0, scaled_image)
|
||||
painter.drawPixmap(0, 0, scaled_annotations)
|
||||
painter.end()
|
||||
|
||||
self.canvas_label.setPixmap(combined)
|
||||
self.canvas_label.setScaledContents(False)
|
||||
self.canvas_label.adjustSize()
|
||||
|
||||
self.zoom_changed.emit(self.zoom_scale)
|
||||
|
||||
def _update_display(self):
|
||||
"""Update display after drawing."""
|
||||
self._apply_zoom()
|
||||
|
||||
def set_polyline_enabled(self, enabled: bool):
|
||||
"""Enable or disable polyline tool."""
|
||||
self.polyline_enabled = enabled
|
||||
if enabled:
|
||||
self.canvas_label.setCursor(Qt.CrossCursor)
|
||||
else:
|
||||
self.canvas_label.setCursor(Qt.ArrowCursor)
|
||||
|
||||
def set_polyline_pen_color(self, color: QColor):
|
||||
"""Set polyline pen color."""
|
||||
self.polyline_pen_color = color
|
||||
|
||||
def set_polyline_pen_width(self, width: int):
|
||||
"""Set polyline pen width."""
|
||||
self.polyline_pen_width = max(1, width)
|
||||
|
||||
def get_zoom_percentage(self) -> int:
|
||||
"""Get current zoom level as percentage."""
|
||||
return int(self.zoom_scale * 100)
|
||||
|
||||
def zoom_in(self):
|
||||
"""Zoom in on the image."""
|
||||
if self.original_pixmap is None:
|
||||
return
|
||||
new_scale = self.zoom_scale + self.zoom_step
|
||||
if new_scale <= self.zoom_max:
|
||||
self.zoom_scale = new_scale
|
||||
self._apply_zoom()
|
||||
|
||||
def zoom_out(self):
|
||||
"""Zoom out from the image."""
|
||||
if self.original_pixmap is None:
|
||||
return
|
||||
new_scale = self.zoom_scale - self.zoom_step
|
||||
if new_scale >= self.zoom_min:
|
||||
self.zoom_scale = new_scale
|
||||
self._apply_zoom()
|
||||
|
||||
def reset_zoom(self):
|
||||
"""Reset zoom to 100%."""
|
||||
if self.original_pixmap is None:
|
||||
return
|
||||
self.zoom_scale = 1.0
|
||||
self._apply_zoom()
|
||||
|
||||
def _canvas_to_image_coords(self, pos: QPoint) -> Optional[Tuple[int, int]]:
|
||||
"""Convert canvas coordinates to image coordinates, accounting for zoom and centering."""
|
||||
if self.original_pixmap is None or self.canvas_label.pixmap() is None:
|
||||
return None
|
||||
|
||||
# Get the displayed pixmap size (after zoom)
|
||||
displayed_pixmap = self.canvas_label.pixmap()
|
||||
displayed_width = displayed_pixmap.width()
|
||||
displayed_height = displayed_pixmap.height()
|
||||
|
||||
# Calculate offset due to label centering (label might be larger than pixmap)
|
||||
label_width = self.canvas_label.width()
|
||||
label_height = self.canvas_label.height()
|
||||
offset_x = max(0, (label_width - displayed_width) // 2)
|
||||
offset_y = max(0, (label_height - displayed_height) // 2)
|
||||
|
||||
# Adjust position for offset and convert to image coordinates
|
||||
x = (pos.x() - offset_x) / self.zoom_scale
|
||||
y = (pos.y() - offset_y) / self.zoom_scale
|
||||
|
||||
# Check bounds
|
||||
if (
|
||||
0 <= x < self.original_pixmap.width()
|
||||
and 0 <= y < self.original_pixmap.height()
|
||||
):
|
||||
return (int(x), int(y))
|
||||
return None
|
||||
|
||||
def _find_polyline_at(
|
||||
self, img_x: float, img_y: float, threshold_px: float = 5.0
|
||||
) -> Optional[int]:
|
||||
"""
|
||||
Find index of polyline whose geometry is within threshold_px of (img_x, img_y).
|
||||
Returns the index in self.polylines, or None if none is close enough.
|
||||
"""
|
||||
best_index: Optional[int] = None
|
||||
best_dist: float = float("inf")
|
||||
|
||||
for idx, polyline in enumerate(self.polylines):
|
||||
if len(polyline) < 2:
|
||||
continue
|
||||
|
||||
# Quick bounding-box check to skip obviously distant polylines
|
||||
xs = [p[0] for p in polyline]
|
||||
ys = [p[1] for p in polyline]
|
||||
if img_x < min(xs) - threshold_px or img_x > max(xs) + threshold_px:
|
||||
continue
|
||||
if img_y < min(ys) - threshold_px or img_y > max(ys) + threshold_px:
|
||||
continue
|
||||
|
||||
# Precise distance to all segments
|
||||
for (x1, y1), (x2, y2) in zip(polyline[:-1], polyline[1:]):
|
||||
d = perpendicular_distance(
|
||||
(img_x, img_y), (float(x1), float(y1)), (float(x2), float(y2))
|
||||
)
|
||||
if d < best_dist:
|
||||
best_dist = d
|
||||
best_index = idx
|
||||
|
||||
if best_index is not None and best_dist <= threshold_px:
|
||||
return best_index
|
||||
return None
|
||||
|
||||
def _image_to_normalized_coords(self, x: int, y: int) -> Tuple[float, float]:
|
||||
"""Convert image coordinates to normalized coordinates (0-1)."""
|
||||
if self.original_pixmap is None:
|
||||
return (0.0, 0.0)
|
||||
|
||||
norm_x = x / self.original_pixmap.width()
|
||||
norm_y = y / self.original_pixmap.height()
|
||||
return (norm_x, norm_y)
|
||||
|
||||
def _add_polyline(
|
||||
self,
|
||||
img_points: List[Tuple[float, float]],
|
||||
color: QColor,
|
||||
width: int,
|
||||
annotation_id: Optional[int] = None,
|
||||
):
|
||||
"""Store a polyline in image coordinates and redraw annotations."""
|
||||
if not img_points or len(img_points) < 2:
|
||||
return
|
||||
|
||||
# Ensure all points are tuples of floats
|
||||
normalized_points = [(float(x), float(y)) for x, y in img_points]
|
||||
self.polylines.append(normalized_points)
|
||||
self.stroke_meta.append({"color": QColor(color), "width": int(width)})
|
||||
self.polyline_annotation_ids.append(annotation_id)
|
||||
|
||||
self._redraw_annotations()
|
||||
|
||||
def _redraw_annotations(self):
|
||||
"""Redraw all stored polylines and (optionally) bounding boxes onto the annotation pixmap."""
|
||||
if self.annotation_pixmap is None:
|
||||
return
|
||||
|
||||
# Clear existing overlay
|
||||
self.annotation_pixmap.fill(Qt.transparent)
|
||||
|
||||
painter = QPainter(self.annotation_pixmap)
|
||||
|
||||
# Draw polylines
|
||||
for idx, (polyline, meta) in enumerate(zip(self.polylines, self.stroke_meta)):
|
||||
pen_color: QColor = meta.get("color", self.polyline_pen_color)
|
||||
width: int = meta.get("width", self.polyline_pen_width)
|
||||
|
||||
if idx in self.selected_polyline_indices:
|
||||
# Highlight selected polylines in a distinct color / width
|
||||
highlight_color = QColor(255, 255, 0, 200) # yellow, semi-opaque
|
||||
pen = QPen(
|
||||
highlight_color,
|
||||
width + 1,
|
||||
Qt.SolidLine,
|
||||
Qt.RoundCap,
|
||||
Qt.RoundJoin,
|
||||
)
|
||||
else:
|
||||
pen = QPen(
|
||||
pen_color,
|
||||
width,
|
||||
Qt.SolidLine,
|
||||
Qt.RoundCap,
|
||||
Qt.RoundJoin,
|
||||
)
|
||||
|
||||
painter.setPen(pen)
|
||||
# Use QPolygonF for efficient polygon rendering (single call vs N-1 calls)
|
||||
# drawPolygon() automatically closes the shape, ensuring proper visual closure
|
||||
polygon = QPolygonF([QPointF(x, y) for x, y in polyline])
|
||||
painter.drawPolygon(polygon)
|
||||
|
||||
# Draw bounding boxes (dashed) if enabled
|
||||
if self.show_bboxes and self.original_pixmap is not None and self.bboxes:
|
||||
img_width = float(self.original_pixmap.width())
|
||||
img_height = float(self.original_pixmap.height())
|
||||
|
||||
for bbox, meta in zip(self.bboxes, self.bbox_meta):
|
||||
if len(bbox) != 4:
|
||||
continue
|
||||
|
||||
x_min_norm, y_min_norm, x_max_norm, y_max_norm = bbox
|
||||
x_min = int(x_min_norm * img_width)
|
||||
y_min = int(y_min_norm * img_height)
|
||||
x_max = int(x_max_norm * img_width)
|
||||
y_max = int(y_max_norm * img_height)
|
||||
|
||||
rect_width = x_max - x_min
|
||||
rect_height = y_max - y_min
|
||||
|
||||
pen_color: QColor = meta.get("color", QColor(255, 0, 0, 128))
|
||||
width: int = meta.get("width", self.polyline_pen_width)
|
||||
pen = QPen(
|
||||
pen_color,
|
||||
width,
|
||||
Qt.DashLine,
|
||||
Qt.SquareCap,
|
||||
Qt.MiterJoin,
|
||||
)
|
||||
painter.setPen(pen)
|
||||
painter.drawRect(x_min, y_min, rect_width, rect_height)
|
||||
|
||||
label_text = meta.get("label")
|
||||
if label_text:
|
||||
painter.save()
|
||||
font = painter.font()
|
||||
font.setPointSizeF(max(10.0, width + 4))
|
||||
painter.setFont(font)
|
||||
metrics = painter.fontMetrics()
|
||||
text_width = metrics.horizontalAdvance(label_text)
|
||||
text_height = metrics.height()
|
||||
padding = 4
|
||||
bg_width = text_width + padding * 2
|
||||
bg_height = text_height + padding * 2
|
||||
canvas_width = self.original_pixmap.width()
|
||||
canvas_height = self.original_pixmap.height()
|
||||
bg_x = max(0, min(x_min, canvas_width - bg_width))
|
||||
bg_y = y_min - bg_height
|
||||
if bg_y < 0:
|
||||
bg_y = min(y_min, canvas_height - bg_height)
|
||||
bg_y = max(0, bg_y)
|
||||
background_rect = QRect(bg_x, bg_y, bg_width, bg_height)
|
||||
background_color = QColor(pen_color)
|
||||
background_color.setAlpha(220)
|
||||
painter.fillRect(background_rect, background_color)
|
||||
text_color = QColor(0, 0, 0)
|
||||
if background_color.lightness() < 128:
|
||||
text_color = QColor(255, 255, 255)
|
||||
painter.setPen(text_color)
|
||||
painter.drawText(
|
||||
background_rect.adjusted(padding, padding, -padding, -padding),
|
||||
Qt.AlignLeft | Qt.AlignVCenter,
|
||||
label_text,
|
||||
)
|
||||
painter.restore()
|
||||
|
||||
painter.end()
|
||||
|
||||
self._update_display()
|
||||
|
||||
def mousePressEvent(self, event: QMouseEvent):
|
||||
"""Handle mouse press events for drawing and selecting polylines."""
|
||||
if self.annotation_pixmap is None:
|
||||
super().mousePressEvent(event)
|
||||
return
|
||||
|
||||
# Map click to image coordinates
|
||||
label_pos = self.canvas_label.mapFromGlobal(event.globalPos())
|
||||
img_coords = self._canvas_to_image_coords(label_pos)
|
||||
|
||||
# Left button + drawing tool enabled -> start a new stroke
|
||||
if event.button() == Qt.LeftButton and self.polyline_enabled:
|
||||
if img_coords:
|
||||
self.is_drawing = True
|
||||
self.current_stroke = [(float(img_coords[0]), float(img_coords[1]))]
|
||||
return
|
||||
|
||||
# Left button + drawing tool disabled -> attempt selection of existing polyline
|
||||
if event.button() == Qt.LeftButton and not self.polyline_enabled:
|
||||
if img_coords:
|
||||
idx = self._find_polyline_at(float(img_coords[0]), float(img_coords[1]))
|
||||
if idx is not None:
|
||||
if event.modifiers() & Qt.ShiftModifier:
|
||||
# Multi-select mode: add to current selection (if not already selected)
|
||||
if idx not in self.selected_polyline_indices:
|
||||
self.selected_polyline_indices.append(idx)
|
||||
else:
|
||||
# Single-select mode: replace current selection
|
||||
self.selected_polyline_indices = [idx]
|
||||
|
||||
# Build list of selected annotation IDs (ignore None entries)
|
||||
selected_ids: List[int] = []
|
||||
for sel_idx in self.selected_polyline_indices:
|
||||
if 0 <= sel_idx < len(self.polyline_annotation_ids):
|
||||
ann_id = self.polyline_annotation_ids[sel_idx]
|
||||
if isinstance(ann_id, int):
|
||||
selected_ids.append(ann_id)
|
||||
|
||||
if selected_ids:
|
||||
self.annotation_selected.emit(selected_ids)
|
||||
else:
|
||||
# No valid DB-backed annotations in selection
|
||||
self.annotation_selected.emit(None)
|
||||
else:
|
||||
# Clicked on empty space -> clear selection
|
||||
self.selected_polyline_indices = []
|
||||
self.annotation_selected.emit(None)
|
||||
|
||||
self._redraw_annotations()
|
||||
return
|
||||
|
||||
# Fallback for other buttons / cases
|
||||
super().mousePressEvent(event)
|
||||
|
||||
def mouseMoveEvent(self, event: QMouseEvent):
|
||||
"""Handle mouse move events for drawing."""
|
||||
if (
|
||||
not self.is_drawing
|
||||
or not self.polyline_enabled
|
||||
or self.annotation_pixmap is None
|
||||
):
|
||||
super().mouseMoveEvent(event)
|
||||
return
|
||||
|
||||
# Get accurate position using global coordinates
|
||||
label_pos = self.canvas_label.mapFromGlobal(event.globalPos())
|
||||
img_coords = self._canvas_to_image_coords(label_pos)
|
||||
|
||||
if img_coords and len(self.current_stroke) > 0:
|
||||
last_point = self.current_stroke[-1]
|
||||
dx = img_coords[0] - last_point[0]
|
||||
dy = img_coords[1] - last_point[1]
|
||||
|
||||
# Only sample a new point if we moved enough pixels
|
||||
if math.hypot(dx, dy) < self.sample_threshold:
|
||||
return
|
||||
|
||||
# Draw line from last point to current point for interactive feedback
|
||||
painter = QPainter(self.annotation_pixmap)
|
||||
pen = QPen(
|
||||
self.polyline_pen_color,
|
||||
self.polyline_pen_width,
|
||||
Qt.SolidLine,
|
||||
Qt.RoundCap,
|
||||
Qt.RoundJoin,
|
||||
)
|
||||
painter.setPen(pen)
|
||||
painter.drawLine(
|
||||
int(last_point[0]),
|
||||
int(last_point[1]),
|
||||
int(img_coords[0]),
|
||||
int(img_coords[1]),
|
||||
)
|
||||
painter.end()
|
||||
|
||||
self.current_stroke.append((float(img_coords[0]), float(img_coords[1])))
|
||||
self._update_display()
|
||||
|
||||
def mouseReleaseEvent(self, event: QMouseEvent):
|
||||
"""Handle mouse release events to complete a stroke."""
|
||||
if not self.is_drawing or event.button() != Qt.LeftButton:
|
||||
super().mouseReleaseEvent(event)
|
||||
return
|
||||
|
||||
self.is_drawing = False
|
||||
|
||||
if len(self.current_stroke) > 1 and self.original_pixmap is not None:
|
||||
# Ensure the stroke is closed by connecting end -> start
|
||||
raw_points = list(self.current_stroke)
|
||||
if raw_points[0] != raw_points[-1]:
|
||||
raw_points.append(raw_points[0])
|
||||
|
||||
# Optional RDP simplification (in image pixel space)
|
||||
if self.simplify_on_finish:
|
||||
simplified = simplify_polyline(raw_points, self.simplify_epsilon)
|
||||
else:
|
||||
simplified = raw_points
|
||||
|
||||
if len(simplified) >= 2:
|
||||
# Store polyline and redraw all annotations
|
||||
self._add_polyline(
|
||||
simplified, self.polyline_pen_color, self.polyline_pen_width
|
||||
)
|
||||
|
||||
# Convert to normalized coordinates for metadata + signal
|
||||
normalized_stroke = [
|
||||
self._image_to_normalized_coords(int(x), int(y))
|
||||
for (x, y) in simplified
|
||||
]
|
||||
self.all_strokes.append(
|
||||
{
|
||||
"points": normalized_stroke,
|
||||
"color": self.polyline_pen_color.name(),
|
||||
"alpha": self.polyline_pen_color.alpha(),
|
||||
"width": self.polyline_pen_width,
|
||||
}
|
||||
)
|
||||
|
||||
# Emit signal with normalized coordinates
|
||||
self.annotation_drawn.emit(normalized_stroke)
|
||||
logger.debug(
|
||||
f"Completed stroke with {len(simplified)} points "
|
||||
f"(normalized len={len(normalized_stroke)})"
|
||||
)
|
||||
|
||||
self.current_stroke = []
|
||||
|
||||
def get_all_strokes(self) -> List[dict]:
|
||||
"""Get all drawn strokes with metadata."""
|
||||
return self.all_strokes
|
||||
|
||||
def get_annotation_parameters(self) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
Get all annotation parameters including bounding box and polyline.
|
||||
|
||||
Returns:
|
||||
List of dictionaries, each containing:
|
||||
- 'bbox': [x_min, y_min, x_max, y_max] in normalized image coordinates
|
||||
- 'polyline': List of [y_norm, x_norm] points describing the polygon
|
||||
"""
|
||||
if self.original_pixmap is None or not self.polylines:
|
||||
return None
|
||||
|
||||
img_width = float(self.original_pixmap.width())
|
||||
img_height = float(self.original_pixmap.height())
|
||||
|
||||
params: List[Dict[str, Any]] = []
|
||||
|
||||
for idx, polyline in enumerate(self.polylines):
|
||||
if len(polyline) < 2:
|
||||
continue
|
||||
|
||||
xs = [p[0] for p in polyline]
|
||||
ys = [p[1] for p in polyline]
|
||||
|
||||
x_min_norm = min(xs) / img_width
|
||||
x_max_norm = max(xs) / img_width
|
||||
y_min_norm = min(ys) / img_height
|
||||
y_max_norm = max(ys) / img_height
|
||||
|
||||
# Store polyline as [y_norm, x_norm] to match DB convention and
|
||||
# the expectations of draw_saved_polyline().
|
||||
normalized_polyline = [
|
||||
[y / img_height, x / img_width] for (x, y) in polyline
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
f"Polyline {idx}: {len(polyline)} points, "
|
||||
f"bbox=({x_min_norm:.3f}, {y_min_norm:.3f})-({x_max_norm:.3f}, {y_max_norm:.3f})"
|
||||
)
|
||||
|
||||
params.append(
|
||||
{
|
||||
"bbox": [x_min_norm, y_min_norm, x_max_norm, y_max_norm],
|
||||
"polyline": normalized_polyline,
|
||||
}
|
||||
)
|
||||
|
||||
return params or None
|
||||
|
||||
def draw_saved_polyline(
|
||||
self,
|
||||
polyline: List[List[float]],
|
||||
color: str,
|
||||
width: int = 3,
|
||||
annotation_id: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Draw a polyline from database coordinates onto the annotation canvas.
|
||||
|
||||
Args:
|
||||
polyline: List of [x, y] coordinate pairs in normalized coordinates (0-1)
|
||||
color: Color hex string (e.g., '#FF0000')
|
||||
width: Line width in pixels
|
||||
"""
|
||||
if not self.annotation_pixmap or not self.original_pixmap:
|
||||
logger.warning("Cannot draw polyline: no image loaded")
|
||||
return
|
||||
|
||||
if len(polyline) < 2:
|
||||
logger.warning("Polyline has less than 2 points, cannot draw")
|
||||
return
|
||||
|
||||
# Convert normalized coordinates to image coordinates
|
||||
# Polyline is stored as [[y_norm, x_norm], ...] (row_norm, col_norm format)
|
||||
img_width = self.original_pixmap.width()
|
||||
img_height = self.original_pixmap.height()
|
||||
|
||||
logger.debug(f"Loading polyline with {len(polyline)} points")
|
||||
logger.debug(f" Image size: {img_width}x{img_height}")
|
||||
logger.debug(f" First 3 normalized points from DB: {polyline[:3]}")
|
||||
|
||||
img_coords: List[Tuple[float, float]] = []
|
||||
for y_norm, x_norm in polyline:
|
||||
x = float(x_norm * img_width)
|
||||
y = float(y_norm * img_height)
|
||||
img_coords.append((x, y))
|
||||
|
||||
logger.debug(f" First 3 pixel coords: {img_coords[:3]}")
|
||||
|
||||
# Store and redraw using common pipeline
|
||||
pen_color = QColor(color)
|
||||
pen_color.setAlpha(128) # Add semi-transparency
|
||||
self._add_polyline(img_coords, pen_color, width, annotation_id=annotation_id)
|
||||
|
||||
# Store in all_strokes for consistency (uses normalized coordinates)
|
||||
self.all_strokes.append(
|
||||
{"points": polyline, "color": color, "alpha": 128, "width": width}
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Drew saved polyline with {len(polyline)} points in color {color}"
|
||||
)
|
||||
|
||||
def draw_saved_bbox(
|
||||
self,
|
||||
bbox: List[float],
|
||||
color: str,
|
||||
width: int = 3,
|
||||
label: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Draw a bounding box from database coordinates onto the annotation canvas.
|
||||
|
||||
Args:
|
||||
bbox: Bounding box as [x_min_norm, y_min_norm, x_max_norm, y_max_norm]
|
||||
in normalized coordinates (0-1)
|
||||
color: Color hex string (e.g., '#FF0000')
|
||||
width: Line width in pixels
|
||||
label: Optional text label to render near the bounding box
|
||||
"""
|
||||
if not self.annotation_pixmap or not self.original_pixmap:
|
||||
logger.warning("Cannot draw bounding box: no image loaded")
|
||||
return
|
||||
|
||||
if len(bbox) != 4:
|
||||
logger.warning(
|
||||
f"Invalid bounding box format: expected 4 values, got {len(bbox)}"
|
||||
)
|
||||
return
|
||||
|
||||
# Convert normalized coordinates to image coordinates (for logging/debug)
|
||||
img_width = self.original_pixmap.width()
|
||||
img_height = self.original_pixmap.height()
|
||||
|
||||
x_min_norm, y_min_norm, x_max_norm, y_max_norm = bbox
|
||||
x_min = int(x_min_norm * img_width)
|
||||
y_min = int(y_min_norm * img_height)
|
||||
x_max = int(x_max_norm * img_width)
|
||||
y_max = int(y_max_norm * img_height)
|
||||
|
||||
logger.debug(f"Drawing bounding box: {bbox}")
|
||||
logger.debug(f" Image size: {img_width}x{img_height}")
|
||||
logger.debug(f" Pixel coords: ({x_min}, {y_min}) to ({x_max}, {y_max})")
|
||||
|
||||
# Store bounding box (normalized) and its style; actual drawing happens
|
||||
# in _redraw_annotations() together with all polylines.
|
||||
pen_color = QColor(color)
|
||||
pen_color.setAlpha(128) # Add semi-transparency
|
||||
self.bboxes.append(
|
||||
[float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)]
|
||||
)
|
||||
self.bbox_meta.append({"color": pen_color, "width": int(width), "label": label})
|
||||
|
||||
# Store in all_strokes for consistency
|
||||
self.all_strokes.append(
|
||||
{"bbox": bbox, "color": color, "alpha": 128, "width": width, "label": label}
|
||||
)
|
||||
|
||||
# Redraw overlay (polylines + all bounding boxes)
|
||||
self._redraw_annotations()
|
||||
logger.debug(f"Drew saved bounding box in color {color}")
|
||||
|
||||
def set_show_bboxes(self, show: bool):
|
||||
"""
|
||||
Enable or disable drawing of bounding boxes.
|
||||
|
||||
Args:
|
||||
show: If True, draw bounding boxes; if False, hide them.
|
||||
"""
|
||||
self.show_bboxes = bool(show)
|
||||
logger.debug(f"Set show_bboxes to {self.show_bboxes}")
|
||||
self._redraw_annotations()
|
||||
|
||||
def keyPressEvent(self, event: QKeyEvent):
|
||||
"""Handle keyboard events for zooming."""
|
||||
if event.key() in (Qt.Key_Plus, Qt.Key_Equal):
|
||||
self.zoom_in()
|
||||
event.accept()
|
||||
elif event.key() == Qt.Key_Minus:
|
||||
self.zoom_out()
|
||||
event.accept()
|
||||
elif event.key() == Qt.Key_0 and event.modifiers() == Qt.ControlModifier:
|
||||
self.reset_zoom()
|
||||
event.accept()
|
||||
else:
|
||||
super().keyPressEvent(event)
|
||||
|
||||
def eventFilter(self, obj, event: QEvent) -> bool:
|
||||
"""Event filter to capture wheel events for zooming."""
|
||||
if event.type() == QEvent.Wheel:
|
||||
wheel_event = event
|
||||
if self.original_pixmap is not None:
|
||||
delta = wheel_event.angleDelta().y()
|
||||
|
||||
if delta > 0:
|
||||
new_scale = self.zoom_scale + self.zoom_wheel_step
|
||||
if new_scale <= self.zoom_max:
|
||||
self.zoom_scale = new_scale
|
||||
self._apply_zoom()
|
||||
else:
|
||||
new_scale = self.zoom_scale - self.zoom_wheel_step
|
||||
if new_scale >= self.zoom_min:
|
||||
self.zoom_scale = new_scale
|
||||
self._apply_zoom()
|
||||
|
||||
return True
|
||||
|
||||
return super().eventFilter(obj, event)
|
||||
478
src/gui/widgets/annotation_tools_widget.py
Normal file
478
src/gui/widgets/annotation_tools_widget.py
Normal file
@@ -0,0 +1,478 @@
|
||||
"""
|
||||
Annotation tools widget for controlling annotation parameters.
|
||||
Includes polyline tool, color picker, class selection, and annotation management.
|
||||
"""
|
||||
|
||||
from PySide6.QtWidgets import (
|
||||
QWidget,
|
||||
QVBoxLayout,
|
||||
QHBoxLayout,
|
||||
QLabel,
|
||||
QGroupBox,
|
||||
QPushButton,
|
||||
QComboBox,
|
||||
QSpinBox,
|
||||
QDoubleSpinBox,
|
||||
QCheckBox,
|
||||
QColorDialog,
|
||||
QInputDialog,
|
||||
QMessageBox,
|
||||
)
|
||||
from PySide6.QtGui import QColor, QIcon, QPixmap, QPainter
|
||||
from PySide6.QtCore import Qt, Signal
|
||||
from typing import Optional, Dict
|
||||
|
||||
from src.database.db_manager import DatabaseManager
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AnnotationToolsWidget(QWidget):
|
||||
"""
|
||||
Widget for annotation tool controls.
|
||||
|
||||
Features:
|
||||
- Enable/disable polyline tool
|
||||
- Color selection for polyline pen
|
||||
- Object class selection
|
||||
- Add new object classes
|
||||
- Pen width control
|
||||
- Clear annotations
|
||||
|
||||
Signals:
|
||||
polyline_enabled_changed: Emitted when polyline tool is enabled/disabled (bool)
|
||||
polyline_pen_color_changed: Emitted when polyline pen color changes (QColor)
|
||||
polyline_pen_width_changed: Emitted when polyline pen width changes (int)
|
||||
class_selected: Emitted when object class is selected (dict)
|
||||
clear_annotations_requested: Emitted when clear button is pressed
|
||||
"""
|
||||
|
||||
polyline_enabled_changed = Signal(bool)
|
||||
polyline_pen_color_changed = Signal(QColor)
|
||||
polyline_pen_width_changed = Signal(int)
|
||||
simplify_on_finish_changed = Signal(bool)
|
||||
simplify_epsilon_changed = Signal(float)
|
||||
# Toggle visibility of bounding boxes on the canvas
|
||||
show_bboxes_changed = Signal(bool)
|
||||
class_selected = Signal(dict)
|
||||
class_color_changed = Signal()
|
||||
clear_annotations_requested = Signal()
|
||||
# Request deletion of the currently selected annotation on the canvas
|
||||
delete_selected_annotation_requested = Signal()
|
||||
|
||||
def __init__(self, db_manager: DatabaseManager, parent=None):
|
||||
"""
|
||||
Initialize annotation tools widget.
|
||||
|
||||
Args:
|
||||
db_manager: Database manager instance
|
||||
parent: Parent widget
|
||||
"""
|
||||
super().__init__(parent)
|
||||
self.db_manager = db_manager
|
||||
self.polyline_enabled = False
|
||||
self.current_color = QColor(255, 0, 0, 128) # Red with 50% alpha
|
||||
self.current_class = None
|
||||
|
||||
self._setup_ui()
|
||||
self._load_object_classes()
|
||||
|
||||
def _setup_ui(self):
|
||||
"""Setup user interface."""
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# Polyline Tool Group
|
||||
polyline_group = QGroupBox("Polyline Tool")
|
||||
polyline_layout = QVBoxLayout()
|
||||
|
||||
# Enable/Disable polyline tool
|
||||
button_layout = QHBoxLayout()
|
||||
self.polyline_toggle_btn = QPushButton("Start Drawing Polyline")
|
||||
self.polyline_toggle_btn.setCheckable(True)
|
||||
self.polyline_toggle_btn.clicked.connect(self._on_polyline_toggle)
|
||||
button_layout.addWidget(self.polyline_toggle_btn)
|
||||
polyline_layout.addLayout(button_layout)
|
||||
|
||||
# Polyline pen width control
|
||||
width_layout = QHBoxLayout()
|
||||
width_layout.addWidget(QLabel("Pen Width:"))
|
||||
self.polyline_pen_width_spin = QSpinBox()
|
||||
self.polyline_pen_width_spin.setMinimum(1)
|
||||
self.polyline_pen_width_spin.setMaximum(20)
|
||||
self.polyline_pen_width_spin.setValue(3)
|
||||
self.polyline_pen_width_spin.valueChanged.connect(
|
||||
self._on_polyline_pen_width_changed
|
||||
)
|
||||
width_layout.addWidget(self.polyline_pen_width_spin)
|
||||
width_layout.addStretch()
|
||||
polyline_layout.addLayout(width_layout)
|
||||
|
||||
# Simplification controls (RDP)
|
||||
simplify_layout = QHBoxLayout()
|
||||
self.simplify_checkbox = QCheckBox("Simplify on finish")
|
||||
self.simplify_checkbox.setChecked(True)
|
||||
self.simplify_checkbox.stateChanged.connect(self._on_simplify_toggle)
|
||||
simplify_layout.addWidget(self.simplify_checkbox)
|
||||
|
||||
simplify_layout.addWidget(QLabel("epsilon (px):"))
|
||||
self.eps_spin = QDoubleSpinBox()
|
||||
self.eps_spin.setRange(0.0, 1000.0)
|
||||
self.eps_spin.setSingleStep(0.5)
|
||||
self.eps_spin.setValue(2.0)
|
||||
self.eps_spin.valueChanged.connect(self._on_eps_change)
|
||||
simplify_layout.addWidget(self.eps_spin)
|
||||
simplify_layout.addStretch()
|
||||
polyline_layout.addLayout(simplify_layout)
|
||||
|
||||
polyline_group.setLayout(polyline_layout)
|
||||
layout.addWidget(polyline_group)
|
||||
|
||||
# Object Class Group
|
||||
class_group = QGroupBox("Object Class")
|
||||
class_layout = QVBoxLayout()
|
||||
|
||||
# Class selection dropdown
|
||||
self.class_combo = QComboBox()
|
||||
self.class_combo.currentIndexChanged.connect(self._on_class_selected)
|
||||
class_layout.addWidget(self.class_combo)
|
||||
|
||||
# Add / manage classes
|
||||
class_button_layout = QHBoxLayout()
|
||||
self.add_class_btn = QPushButton("Add New Class")
|
||||
self.add_class_btn.clicked.connect(self._on_add_class)
|
||||
class_button_layout.addWidget(self.add_class_btn)
|
||||
|
||||
self.refresh_classes_btn = QPushButton("Refresh")
|
||||
self.refresh_classes_btn.clicked.connect(self._load_object_classes)
|
||||
class_button_layout.addWidget(self.refresh_classes_btn)
|
||||
class_layout.addLayout(class_button_layout)
|
||||
|
||||
# Class color (associated with selected object class)
|
||||
color_layout = QHBoxLayout()
|
||||
color_layout.addWidget(QLabel("Class Color:"))
|
||||
self.color_btn = QPushButton()
|
||||
self.color_btn.setFixedSize(40, 30)
|
||||
self.color_btn.clicked.connect(self._on_color_picker)
|
||||
self._update_color_button()
|
||||
color_layout.addWidget(self.color_btn)
|
||||
color_layout.addStretch()
|
||||
class_layout.addLayout(color_layout)
|
||||
|
||||
# Selected class info
|
||||
self.class_info_label = QLabel("No class selected")
|
||||
self.class_info_label.setWordWrap(True)
|
||||
self.class_info_label.setStyleSheet(
|
||||
"QLabel { color: #888; font-style: italic; }"
|
||||
)
|
||||
class_layout.addWidget(self.class_info_label)
|
||||
|
||||
class_group.setLayout(class_layout)
|
||||
layout.addWidget(class_group)
|
||||
|
||||
# Actions Group
|
||||
actions_group = QGroupBox("Actions")
|
||||
actions_layout = QVBoxLayout()
|
||||
|
||||
# Show / hide bounding boxes
|
||||
self.show_bboxes_checkbox = QCheckBox("Show bounding boxes")
|
||||
self.show_bboxes_checkbox.setChecked(True)
|
||||
self.show_bboxes_checkbox.stateChanged.connect(self._on_show_bboxes_toggle)
|
||||
actions_layout.addWidget(self.show_bboxes_checkbox)
|
||||
|
||||
self.clear_btn = QPushButton("Clear All Annotations")
|
||||
self.clear_btn.clicked.connect(self._on_clear_annotations)
|
||||
actions_layout.addWidget(self.clear_btn)
|
||||
|
||||
# Delete currently selected annotation (enabled when a selection exists)
|
||||
self.delete_selected_btn = QPushButton("Delete Selected Annotation")
|
||||
self.delete_selected_btn.clicked.connect(self._on_delete_selected_annotation)
|
||||
self.delete_selected_btn.setEnabled(False)
|
||||
actions_layout.addWidget(self.delete_selected_btn)
|
||||
|
||||
actions_group.setLayout(actions_layout)
|
||||
layout.addWidget(actions_group)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def _update_color_button(self):
|
||||
"""Update the color button appearance with current color."""
|
||||
pixmap = QPixmap(40, 30)
|
||||
pixmap.fill(self.current_color)
|
||||
|
||||
# Add border
|
||||
painter = QPainter(pixmap)
|
||||
painter.setPen(Qt.black)
|
||||
painter.drawRect(0, 0, pixmap.width() - 1, pixmap.height() - 1)
|
||||
painter.end()
|
||||
|
||||
self.color_btn.setIcon(QIcon(pixmap))
|
||||
self.color_btn.setStyleSheet(f"background-color: {self.current_color.name()};")
|
||||
|
||||
def _load_object_classes(self):
|
||||
"""Load object classes from database and populate combo box."""
|
||||
try:
|
||||
classes = self.db_manager.get_object_classes()
|
||||
|
||||
# Clear and repopulate combo box
|
||||
self.class_combo.clear()
|
||||
self.class_combo.addItem("-- Select Class / Show All --", None)
|
||||
|
||||
for cls in classes:
|
||||
self.class_combo.addItem(cls["class_name"], cls)
|
||||
|
||||
logger.debug(f"Loaded {len(classes)} object classes")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading object classes: {e}")
|
||||
QMessageBox.warning(
|
||||
self, "Error", f"Failed to load object classes:\n{str(e)}"
|
||||
)
|
||||
|
||||
def _on_polyline_toggle(self, checked: bool):
|
||||
"""Handle polyline tool enable/disable."""
|
||||
self.polyline_enabled = checked
|
||||
|
||||
if checked:
|
||||
self.polyline_toggle_btn.setText("Stop Drawing Polyline")
|
||||
self.polyline_toggle_btn.setStyleSheet(
|
||||
"QPushButton { background-color: #4CAF50; }"
|
||||
)
|
||||
else:
|
||||
self.polyline_toggle_btn.setText("Start Drawing Polyline")
|
||||
self.polyline_toggle_btn.setStyleSheet("")
|
||||
|
||||
self.polyline_enabled_changed.emit(self.polyline_enabled)
|
||||
logger.debug(f"Polyline tool {'enabled' if checked else 'disabled'}")
|
||||
|
||||
def _on_polyline_pen_width_changed(self, width: int):
|
||||
"""Handle polyline pen width changes."""
|
||||
self.polyline_pen_width_changed.emit(width)
|
||||
logger.debug(f"Polyline pen width changed to {width}")
|
||||
|
||||
def _on_simplify_toggle(self, state: int):
|
||||
"""Handle simplify-on-finish checkbox toggle."""
|
||||
enabled = bool(state)
|
||||
self.simplify_on_finish_changed.emit(enabled)
|
||||
logger.debug(f"Simplify on finish set to {enabled}")
|
||||
|
||||
def _on_eps_change(self, val: float):
|
||||
"""Handle epsilon (RDP tolerance) value changes."""
|
||||
epsilon = float(val)
|
||||
self.simplify_epsilon_changed.emit(epsilon)
|
||||
logger.debug(f"Simplification epsilon changed to {epsilon}")
|
||||
|
||||
def _on_show_bboxes_toggle(self, state: int):
|
||||
"""Handle 'Show bounding boxes' checkbox toggle."""
|
||||
show = bool(state)
|
||||
self.show_bboxes_changed.emit(show)
|
||||
logger.debug(f"Show bounding boxes set to {show}")
|
||||
|
||||
def _on_color_picker(self):
|
||||
"""Open color picker dialog and update the selected object's class color."""
|
||||
if not self.current_class:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"No Class Selected",
|
||||
"Please select an object class before changing its color.",
|
||||
)
|
||||
return
|
||||
|
||||
# Use current class color (without alpha) as the base
|
||||
base_color = QColor(self.current_class.get("color", self.current_color.name()))
|
||||
color = QColorDialog.getColor(
|
||||
base_color,
|
||||
self,
|
||||
"Select Class Color",
|
||||
QColorDialog.ShowAlphaChannel, # Allow alpha in UI, but store RGB in DB
|
||||
)
|
||||
|
||||
if not color.isValid():
|
||||
return
|
||||
|
||||
# Normalize to opaque RGB for storage
|
||||
new_color = QColor(color)
|
||||
new_color.setAlpha(255)
|
||||
hex_color = new_color.name()
|
||||
|
||||
try:
|
||||
# Update in database
|
||||
self.db_manager.update_object_class(
|
||||
class_id=self.current_class["id"], color=hex_color
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update class color in database: {e}")
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Error",
|
||||
f"Failed to update class color in database:\n{str(e)}",
|
||||
)
|
||||
return
|
||||
|
||||
# Update local class data and combo box item data
|
||||
self.current_class["color"] = hex_color
|
||||
current_index = self.class_combo.currentIndex()
|
||||
if current_index >= 0:
|
||||
self.class_combo.setItemData(current_index, dict(self.current_class))
|
||||
|
||||
# Update info label text
|
||||
info_text = f"Class: {self.current_class['class_name']}\nColor: {hex_color}"
|
||||
if self.current_class.get("description"):
|
||||
info_text += f"\nDescription: {self.current_class['description']}"
|
||||
self.class_info_label.setText(info_text)
|
||||
|
||||
# Use semi-transparent version for polyline pen / button preview
|
||||
class_color = QColor(hex_color)
|
||||
class_color.setAlpha(128)
|
||||
self.current_color = class_color
|
||||
self._update_color_button()
|
||||
self.polyline_pen_color_changed.emit(class_color)
|
||||
|
||||
logger.debug(
|
||||
f"Updated class '{self.current_class['class_name']}' color to "
|
||||
f"{hex_color} (polyline pen alpha={class_color.alpha()})"
|
||||
)
|
||||
|
||||
# Notify listeners (e.g., AnnotationTab) so they can reload/redraw
|
||||
self.class_color_changed.emit()
|
||||
|
||||
def _on_class_selected(self, index: int):
|
||||
"""Handle object class selection (including '-- Select Class --')."""
|
||||
class_data = self.class_combo.currentData()
|
||||
|
||||
if class_data:
|
||||
self.current_class = class_data
|
||||
|
||||
# Update info label
|
||||
info_text = (
|
||||
f"Class: {class_data['class_name']}\n" f"Color: {class_data['color']}"
|
||||
)
|
||||
if class_data.get("description"):
|
||||
info_text += f"\nDescription: {class_data['description']}"
|
||||
|
||||
self.class_info_label.setText(info_text)
|
||||
|
||||
# Update polyline pen color to match class color with semi-transparency
|
||||
class_color = QColor(class_data["color"])
|
||||
if class_color.isValid():
|
||||
# Add 50% alpha for semi-transparency
|
||||
class_color.setAlpha(128)
|
||||
self.current_color = class_color
|
||||
self._update_color_button()
|
||||
self.polyline_pen_color_changed.emit(class_color)
|
||||
|
||||
self.class_selected.emit(class_data)
|
||||
logger.debug(f"Selected class: {class_data['class_name']}")
|
||||
else:
|
||||
# "-- Select Class --" chosen: clear current class and show all annotations
|
||||
self.current_class = None
|
||||
self.class_info_label.setText("No class selected")
|
||||
self.class_selected.emit(None)
|
||||
logger.debug("Class selection cleared: showing annotations for all classes")
|
||||
|
||||
def _on_add_class(self):
|
||||
"""Handle adding a new object class."""
|
||||
# Get class name
|
||||
class_name, ok = QInputDialog.getText(
|
||||
self, "Add Object Class", "Enter class name:"
|
||||
)
|
||||
|
||||
if not ok or not class_name.strip():
|
||||
return
|
||||
|
||||
class_name = class_name.strip()
|
||||
|
||||
# Check if class already exists
|
||||
existing = self.db_manager.get_object_class_by_name(class_name)
|
||||
if existing:
|
||||
QMessageBox.warning(
|
||||
self, "Class Exists", f"A class named '{class_name}' already exists."
|
||||
)
|
||||
return
|
||||
|
||||
# Get color
|
||||
color = QColorDialog.getColor(self.current_color, self, "Select Class Color")
|
||||
|
||||
if not color.isValid():
|
||||
return
|
||||
|
||||
# Get optional description
|
||||
description, ok = QInputDialog.getText(
|
||||
self, "Class Description", "Enter class description (optional):"
|
||||
)
|
||||
|
||||
if not ok:
|
||||
description = None
|
||||
|
||||
# Add to database
|
||||
try:
|
||||
class_id = self.db_manager.add_object_class(
|
||||
class_name, color.name(), description.strip() if description else None
|
||||
)
|
||||
|
||||
logger.info(f"Added new object class: {class_name} (ID: {class_id})")
|
||||
|
||||
# Reload classes and select the new one
|
||||
self._load_object_classes()
|
||||
|
||||
# Find and select the newly added class
|
||||
for i in range(self.class_combo.count()):
|
||||
class_data = self.class_combo.itemData(i)
|
||||
if class_data and class_data.get("id") == class_id:
|
||||
self.class_combo.setCurrentIndex(i)
|
||||
break
|
||||
|
||||
QMessageBox.information(
|
||||
self, "Success", f"Class '{class_name}' added successfully!"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding object class: {e}")
|
||||
QMessageBox.critical(
|
||||
self, "Error", f"Failed to add object class:\n{str(e)}"
|
||||
)
|
||||
|
||||
def _on_clear_annotations(self):
|
||||
"""Handle clear annotations button."""
|
||||
reply = QMessageBox.question(
|
||||
self,
|
||||
"Clear Annotations",
|
||||
"Are you sure you want to clear all annotations?",
|
||||
QMessageBox.Yes | QMessageBox.No,
|
||||
QMessageBox.No,
|
||||
)
|
||||
|
||||
if reply == QMessageBox.Yes:
|
||||
self.clear_annotations_requested.emit()
|
||||
logger.debug("Clear annotations requested")
|
||||
|
||||
def _on_delete_selected_annotation(self):
|
||||
"""Handle delete selected annotation button."""
|
||||
self.delete_selected_annotation_requested.emit()
|
||||
logger.debug("Delete selected annotation requested")
|
||||
|
||||
def set_has_selected_annotation(self, has_selection: bool):
|
||||
"""
|
||||
Enable/disable actions that require a selected annotation.
|
||||
|
||||
Args:
|
||||
has_selection: True if an annotation is currently selected on the canvas.
|
||||
"""
|
||||
self.delete_selected_btn.setEnabled(bool(has_selection))
|
||||
|
||||
def get_current_class(self) -> Optional[Dict]:
|
||||
"""Get currently selected object class."""
|
||||
return self.current_class
|
||||
|
||||
def get_polyline_pen_color(self) -> QColor:
|
||||
"""Get current polyline pen color."""
|
||||
return self.current_color
|
||||
|
||||
def get_polyline_pen_width(self) -> int:
|
||||
"""Get current polyline pen width."""
|
||||
return self.polyline_pen_width_spin.value()
|
||||
|
||||
def is_polyline_enabled(self) -> bool:
|
||||
"""Check if polyline tool is enabled."""
|
||||
return self.polyline_enabled
|
||||
@@ -137,7 +137,7 @@ class ImageDisplayWidget(QWidget):
|
||||
height,
|
||||
bytes_per_line,
|
||||
self.current_image.qtimage_format,
|
||||
)
|
||||
).copy() # Copy to ensure Qt owns its memory after this scope
|
||||
|
||||
# Convert to pixmap
|
||||
pixmap = QPixmap.fromImage(qimage)
|
||||
|
||||
@@ -5,12 +5,12 @@ Handles detection inference and result storage.
|
||||
|
||||
from typing import List, Dict, Optional, Callable
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from src.model.yolo_wrapper import YOLOWrapper
|
||||
from src.database.db_manager import DatabaseManager
|
||||
from src.utils.image import Image
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.file_utils import get_relative_path
|
||||
|
||||
@@ -42,6 +42,7 @@ class InferenceEngine:
|
||||
relative_path: str,
|
||||
conf: float = 0.25,
|
||||
save_to_db: bool = True,
|
||||
repository_root: Optional[str] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Detect objects in a single image.
|
||||
@@ -51,49 +52,79 @@ class InferenceEngine:
|
||||
relative_path: Relative path from repository root
|
||||
conf: Confidence threshold
|
||||
save_to_db: Whether to save results to database
|
||||
repository_root: Base directory used to compute relative_path (if known)
|
||||
|
||||
Returns:
|
||||
Dictionary with detection results
|
||||
"""
|
||||
try:
|
||||
# Normalize storage path (fall back to absolute path when repo root is unknown)
|
||||
stored_relative_path = relative_path
|
||||
if not repository_root:
|
||||
stored_relative_path = str(Path(image_path).resolve())
|
||||
|
||||
# Get image dimensions
|
||||
img = Image.open(image_path)
|
||||
width, height = img.size
|
||||
img.close()
|
||||
img = Image(image_path)
|
||||
width = img.width
|
||||
height = img.height
|
||||
|
||||
# Perform detection
|
||||
detections = self.yolo.predict(image_path, conf=conf)
|
||||
|
||||
# Add/get image in database
|
||||
image_id = self.db_manager.get_or_create_image(
|
||||
relative_path=relative_path,
|
||||
relative_path=stored_relative_path,
|
||||
filename=Path(image_path).name,
|
||||
width=width,
|
||||
height=height,
|
||||
)
|
||||
|
||||
# Save detections to database
|
||||
if save_to_db and detections:
|
||||
detection_records = []
|
||||
for det in detections:
|
||||
# Use normalized bbox from detection
|
||||
bbox_normalized = det[
|
||||
"bbox_normalized"
|
||||
] # [x_min, y_min, x_max, y_max]
|
||||
inserted_count = 0
|
||||
deleted_count = 0
|
||||
|
||||
record = {
|
||||
"image_id": image_id,
|
||||
"model_id": self.model_id,
|
||||
"class_name": det["class_name"],
|
||||
"bbox": tuple(bbox_normalized),
|
||||
"confidence": det["confidence"],
|
||||
"segmentation_mask": det.get("segmentation_mask"),
|
||||
"metadata": {"class_id": det["class_id"]},
|
||||
}
|
||||
detection_records.append(record)
|
||||
# Save detections to database, replacing any previous results for this image/model
|
||||
if save_to_db:
|
||||
deleted_count = self.db_manager.delete_detections_for_image(
|
||||
image_id, self.model_id
|
||||
)
|
||||
if detections:
|
||||
detection_records = []
|
||||
for det in detections:
|
||||
# Use normalized bbox from detection
|
||||
bbox_normalized = det[
|
||||
"bbox_normalized"
|
||||
] # [x_min, y_min, x_max, y_max]
|
||||
|
||||
self.db_manager.add_detections_batch(detection_records)
|
||||
logger.info(f"Saved {len(detection_records)} detections to database")
|
||||
metadata = {
|
||||
"class_id": det["class_id"],
|
||||
"source_path": str(Path(image_path).resolve()),
|
||||
}
|
||||
if repository_root:
|
||||
metadata["repository_root"] = str(
|
||||
Path(repository_root).resolve()
|
||||
)
|
||||
|
||||
record = {
|
||||
"image_id": image_id,
|
||||
"model_id": self.model_id,
|
||||
"class_name": det["class_name"],
|
||||
"bbox": tuple(bbox_normalized),
|
||||
"confidence": det["confidence"],
|
||||
"segmentation_mask": det.get("segmentation_mask"),
|
||||
"metadata": metadata,
|
||||
}
|
||||
detection_records.append(record)
|
||||
|
||||
inserted_count = self.db_manager.add_detections_batch(
|
||||
detection_records
|
||||
)
|
||||
logger.info(
|
||||
f"Saved {inserted_count} detections to database (replaced {deleted_count})"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Detection run removed {deleted_count} stale entries but produced no new detections"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
@@ -142,7 +173,12 @@ class InferenceEngine:
|
||||
rel_path = get_relative_path(image_path, repository_root)
|
||||
|
||||
# Perform detection
|
||||
result = self.detect_single(image_path, rel_path, conf)
|
||||
result = self.detect_single(
|
||||
image_path,
|
||||
rel_path,
|
||||
conf=conf,
|
||||
repository_root=repository_root,
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
# Update progress
|
||||
|
||||
@@ -7,6 +7,9 @@ from ultralytics import YOLO
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Callable, Any
|
||||
import torch
|
||||
import tempfile
|
||||
import os
|
||||
from src.utils.image import Image
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
|
||||
@@ -55,6 +58,7 @@ class YOLOWrapper:
|
||||
save_dir: str = "data/models",
|
||||
name: str = "custom_model",
|
||||
resume: bool = False,
|
||||
callbacks: Optional[Dict[str, Callable]] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -69,13 +73,15 @@ class YOLOWrapper:
|
||||
save_dir: Directory to save trained model
|
||||
name: Name for the training run
|
||||
resume: Resume training from last checkpoint
|
||||
callbacks: Optional Ultralytics callback dictionary
|
||||
**kwargs: Additional training arguments
|
||||
|
||||
Returns:
|
||||
Dictionary with training results
|
||||
"""
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
if not self.load_model():
|
||||
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
||||
|
||||
try:
|
||||
logger.info(f"Starting training: {name}")
|
||||
@@ -117,7 +123,8 @@ class YOLOWrapper:
|
||||
Dictionary with validation metrics
|
||||
"""
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
if not self.load_model():
|
||||
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
||||
|
||||
try:
|
||||
logger.info(f"Starting validation on {split} split")
|
||||
@@ -158,12 +165,15 @@ class YOLOWrapper:
|
||||
List of detection dictionaries
|
||||
"""
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
if not self.load_model():
|
||||
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
||||
|
||||
prepared_source, cleanup_path = self._prepare_source(source)
|
||||
|
||||
try:
|
||||
logger.info(f"Running inference on {source}")
|
||||
results = self.model.predict(
|
||||
source=source,
|
||||
source=prepared_source,
|
||||
conf=conf,
|
||||
iou=iou,
|
||||
save=save,
|
||||
@@ -180,6 +190,14 @@ class YOLOWrapper:
|
||||
except Exception as e:
|
||||
logger.error(f"Error during inference: {e}")
|
||||
raise
|
||||
finally:
|
||||
if 0: # cleanup_path:
|
||||
try:
|
||||
os.remove(cleanup_path)
|
||||
except OSError as cleanup_error:
|
||||
logger.warning(
|
||||
f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}"
|
||||
)
|
||||
|
||||
def export(
|
||||
self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
|
||||
@@ -196,7 +214,8 @@ class YOLOWrapper:
|
||||
Path to exported model
|
||||
"""
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
if not self.load_model():
|
||||
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
||||
|
||||
try:
|
||||
logger.info(f"Exporting model to {format} format")
|
||||
@@ -208,6 +227,38 @@ class YOLOWrapper:
|
||||
logger.error(f"Error exporting model: {e}")
|
||||
raise
|
||||
|
||||
def _prepare_source(self, source):
|
||||
"""Convert single-channel images to RGB temporarily for inference."""
|
||||
cleanup_path = None
|
||||
|
||||
if isinstance(source, (str, Path)):
|
||||
source_path = Path(source)
|
||||
if source_path.is_file():
|
||||
try:
|
||||
img_obj = Image(source_path)
|
||||
pil_img = img_obj.pil_image
|
||||
if len(pil_img.getbands()) == 1:
|
||||
rgb_img = img_obj.convert_grayscale_to_rgb_preserve_range()
|
||||
else:
|
||||
rgb_img = pil_img.convert("RGB")
|
||||
|
||||
suffix = source_path.suffix or ".png"
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
|
||||
tmp_path = tmp.name
|
||||
tmp.close()
|
||||
rgb_img.save(tmp_path)
|
||||
cleanup_path = tmp_path
|
||||
logger.info(
|
||||
f"Converted image {source_path} to RGB for inference at {tmp_path}"
|
||||
)
|
||||
return tmp_path, cleanup_path
|
||||
except Exception as convert_error:
|
||||
logger.warning(
|
||||
f"Failed to preprocess {source_path} as RGB, continuing with original file: {convert_error}"
|
||||
)
|
||||
|
||||
return source, cleanup_path
|
||||
|
||||
def _format_training_results(self, results) -> Dict[str, Any]:
|
||||
"""Format training results into dictionary."""
|
||||
try:
|
||||
|
||||
@@ -7,6 +7,7 @@ import yaml
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.image import Image
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -46,18 +47,15 @@ class ConfigManager:
|
||||
"database": {"path": "data/detections.db"},
|
||||
"image_repository": {
|
||||
"base_path": "",
|
||||
"allowed_extensions": [
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".png",
|
||||
".tif",
|
||||
".tiff",
|
||||
".bmp",
|
||||
],
|
||||
"allowed_extensions": Image.SUPPORTED_EXTENSIONS,
|
||||
},
|
||||
"models": {
|
||||
"default_base_model": "yolov8s-seg.pt",
|
||||
"models_directory": "data/models",
|
||||
"base_model_choices": [
|
||||
"yolov8s-seg.pt",
|
||||
"yolov11s-seg.pt",
|
||||
],
|
||||
},
|
||||
"training": {
|
||||
"default_epochs": 100,
|
||||
@@ -65,6 +63,20 @@ class ConfigManager:
|
||||
"default_imgsz": 640,
|
||||
"default_patience": 50,
|
||||
"default_lr0": 0.01,
|
||||
"two_stage": {
|
||||
"enabled": False,
|
||||
"stage1": {
|
||||
"epochs": 20,
|
||||
"lr0": 0.0005,
|
||||
"patience": 10,
|
||||
"freeze": 10,
|
||||
},
|
||||
"stage2": {
|
||||
"epochs": 150,
|
||||
"lr0": 0.0003,
|
||||
"patience": 30,
|
||||
},
|
||||
},
|
||||
},
|
||||
"detection": {
|
||||
"default_confidence": 0.25,
|
||||
@@ -214,5 +226,5 @@ class ConfigManager:
|
||||
def get_allowed_extensions(self) -> list:
|
||||
"""Get list of allowed image file extensions."""
|
||||
return self.get(
|
||||
"image_repository.allowed_extensions", [".jpg", ".jpeg", ".png"]
|
||||
"image_repository.allowed_extensions", Image.SUPPORTED_EXTENSIONS
|
||||
)
|
||||
|
||||
@@ -28,7 +28,9 @@ def get_image_files(
|
||||
List of absolute paths to image files
|
||||
"""
|
||||
if allowed_extensions is None:
|
||||
allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
|
||||
from src.utils.image import Image
|
||||
|
||||
allowed_extensions = Image.SUPPORTED_EXTENSIONS
|
||||
|
||||
# Normalize extensions to lowercase
|
||||
allowed_extensions = [ext.lower() for ext in allowed_extensions]
|
||||
@@ -204,7 +206,9 @@ def is_image_file(
|
||||
True if file is an image
|
||||
"""
|
||||
if allowed_extensions is None:
|
||||
allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
|
||||
from src.utils.image import Image
|
||||
|
||||
allowed_extensions = Image.SUPPORTED_EXTENSIONS
|
||||
|
||||
extension = Path(file_path).suffix.lower()
|
||||
return extension in [ext.lower() for ext in allowed_extensions]
|
||||
|
||||
@@ -277,6 +277,38 @@ class Image:
|
||||
"""
|
||||
return self._channels >= 3
|
||||
|
||||
def convert_grayscale_to_rgb_preserve_range(
|
||||
self,
|
||||
) -> PILImage.Image:
|
||||
"""Convert a single-channel PIL image to RGB while preserving dynamic range.
|
||||
|
||||
Returns:
|
||||
PIL Image in RGB mode with intensities normalized to 0-255.
|
||||
"""
|
||||
if self._channels == 3:
|
||||
return self.pil_image
|
||||
|
||||
grayscale = self.data
|
||||
if grayscale.ndim == 3:
|
||||
grayscale = grayscale[:, :, 0]
|
||||
|
||||
original_dtype = grayscale.dtype
|
||||
grayscale = grayscale.astype(np.float32)
|
||||
|
||||
if grayscale.size == 0:
|
||||
return PILImage.new("RGB", self.shape, color=(0, 0, 0))
|
||||
|
||||
if np.issubdtype(original_dtype, np.integer):
|
||||
denom = float(max(np.iinfo(original_dtype).max, 1))
|
||||
else:
|
||||
max_val = float(grayscale.max())
|
||||
denom = max(max_val, 1.0)
|
||||
|
||||
grayscale = np.clip(grayscale / denom, 0.0, 1.0)
|
||||
grayscale_u8 = (grayscale * 255.0).round().astype(np.uint8)
|
||||
rgb_arr = np.repeat(grayscale_u8[:, :, None], 3, axis=2)
|
||||
return PILImage.fromarray(rgb_arr, mode="RGB")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of the Image object."""
|
||||
return (
|
||||
|
||||
160
src/utils/image_converters.py
Normal file
160
src/utils/image_converters.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import numpy as np
|
||||
|
||||
from roifile import ImagejRoi
|
||||
from tifffile import TiffFile, TiffWriter
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class UT:
|
||||
"""
|
||||
Docstring for UT
|
||||
|
||||
Operetta files along with rois drawn in ImageJ
|
||||
"""
|
||||
|
||||
def __init__(self, roifile_fn: Path, no_labels: bool):
|
||||
self.roifile_fn = roifile_fn
|
||||
print("is file", self.roifile_fn.is_file())
|
||||
self.rois = None
|
||||
if no_labels:
|
||||
self.rois = ImagejRoi.fromfile(self.roifile_fn)
|
||||
self.stem = self.roifile_fn.stem.split("Roi-")[1]
|
||||
else:
|
||||
self.roifile_fn = roifile_fn / roifile_fn.parts[-1]
|
||||
self.stem = self.roifile_fn.stem
|
||||
|
||||
print(self.roifile_fn)
|
||||
|
||||
print(self.stem)
|
||||
self.image, self.image_props = self._load_images()
|
||||
|
||||
def _load_images(self):
|
||||
"""Loading sequence of tif files
|
||||
array sequence is CZYX
|
||||
"""
|
||||
print("Loading images:", self.roifile_fn.parent, self.stem)
|
||||
fns = list(self.roifile_fn.parent.glob(f"{self.stem.lower()}*.tif*"))
|
||||
stems = [fn.stem.split(self.stem)[-1] for fn in fns]
|
||||
n_ch = len(set([stem.split("-ch")[-1].split("t")[0] for stem in stems]))
|
||||
n_p = len(set([stem.split("-")[0] for stem in stems]))
|
||||
n_t = len(set([stem.split("t")[1] for stem in stems]))
|
||||
|
||||
with TiffFile(fns[0]) as tif:
|
||||
img = tif.asarray()
|
||||
w, h = img.shape
|
||||
dtype = img.dtype
|
||||
self.image_props = {
|
||||
"channels": n_ch,
|
||||
"planes": n_p,
|
||||
"tiles": n_t,
|
||||
"width": w,
|
||||
"height": h,
|
||||
"dtype": dtype,
|
||||
}
|
||||
print("Image props", self.image_props)
|
||||
|
||||
image_stack = np.zeros((n_ch, n_p, w, h), dtype=dtype)
|
||||
for fn in fns:
|
||||
with TiffFile(fn) as tif:
|
||||
img = tif.asarray()
|
||||
stem = fn.stem.split(self.stem)[-1]
|
||||
ch = int(stem.split("-ch")[-1].split("t")[0])
|
||||
p = int(stem.split("-")[0].split("p")[1])
|
||||
t = int(stem.split("t")[1])
|
||||
print(fn.stem, "ch", ch, "p", p, "t", t)
|
||||
image_stack[ch - 1, p - 1] = img
|
||||
|
||||
print(image_stack.shape)
|
||||
|
||||
return image_stack, self.image_props
|
||||
|
||||
@property
|
||||
def width(self):
|
||||
return self.image_props["width"]
|
||||
|
||||
@property
|
||||
def height(self):
|
||||
return self.image_props["height"]
|
||||
|
||||
@property
|
||||
def nchannels(self):
|
||||
return self.image_props["channels"]
|
||||
|
||||
@property
|
||||
def nplanes(self):
|
||||
return self.image_props["planes"]
|
||||
|
||||
def export_rois(
|
||||
self,
|
||||
path: Path,
|
||||
subfolder: str = "labels",
|
||||
class_index: int = 0,
|
||||
):
|
||||
"""Export rois to a file"""
|
||||
with open(path / subfolder / f"{self.stem}.txt", "w") as f:
|
||||
for i, roi in enumerate(self.rois):
|
||||
rc = roi.subpixel_coordinates
|
||||
if rc is None:
|
||||
print(
|
||||
f"No coordinates: {self.roifile_fn}, element {i}, out of {len(self.rois)}"
|
||||
)
|
||||
continue
|
||||
xmn, ymn = rc.min(axis=0)
|
||||
xmx, ymx = rc.max(axis=0)
|
||||
xc = (xmn + xmx) / 2
|
||||
yc = (ymn + ymx) / 2
|
||||
bw = xmx - xmn
|
||||
bh = ymx - ymn
|
||||
coords = f"{xc/self.width} {yc/self.height} {bw/self.width} {bh/self.height} "
|
||||
for x, y in rc:
|
||||
coords += f"{x/self.width} {y/self.height} "
|
||||
f.write(f"{class_index} {coords}\n")
|
||||
|
||||
return
|
||||
|
||||
def export_image(
|
||||
self,
|
||||
path: Path,
|
||||
subfolder: str = "images",
|
||||
plane_mode: str = "max projection",
|
||||
channel: int = 0,
|
||||
):
|
||||
"""Export image to a file"""
|
||||
|
||||
if plane_mode == "max projection":
|
||||
self.image = np.max(self.image[channel], axis=0)
|
||||
print(self.image.shape)
|
||||
|
||||
print(path / subfolder / f"{self.stem}.tif")
|
||||
with TiffWriter(path / subfolder / f"{self.stem}.tif") as tif:
|
||||
tif.write(self.image)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-i", "--input", nargs="*", type=Path)
|
||||
parser.add_argument("-o", "--output", type=Path)
|
||||
parser.add_argument(
|
||||
"--no-labels",
|
||||
action="store_false",
|
||||
help="Source does not have labels, export only images",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
for path in args.input:
|
||||
print("Path:", path)
|
||||
if not args.no_labels:
|
||||
print("No labels")
|
||||
ut = UT(path, args.no_labels)
|
||||
ut.export_image(args.output, plane_mode="max projection", channel=0)
|
||||
|
||||
else:
|
||||
for rfn in Path(path).glob("*.zip"):
|
||||
print("Roi FN:", rfn)
|
||||
ut = UT(rfn, args.no_labels)
|
||||
ut.export_rois(args.output, class_index=0)
|
||||
ut.export_image(args.output, plane_mode="max projection", channel=0)
|
||||
|
||||
print()
|
||||
184
tests/show_yolo_seg.py
Normal file
184
tests/show_yolo_seg.py
Normal file
@@ -0,0 +1,184 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
show_yolo_seg.py
|
||||
|
||||
Usage:
|
||||
python show_yolo_seg.py /path/to/image.jpg /path/to/labels.txt
|
||||
|
||||
Supports:
|
||||
- Segmentation polygons: "class x1 y1 x2 y2 ... xn yn"
|
||||
- YOLO bbox lines as fallback: "class x_center y_center width height"
|
||||
Coordinates can be normalized [0..1] or absolute pixels (auto-detected).
|
||||
"""
|
||||
import sys
|
||||
import cv2
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import random
|
||||
|
||||
|
||||
def parse_label_line(line):
|
||||
parts = line.strip().split()
|
||||
if not parts:
|
||||
return None
|
||||
cls = int(float(parts[0]))
|
||||
coords = [float(x) for x in parts[1:]]
|
||||
return cls, coords
|
||||
|
||||
|
||||
def coords_are_normalized(coords):
|
||||
# If every coordinate is between 0 and 1 (inclusive-ish), assume normalized
|
||||
if not coords:
|
||||
return False
|
||||
return max(coords) <= 1.001
|
||||
|
||||
|
||||
def yolo_bbox_to_xyxy(coords, img_w, img_h):
|
||||
# coords: [xc, yc, w, h] normalized or absolute
|
||||
xc, yc, w, h = coords[:4]
|
||||
if max(coords) <= 1.001:
|
||||
xc *= img_w
|
||||
yc *= img_h
|
||||
w *= img_w
|
||||
h *= img_h
|
||||
x1 = int(round(xc - w / 2))
|
||||
y1 = int(round(yc - h / 2))
|
||||
x2 = int(round(xc + w / 2))
|
||||
y2 = int(round(yc + h / 2))
|
||||
return x1, y1, x2, y2
|
||||
|
||||
|
||||
def poly_to_pts(coords, img_w, img_h):
|
||||
# coords: [x1 y1 x2 y2 ...] either normalized or absolute
|
||||
if coords_are_normalized(coords[4:]):
|
||||
coords = [
|
||||
coords[i] * (img_w if i % 2 == 0 else img_h) for i in range(len(coords))
|
||||
]
|
||||
pts = np.array(coords, dtype=np.int32).reshape(-1, 2)
|
||||
return pts
|
||||
|
||||
|
||||
def random_color_for_class(cls):
|
||||
random.seed(cls) # deterministic per class
|
||||
return tuple(int(x) for x in np.array([random.randint(0, 255) for _ in range(3)]))
|
||||
|
||||
|
||||
def draw_annotations(img, labels, alpha=0.4, draw_bbox_for_poly=True):
|
||||
# img: BGR numpy array
|
||||
overlay = img.copy()
|
||||
h, w = img.shape[:2]
|
||||
for cls, coords in labels:
|
||||
if not coords:
|
||||
continue
|
||||
# polygon case (>=6 coordinates)
|
||||
if len(coords) >= 6:
|
||||
color = random_color_for_class(cls)
|
||||
|
||||
x1, y1, x2, y2 = yolo_bbox_to_xyxy(coords[:4], w, h)
|
||||
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
|
||||
|
||||
pts = poly_to_pts(coords[4:], w, h)
|
||||
# fill on overlay
|
||||
cv2.fillPoly(overlay, [pts], color)
|
||||
# outline on base image
|
||||
cv2.polylines(img, [pts], isClosed=True, color=color, thickness=2)
|
||||
# put class text at first point
|
||||
x, y = int(pts[0, 0]), int(pts[0, 1]) - 6
|
||||
cv2.putText(
|
||||
img,
|
||||
str(cls),
|
||||
(x, max(6, y)),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.6,
|
||||
(255, 255, 255),
|
||||
2,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
|
||||
# YOLO bbox case (4 coords)
|
||||
elif len(coords) == 4:
|
||||
x1, y1, x2, y2 = yolo_bbox_to_xyxy(coords, w, h)
|
||||
color = random_color_for_class(cls)
|
||||
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
|
||||
cv2.putText(
|
||||
img,
|
||||
str(cls),
|
||||
(x1, max(6, y1 - 4)),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.6,
|
||||
(255, 255, 255),
|
||||
2,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
else:
|
||||
# Unknown / invalid format, skip
|
||||
continue
|
||||
|
||||
# blend overlay for filled polygons
|
||||
cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0, img)
|
||||
return img
|
||||
|
||||
|
||||
def load_labels_file(label_path):
|
||||
labels = []
|
||||
with open(label_path, "r") as f:
|
||||
for raw in f:
|
||||
line = raw.strip()
|
||||
if not line:
|
||||
continue
|
||||
parsed = parse_label_line(line)
|
||||
if parsed:
|
||||
labels.append(parsed)
|
||||
return labels
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Show YOLO segmentation / polygon annotations"
|
||||
)
|
||||
parser.add_argument("image", type=str, help="Path to image file")
|
||||
parser.add_argument("labels", type=str, help="Path to YOLO label file (polygons)")
|
||||
parser.add_argument(
|
||||
"--alpha", type=float, default=0.4, help="Polygon fill alpha (0..1)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-bbox", action="store_true", help="Don't draw bounding boxes for polygons"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
img_path = Path(args.image)
|
||||
lbl_path = Path(args.labels)
|
||||
|
||||
if not img_path.exists():
|
||||
print("Image not found:", img_path)
|
||||
sys.exit(1)
|
||||
if not lbl_path.exists():
|
||||
print("Label file not found:", lbl_path)
|
||||
sys.exit(1)
|
||||
|
||||
img = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
print("Could not load image:", img_path)
|
||||
sys.exit(1)
|
||||
|
||||
labels = load_labels_file(str(lbl_path))
|
||||
if not labels:
|
||||
print("No labels parsed from", lbl_path)
|
||||
# continue and just show image
|
||||
out = draw_annotations(
|
||||
img.copy(), labels, alpha=args.alpha, draw_bbox_for_poly=(not args.no_bbox)
|
||||
)
|
||||
|
||||
# Convert BGR -> RGB for matplotlib display
|
||||
out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
|
||||
plt.figure(figsize=(10, 10 * out.shape[0] / out.shape[1]))
|
||||
plt.imshow(out_rgb)
|
||||
plt.axis("off")
|
||||
plt.title(f"{img_path.name} ({lbl_path.name})")
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -27,7 +27,7 @@ class TestImage:
|
||||
|
||||
def test_supported_extensions(self):
|
||||
"""Test that supported extensions are correctly defined."""
|
||||
expected_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
|
||||
expected_extensions = Image.SUPPORTED_EXTENSIONS
|
||||
assert Image.SUPPORTED_EXTENSIONS == expected_extensions
|
||||
|
||||
def test_image_properties(self, tmp_path):
|
||||
|
||||
Reference in New Issue
Block a user