Adding pen tool for annotation
This commit is contained in:
@@ -30,18 +30,48 @@ class DatabaseManager:
|
||||
# Create directory if it doesn't exist
|
||||
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Read schema file and execute
|
||||
schema_path = Path(__file__).parent / "schema.sql"
|
||||
with open(schema_path, "r") as f:
|
||||
schema_sql = f.read()
|
||||
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
# Check if annotations table needs migration
|
||||
self._migrate_annotations_table(conn)
|
||||
|
||||
# Read schema file and execute
|
||||
schema_path = Path(__file__).parent / "schema.sql"
|
||||
with open(schema_path, "r") as f:
|
||||
schema_sql = f.read()
|
||||
|
||||
conn.executescript(schema_sql)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _migrate_annotations_table(self, conn: sqlite3.Connection) -> None:
|
||||
"""
|
||||
Migrate annotations table from old schema (class_name) to new schema (class_id).
|
||||
"""
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check if annotations table exists
|
||||
cursor.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='annotations'"
|
||||
)
|
||||
if not cursor.fetchone():
|
||||
# Table doesn't exist yet, no migration needed
|
||||
return
|
||||
|
||||
# Check if table has old schema (class_name column)
|
||||
cursor.execute("PRAGMA table_info(annotations)")
|
||||
columns = {row[1]: row for row in cursor.fetchall()}
|
||||
|
||||
if "class_name" in columns and "class_id" not in columns:
|
||||
# Old schema detected, need to migrate
|
||||
print("Migrating annotations table to new schema with class_id...")
|
||||
|
||||
# Drop old annotations table (assuming no critical data since this is a new feature)
|
||||
cursor.execute("DROP TABLE IF EXISTS annotations")
|
||||
conn.commit()
|
||||
print("Old annotations table dropped, will be recreated with new schema")
|
||||
|
||||
def get_connection(self) -> sqlite3.Connection:
|
||||
"""Get database connection with proper settings."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
@@ -593,25 +623,38 @@ class DatabaseManager:
|
||||
def add_annotation(
|
||||
self,
|
||||
image_id: int,
|
||||
class_name: str,
|
||||
class_id: int,
|
||||
bbox: Tuple[float, float, float, float],
|
||||
annotator: str,
|
||||
segmentation_mask: Optional[List[List[float]]] = None,
|
||||
verified: bool = False,
|
||||
) -> int:
|
||||
"""Add manual annotation."""
|
||||
"""
|
||||
Add manual annotation.
|
||||
|
||||
Args:
|
||||
image_id: ID of the image
|
||||
class_id: ID of the object class (foreign key to object_classes)
|
||||
bbox: Bounding box coordinates (normalized 0-1)
|
||||
annotator: Name of person/tool creating annotation
|
||||
segmentation_mask: Polygon coordinates for segmentation
|
||||
verified: Whether annotation has been verified
|
||||
|
||||
Returns:
|
||||
ID of the inserted annotation
|
||||
"""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
x_min, y_min, x_max, y_max = bbox
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO annotations (image_id, class_name, x_min, y_min, x_max, y_max, segmentation_mask, annotator, verified)
|
||||
INSERT INTO annotations (image_id, class_id, x_min, y_min, x_max, y_max, segmentation_mask, annotator, verified)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
image_id,
|
||||
class_name,
|
||||
class_id,
|
||||
x_min,
|
||||
y_min,
|
||||
x_max,
|
||||
@@ -627,15 +670,178 @@ class DatabaseManager:
|
||||
conn.close()
|
||||
|
||||
def get_annotations_for_image(self, image_id: int) -> List[Dict]:
|
||||
"""Get all annotations for an image."""
|
||||
"""
|
||||
Get all annotations for an image with class information.
|
||||
|
||||
Args:
|
||||
image_id: ID of the image
|
||||
|
||||
Returns:
|
||||
List of annotation dictionaries with joined class information
|
||||
"""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM annotations WHERE image_id = ?", (image_id,))
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
a.*,
|
||||
c.class_name,
|
||||
c.color as class_color,
|
||||
c.description as class_description
|
||||
FROM annotations a
|
||||
JOIN object_classes c ON a.class_id = c.id
|
||||
WHERE a.image_id = ?
|
||||
ORDER BY a.created_at DESC
|
||||
""",
|
||||
(image_id,),
|
||||
)
|
||||
annotations = []
|
||||
for row in cursor.fetchall():
|
||||
ann = dict(row)
|
||||
if ann.get("segmentation_mask"):
|
||||
ann["segmentation_mask"] = json.loads(ann["segmentation_mask"])
|
||||
annotations.append(ann)
|
||||
return annotations
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# ==================== Object Class Operations ====================
|
||||
|
||||
def get_object_classes(self) -> List[Dict]:
|
||||
"""
|
||||
Get all object classes.
|
||||
|
||||
Returns:
|
||||
List of object class dictionaries
|
||||
"""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM object_classes ORDER BY class_name")
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_object_class_by_id(self, class_id: int) -> Optional[Dict]:
|
||||
"""Get object class by ID."""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM object_classes WHERE id = ?", (class_id,))
|
||||
row = cursor.fetchone()
|
||||
return dict(row) if row else None
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_object_class_by_name(self, class_name: str) -> Optional[Dict]:
|
||||
"""Get object class by name."""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT * FROM object_classes WHERE class_name = ?", (class_name,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return dict(row) if row else None
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def add_object_class(
|
||||
self, class_name: str, color: str, description: Optional[str] = None
|
||||
) -> int:
|
||||
"""
|
||||
Add a new object class.
|
||||
|
||||
Args:
|
||||
class_name: Name of the object class
|
||||
color: Hex color code (e.g., '#FF0000')
|
||||
description: Optional description
|
||||
|
||||
Returns:
|
||||
ID of the inserted object class
|
||||
"""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO object_classes (class_name, color, description)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
(class_name, color, description),
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.lastrowid
|
||||
except sqlite3.IntegrityError:
|
||||
# Class already exists
|
||||
existing = self.get_object_class_by_name(class_name)
|
||||
return existing["id"] if existing else None
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def update_object_class(
|
||||
self,
|
||||
class_id: int,
|
||||
class_name: Optional[str] = None,
|
||||
color: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Update an object class.
|
||||
|
||||
Args:
|
||||
class_id: ID of the class to update
|
||||
class_name: New class name (optional)
|
||||
color: New color (optional)
|
||||
description: New description (optional)
|
||||
|
||||
Returns:
|
||||
True if updated, False otherwise
|
||||
"""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
updates = {}
|
||||
if class_name is not None:
|
||||
updates["class_name"] = class_name
|
||||
if color is not None:
|
||||
updates["color"] = color
|
||||
if description is not None:
|
||||
updates["description"] = description
|
||||
|
||||
if not updates:
|
||||
return False
|
||||
|
||||
set_clauses = [f"{key} = ?" for key in updates.keys()]
|
||||
params = list(updates.values()) + [class_id]
|
||||
|
||||
query = f"UPDATE object_classes SET {', '.join(set_clauses)} WHERE id = ?"
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(query, params)
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def delete_object_class(self, class_id: int) -> bool:
|
||||
"""
|
||||
Delete an object class.
|
||||
|
||||
Args:
|
||||
class_id: ID of the class to delete
|
||||
|
||||
Returns:
|
||||
True if deleted, False otherwise
|
||||
"""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM object_classes WHERE id = ?", (class_id,))
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
@staticmethod
|
||||
def calculate_checksum(file_path: str) -> str:
|
||||
"""Calculate MD5 checksum of a file."""
|
||||
|
||||
@@ -44,11 +44,27 @@ CREATE TABLE IF NOT EXISTS detections (
|
||||
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- Annotations table: stores manual annotations (future feature)
|
||||
-- Object classes table: stores annotation class definitions with colors
|
||||
CREATE TABLE IF NOT EXISTS object_classes (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
class_name TEXT NOT NULL UNIQUE,
|
||||
color TEXT NOT NULL, -- Hex color code (e.g., '#FF0000')
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
description TEXT
|
||||
);
|
||||
|
||||
-- Insert default object classes
|
||||
INSERT OR IGNORE INTO object_classes (class_name, color, description) VALUES
|
||||
('cell', '#FF0000', 'Cell object'),
|
||||
('nucleus', '#00FF00', 'Cell nucleus'),
|
||||
('mitochondria', '#0000FF', 'Mitochondria'),
|
||||
('vesicle', '#FFFF00', 'Vesicle');
|
||||
|
||||
-- Annotations table: stores manual annotations
|
||||
CREATE TABLE IF NOT EXISTS annotations (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
image_id INTEGER NOT NULL,
|
||||
class_name TEXT NOT NULL,
|
||||
class_id INTEGER NOT NULL,
|
||||
x_min REAL NOT NULL CHECK(x_min >= 0 AND x_min <= 1),
|
||||
y_min REAL NOT NULL CHECK(y_min >= 0 AND y_min <= 1),
|
||||
x_max REAL NOT NULL CHECK(x_max >= 0 AND x_max <= 1),
|
||||
@@ -57,7 +73,8 @@ CREATE TABLE IF NOT EXISTS annotations (
|
||||
annotator TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
verified BOOLEAN DEFAULT 0,
|
||||
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE
|
||||
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (class_id) REFERENCES object_classes (id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- Create indexes for performance optimization
|
||||
@@ -69,4 +86,6 @@ CREATE INDEX IF NOT EXISTS idx_detections_confidence ON detections(confidence);
|
||||
CREATE INDEX IF NOT EXISTS idx_images_relative_path ON images(relative_path);
|
||||
CREATE INDEX IF NOT EXISTS idx_images_added_at ON images(added_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_annotations_image_id ON annotations(image_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_annotations_class_id ON annotations(class_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_object_classes_class_name ON object_classes(class_name);
|
||||
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
Annotation tab for the microscopy object detection application.
|
||||
Future feature for manual annotation.
|
||||
Manual annotation with pen tool and object class management.
|
||||
"""
|
||||
|
||||
from PySide6.QtWidgets import (
|
||||
@@ -21,13 +21,13 @@ from src.database.db_manager import DatabaseManager
|
||||
from src.utils.config_manager import ConfigManager
|
||||
from src.utils.image import Image, ImageLoadError
|
||||
from src.utils.logger import get_logger
|
||||
from src.gui.widgets import ImageDisplayWidget
|
||||
from src.gui.widgets import AnnotationCanvasWidget, AnnotationToolsWidget
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AnnotationTab(QWidget):
|
||||
"""Annotation tab placeholder (future feature)."""
|
||||
"""Annotation tab for manual image annotation."""
|
||||
|
||||
def __init__(
|
||||
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
||||
@@ -37,6 +37,7 @@ class AnnotationTab(QWidget):
|
||||
self.config_manager = config_manager
|
||||
self.current_image = None
|
||||
self.current_image_path = None
|
||||
self.current_image_id = None
|
||||
|
||||
self._setup_ui()
|
||||
|
||||
@@ -52,49 +53,52 @@ class AnnotationTab(QWidget):
|
||||
self.left_splitter = QSplitter(Qt.Vertical)
|
||||
self.left_splitter.setHandleWidth(10)
|
||||
|
||||
# Image display section
|
||||
display_group = QGroupBox("Image Display")
|
||||
display_layout = QVBoxLayout()
|
||||
# Annotation canvas section
|
||||
canvas_group = QGroupBox("Annotation Canvas")
|
||||
canvas_layout = QVBoxLayout()
|
||||
|
||||
# Use the reusable ImageDisplayWidget
|
||||
self.image_display_widget = ImageDisplayWidget()
|
||||
self.image_display_widget.zoom_changed.connect(self._on_zoom_changed)
|
||||
display_layout.addWidget(self.image_display_widget)
|
||||
# Use the AnnotationCanvasWidget
|
||||
self.annotation_canvas = AnnotationCanvasWidget()
|
||||
self.annotation_canvas.zoom_changed.connect(self._on_zoom_changed)
|
||||
self.annotation_canvas.annotation_drawn.connect(self._on_annotation_drawn)
|
||||
canvas_layout.addWidget(self.annotation_canvas)
|
||||
|
||||
display_group.setLayout(display_layout)
|
||||
self.left_splitter.addWidget(display_group)
|
||||
canvas_group.setLayout(canvas_layout)
|
||||
self.left_splitter.addWidget(canvas_group)
|
||||
|
||||
# Zoom controls info
|
||||
zoom_info = QLabel("Zoom: Mouse wheel or +/- keys to zoom in/out")
|
||||
zoom_info.setStyleSheet("QLabel { color: #888; font-style: italic; }")
|
||||
self.left_splitter.addWidget(zoom_info)
|
||||
# Controls info
|
||||
controls_info = QLabel(
|
||||
"Zoom: Mouse wheel or +/- keys | Drawing: Enable pen and drag mouse"
|
||||
)
|
||||
controls_info.setStyleSheet("QLabel { color: #888; font-style: italic; }")
|
||||
self.left_splitter.addWidget(controls_info)
|
||||
# }
|
||||
|
||||
# { Right splitter for annotation tools and controls
|
||||
self.right_splitter = QSplitter(Qt.Vertical)
|
||||
self.right_splitter.setHandleWidth(10)
|
||||
|
||||
# Annotation tools section
|
||||
self.annotation_tools = AnnotationToolsWidget(self.db_manager)
|
||||
self.annotation_tools.pen_enabled_changed.connect(
|
||||
self.annotation_canvas.set_pen_enabled
|
||||
)
|
||||
self.annotation_tools.pen_color_changed.connect(
|
||||
self.annotation_canvas.set_pen_color
|
||||
)
|
||||
self.annotation_tools.pen_width_changed.connect(
|
||||
self.annotation_canvas.set_pen_width
|
||||
)
|
||||
self.annotation_tools.class_selected.connect(self._on_class_selected)
|
||||
self.annotation_tools.clear_annotations_requested.connect(
|
||||
self._on_clear_annotations
|
||||
)
|
||||
self.right_splitter.addWidget(self.annotation_tools)
|
||||
|
||||
# Image loading section
|
||||
load_group = QGroupBox("Image Loading")
|
||||
load_layout = QVBoxLayout()
|
||||
|
||||
# Future features info
|
||||
info_group = QGroupBox("Annotation Tool (Future Feature)")
|
||||
info_layout = QVBoxLayout()
|
||||
info_label = QLabel(
|
||||
"Full annotation functionality will be implemented in future version.\n\n"
|
||||
"Planned Features:\n"
|
||||
"- Drawing tools for bounding boxes\n"
|
||||
"- Class label assignment\n"
|
||||
"- Export annotations to YOLO format\n"
|
||||
"- Annotation verification"
|
||||
)
|
||||
info_label.setWordWrap(True)
|
||||
info_layout.addWidget(info_label)
|
||||
info_group.setLayout(info_layout)
|
||||
|
||||
self.right_splitter.addWidget(info_group)
|
||||
|
||||
# Load image button
|
||||
button_layout = QHBoxLayout()
|
||||
self.load_image_btn = QPushButton("Load Image")
|
||||
@@ -158,13 +162,22 @@ class AnnotationTab(QWidget):
|
||||
"annotation_tab/last_directory", str(Path(file_path).parent)
|
||||
)
|
||||
|
||||
# Display image using the ImageDisplayWidget
|
||||
self.image_display_widget.load_image(self.current_image)
|
||||
# Get or create image in database
|
||||
relative_path = str(Path(file_path).name) # Simplified for now
|
||||
self.current_image_id = self.db_manager.get_or_create_image(
|
||||
relative_path,
|
||||
Path(file_path).name,
|
||||
self.current_image.width,
|
||||
self.current_image.height,
|
||||
)
|
||||
|
||||
# Display image using the AnnotationCanvasWidget
|
||||
self.annotation_canvas.load_image(self.current_image)
|
||||
|
||||
# Update info label
|
||||
self._update_image_info()
|
||||
|
||||
logger.info(f"Loaded image: {file_path}")
|
||||
logger.info(f"Loaded image: {file_path} (DB ID: {self.current_image_id})")
|
||||
|
||||
except ImageLoadError as e:
|
||||
logger.error(f"Failed to load image: {e}")
|
||||
@@ -181,7 +194,7 @@ class AnnotationTab(QWidget):
|
||||
self.image_info_label.setText("No image loaded")
|
||||
return
|
||||
|
||||
zoom_percentage = self.image_display_widget.get_zoom_percentage()
|
||||
zoom_percentage = self.annotation_canvas.get_zoom_percentage()
|
||||
info_text = (
|
||||
f"File: {Path(self.current_image_path).name}\n"
|
||||
f"Size: {self.current_image.width}x{self.current_image.height} pixels\n"
|
||||
@@ -194,9 +207,36 @@ class AnnotationTab(QWidget):
|
||||
self.image_info_label.setText(info_text)
|
||||
|
||||
def _on_zoom_changed(self, zoom_scale: float):
|
||||
"""Handle zoom level changes from the image display widget."""
|
||||
"""Handle zoom level changes from the annotation canvas."""
|
||||
self._update_image_info()
|
||||
|
||||
def _on_annotation_drawn(self, points: list):
|
||||
"""Handle when an annotation stroke is drawn."""
|
||||
current_class = self.annotation_tools.get_current_class()
|
||||
|
||||
if not current_class:
|
||||
logger.warning("Annotation drawn but no object class selected")
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"No Class Selected",
|
||||
"Please select an object class before drawing annotations.",
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Annotation drawn with {len(points)} points for class: {current_class['class_name']}"
|
||||
)
|
||||
# Future: Save annotation to database or export
|
||||
|
||||
def _on_class_selected(self, class_data: dict):
|
||||
"""Handle when an object class is selected."""
|
||||
logger.debug(f"Object class selected: {class_data['class_name']}")
|
||||
|
||||
def _on_clear_annotations(self):
|
||||
"""Handle clearing all annotations."""
|
||||
self.annotation_canvas.clear_annotations()
|
||||
logger.info("Cleared all annotations")
|
||||
|
||||
def _restore_state(self):
|
||||
"""Restore splitter positions from settings."""
|
||||
settings = QSettings("microscopy_app", "object_detection")
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""GUI widgets for the microscopy object detection application."""
|
||||
|
||||
from src.gui.widgets.image_display_widget import ImageDisplayWidget
|
||||
from src.gui.widgets.annotation_canvas_widget import AnnotationCanvasWidget
|
||||
from src.gui.widgets.annotation_tools_widget import AnnotationToolsWidget
|
||||
|
||||
__all__ = ["ImageDisplayWidget"]
|
||||
__all__ = ["ImageDisplayWidget", "AnnotationCanvasWidget", "AnnotationToolsWidget"]
|
||||
|
||||
406
src/gui/widgets/annotation_canvas_widget.py
Normal file
406
src/gui/widgets/annotation_canvas_widget.py
Normal file
@@ -0,0 +1,406 @@
|
||||
"""
|
||||
Annotation canvas widget for drawing annotations on images.
|
||||
Supports pen tool with color selection for manual annotation.
|
||||
"""
|
||||
|
||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QScrollArea
|
||||
from PySide6.QtGui import (
|
||||
QPixmap,
|
||||
QImage,
|
||||
QPainter,
|
||||
QPen,
|
||||
QColor,
|
||||
QKeyEvent,
|
||||
QMouseEvent,
|
||||
QPaintEvent,
|
||||
)
|
||||
from PySide6.QtCore import Qt, QEvent, Signal, QPoint
|
||||
from typing import List, Optional, Tuple
|
||||
import numpy as np
|
||||
|
||||
from src.utils.image import Image, ImageLoadError
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AnnotationCanvasWidget(QWidget):
|
||||
"""
|
||||
Widget for displaying images and drawing annotations with pen tool.
|
||||
|
||||
Features:
|
||||
- Display images with zoom functionality
|
||||
- Pen tool for drawing annotations
|
||||
- Configurable pen color and width
|
||||
- Mouse-based drawing interface
|
||||
- Zoom in/out with mouse wheel and keyboard
|
||||
|
||||
Signals:
|
||||
zoom_changed: Emitted when zoom level changes (float zoom_scale)
|
||||
annotation_drawn: Emitted when a new stroke is completed (list of points)
|
||||
"""
|
||||
|
||||
zoom_changed = Signal(float)
|
||||
annotation_drawn = Signal(list) # List of (x, y) points in normalized coordinates
|
||||
|
||||
def __init__(self, parent=None):
|
||||
"""Initialize the annotation canvas widget."""
|
||||
super().__init__(parent)
|
||||
|
||||
self.current_image = None
|
||||
self.original_pixmap = None
|
||||
self.annotation_pixmap = None # Overlay for annotations
|
||||
self.zoom_scale = 1.0
|
||||
self.zoom_min = 0.1
|
||||
self.zoom_max = 10.0
|
||||
self.zoom_step = 0.1
|
||||
self.zoom_wheel_step = 0.15
|
||||
|
||||
# Drawing state
|
||||
self.is_drawing = False
|
||||
self.pen_enabled = False
|
||||
self.pen_color = QColor(255, 0, 0, 128) # Default red with 50% alpha
|
||||
self.pen_width = 3
|
||||
self.current_stroke = [] # Points in current stroke
|
||||
self.all_strokes = [] # All completed strokes
|
||||
|
||||
self._setup_ui()
|
||||
|
||||
def _setup_ui(self):
|
||||
"""Setup user interface."""
|
||||
layout = QVBoxLayout()
|
||||
layout.setContentsMargins(0, 0, 0, 0)
|
||||
|
||||
# Scroll area for canvas
|
||||
self.scroll_area = QScrollArea()
|
||||
self.scroll_area.setWidgetResizable(True)
|
||||
self.scroll_area.setMinimumHeight(400)
|
||||
|
||||
self.canvas_label = QLabel("No image loaded")
|
||||
self.canvas_label.setAlignment(Qt.AlignCenter)
|
||||
self.canvas_label.setStyleSheet(
|
||||
"QLabel { background-color: #2b2b2b; color: #888; }"
|
||||
)
|
||||
self.canvas_label.setScaledContents(False)
|
||||
self.canvas_label.setMouseTracking(True)
|
||||
|
||||
self.scroll_area.setWidget(self.canvas_label)
|
||||
self.scroll_area.viewport().installEventFilter(self)
|
||||
|
||||
layout.addWidget(self.scroll_area)
|
||||
self.setLayout(layout)
|
||||
|
||||
self.setFocusPolicy(Qt.StrongFocus)
|
||||
|
||||
def load_image(self, image: Image):
|
||||
"""
|
||||
Load and display an image.
|
||||
|
||||
Args:
|
||||
image: Image object to display
|
||||
"""
|
||||
self.current_image = image
|
||||
self.zoom_scale = 1.0
|
||||
self.clear_annotations()
|
||||
self._display_image()
|
||||
logger.debug(
|
||||
f"Loaded image into annotation canvas: {image.width}x{image.height}"
|
||||
)
|
||||
|
||||
def clear(self):
|
||||
"""Clear the displayed image and all annotations."""
|
||||
self.current_image = None
|
||||
self.original_pixmap = None
|
||||
self.annotation_pixmap = None
|
||||
self.zoom_scale = 1.0
|
||||
self.clear_annotations()
|
||||
self.canvas_label.setText("No image loaded")
|
||||
self.canvas_label.setPixmap(QPixmap())
|
||||
|
||||
def clear_annotations(self):
|
||||
"""Clear all drawn annotations."""
|
||||
self.all_strokes = []
|
||||
self.current_stroke = []
|
||||
self.is_drawing = False
|
||||
if self.annotation_pixmap:
|
||||
self.annotation_pixmap.fill(Qt.transparent)
|
||||
self._update_display()
|
||||
|
||||
def _display_image(self):
|
||||
"""Display the current image in the canvas."""
|
||||
if self.current_image is None:
|
||||
return
|
||||
|
||||
try:
|
||||
# Get RGB image data
|
||||
if self.current_image.channels == 3:
|
||||
image_data = self.current_image.get_rgb()
|
||||
height, width, channels = image_data.shape
|
||||
else:
|
||||
image_data = self.current_image.get_grayscale()
|
||||
height, width = image_data.shape
|
||||
|
||||
image_data = np.ascontiguousarray(image_data)
|
||||
bytes_per_line = image_data.strides[0]
|
||||
|
||||
qimage = QImage(
|
||||
image_data.data,
|
||||
width,
|
||||
height,
|
||||
bytes_per_line,
|
||||
self.current_image.qtimage_format,
|
||||
)
|
||||
|
||||
self.original_pixmap = QPixmap.fromImage(qimage)
|
||||
|
||||
# Create transparent overlay for annotations
|
||||
self.annotation_pixmap = QPixmap(self.original_pixmap.size())
|
||||
self.annotation_pixmap.fill(Qt.transparent)
|
||||
|
||||
self._apply_zoom()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error displaying image: {e}")
|
||||
raise ImageLoadError(f"Failed to display image: {str(e)}")
|
||||
|
||||
def _apply_zoom(self):
|
||||
"""Apply current zoom level to the displayed image."""
|
||||
if self.original_pixmap is None:
|
||||
return
|
||||
|
||||
scaled_width = int(self.original_pixmap.width() * self.zoom_scale)
|
||||
scaled_height = int(self.original_pixmap.height() * self.zoom_scale)
|
||||
|
||||
# Scale both image and annotations
|
||||
scaled_image = self.original_pixmap.scaled(
|
||||
scaled_width,
|
||||
scaled_height,
|
||||
Qt.KeepAspectRatio,
|
||||
(
|
||||
Qt.SmoothTransformation
|
||||
if self.zoom_scale >= 1.0
|
||||
else Qt.FastTransformation
|
||||
),
|
||||
)
|
||||
|
||||
scaled_annotations = self.annotation_pixmap.scaled(
|
||||
scaled_width,
|
||||
scaled_height,
|
||||
Qt.KeepAspectRatio,
|
||||
(
|
||||
Qt.SmoothTransformation
|
||||
if self.zoom_scale >= 1.0
|
||||
else Qt.FastTransformation
|
||||
),
|
||||
)
|
||||
|
||||
# Composite image and annotations
|
||||
combined = QPixmap(scaled_image.size())
|
||||
painter = QPainter(combined)
|
||||
painter.drawPixmap(0, 0, scaled_image)
|
||||
painter.drawPixmap(0, 0, scaled_annotations)
|
||||
painter.end()
|
||||
|
||||
self.canvas_label.setPixmap(combined)
|
||||
self.canvas_label.setScaledContents(False)
|
||||
self.canvas_label.adjustSize()
|
||||
|
||||
self.zoom_changed.emit(self.zoom_scale)
|
||||
|
||||
def _update_display(self):
|
||||
"""Update display after drawing."""
|
||||
self._apply_zoom()
|
||||
|
||||
def set_pen_enabled(self, enabled: bool):
|
||||
"""Enable or disable pen tool."""
|
||||
self.pen_enabled = enabled
|
||||
if enabled:
|
||||
self.canvas_label.setCursor(Qt.CrossCursor)
|
||||
else:
|
||||
self.canvas_label.setCursor(Qt.ArrowCursor)
|
||||
|
||||
def set_pen_color(self, color: QColor):
|
||||
"""Set pen color."""
|
||||
self.pen_color = color
|
||||
|
||||
def set_pen_width(self, width: int):
|
||||
"""Set pen width."""
|
||||
self.pen_width = max(1, width)
|
||||
|
||||
def get_zoom_percentage(self) -> int:
|
||||
"""Get current zoom level as percentage."""
|
||||
return int(self.zoom_scale * 100)
|
||||
|
||||
def zoom_in(self):
|
||||
"""Zoom in on the image."""
|
||||
if self.original_pixmap is None:
|
||||
return
|
||||
new_scale = self.zoom_scale + self.zoom_step
|
||||
if new_scale <= self.zoom_max:
|
||||
self.zoom_scale = new_scale
|
||||
self._apply_zoom()
|
||||
|
||||
def zoom_out(self):
|
||||
"""Zoom out from the image."""
|
||||
if self.original_pixmap is None:
|
||||
return
|
||||
new_scale = self.zoom_scale - self.zoom_step
|
||||
if new_scale >= self.zoom_min:
|
||||
self.zoom_scale = new_scale
|
||||
self._apply_zoom()
|
||||
|
||||
def reset_zoom(self):
|
||||
"""Reset zoom to 100%."""
|
||||
if self.original_pixmap is None:
|
||||
return
|
||||
self.zoom_scale = 1.0
|
||||
self._apply_zoom()
|
||||
|
||||
def _canvas_to_image_coords(self, pos: QPoint) -> Optional[Tuple[int, int]]:
|
||||
"""Convert canvas coordinates to image coordinates, accounting for zoom and centering."""
|
||||
if self.original_pixmap is None or self.canvas_label.pixmap() is None:
|
||||
return None
|
||||
|
||||
# Get the displayed pixmap size (after zoom)
|
||||
displayed_pixmap = self.canvas_label.pixmap()
|
||||
displayed_width = displayed_pixmap.width()
|
||||
displayed_height = displayed_pixmap.height()
|
||||
|
||||
# Calculate offset due to label centering (label might be larger than pixmap)
|
||||
label_width = self.canvas_label.width()
|
||||
label_height = self.canvas_label.height()
|
||||
offset_x = max(0, (label_width - displayed_width) // 2)
|
||||
offset_y = max(0, (label_height - displayed_height) // 2)
|
||||
|
||||
# Adjust position for offset and convert to image coordinates
|
||||
x = (pos.x() - offset_x) / self.zoom_scale
|
||||
y = (pos.y() - offset_y) / self.zoom_scale
|
||||
|
||||
# Check bounds
|
||||
if (
|
||||
0 <= x < self.original_pixmap.width()
|
||||
and 0 <= y < self.original_pixmap.height()
|
||||
):
|
||||
return (int(x), int(y))
|
||||
return None
|
||||
|
||||
def _image_to_normalized_coords(self, x: int, y: int) -> Tuple[float, float]:
|
||||
"""Convert image coordinates to normalized coordinates (0-1)."""
|
||||
if self.original_pixmap is None:
|
||||
return (0.0, 0.0)
|
||||
|
||||
norm_x = x / self.original_pixmap.width()
|
||||
norm_y = y / self.original_pixmap.height()
|
||||
return (norm_x, norm_y)
|
||||
|
||||
def mousePressEvent(self, event: QMouseEvent):
|
||||
"""Handle mouse press events for drawing."""
|
||||
if not self.pen_enabled or self.annotation_pixmap is None:
|
||||
super().mousePressEvent(event)
|
||||
return
|
||||
|
||||
if event.button() == Qt.LeftButton:
|
||||
# Get accurate position using global coordinates
|
||||
label_pos = self.canvas_label.mapFromGlobal(event.globalPos())
|
||||
img_coords = self._canvas_to_image_coords(label_pos)
|
||||
|
||||
if img_coords:
|
||||
self.is_drawing = True
|
||||
self.current_stroke = [img_coords]
|
||||
|
||||
def mouseMoveEvent(self, event: QMouseEvent):
|
||||
"""Handle mouse move events for drawing."""
|
||||
if (
|
||||
not self.is_drawing
|
||||
or not self.pen_enabled
|
||||
or self.annotation_pixmap is None
|
||||
):
|
||||
super().mouseMoveEvent(event)
|
||||
return
|
||||
|
||||
# Get accurate position using global coordinates
|
||||
label_pos = self.canvas_label.mapFromGlobal(event.globalPos())
|
||||
img_coords = self._canvas_to_image_coords(label_pos)
|
||||
|
||||
if img_coords and len(self.current_stroke) > 0:
|
||||
# Draw line from last point to current point
|
||||
painter = QPainter(self.annotation_pixmap)
|
||||
pen = QPen(
|
||||
self.pen_color, self.pen_width, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin
|
||||
)
|
||||
painter.setPen(pen)
|
||||
|
||||
last_point = self.current_stroke[-1]
|
||||
painter.drawLine(last_point[0], last_point[1], img_coords[0], img_coords[1])
|
||||
painter.end()
|
||||
|
||||
self.current_stroke.append(img_coords)
|
||||
self._update_display()
|
||||
|
||||
def mouseReleaseEvent(self, event: QMouseEvent):
|
||||
"""Handle mouse release events to complete a stroke."""
|
||||
if not self.is_drawing or event.button() != Qt.LeftButton:
|
||||
super().mouseReleaseEvent(event)
|
||||
return
|
||||
|
||||
self.is_drawing = False
|
||||
|
||||
if len(self.current_stroke) > 1:
|
||||
# Convert to normalized coordinates and save stroke
|
||||
normalized_stroke = [
|
||||
self._image_to_normalized_coords(x, y) for x, y in self.current_stroke
|
||||
]
|
||||
self.all_strokes.append(
|
||||
{
|
||||
"points": normalized_stroke,
|
||||
"color": self.pen_color.name(),
|
||||
"alpha": self.pen_color.alpha(),
|
||||
"width": self.pen_width,
|
||||
}
|
||||
)
|
||||
|
||||
# Emit signal with normalized coordinates
|
||||
self.annotation_drawn.emit(normalized_stroke)
|
||||
logger.debug(f"Completed stroke with {len(normalized_stroke)} points")
|
||||
|
||||
self.current_stroke = []
|
||||
|
||||
def get_all_strokes(self) -> List[dict]:
|
||||
"""Get all drawn strokes with metadata."""
|
||||
return self.all_strokes
|
||||
|
||||
def keyPressEvent(self, event: QKeyEvent):
|
||||
"""Handle keyboard events for zooming."""
|
||||
if event.key() in (Qt.Key_Plus, Qt.Key_Equal):
|
||||
self.zoom_in()
|
||||
event.accept()
|
||||
elif event.key() == Qt.Key_Minus:
|
||||
self.zoom_out()
|
||||
event.accept()
|
||||
elif event.key() == Qt.Key_0 and event.modifiers() == Qt.ControlModifier:
|
||||
self.reset_zoom()
|
||||
event.accept()
|
||||
else:
|
||||
super().keyPressEvent(event)
|
||||
|
||||
def eventFilter(self, obj, event: QEvent) -> bool:
|
||||
"""Event filter to capture wheel events for zooming."""
|
||||
if event.type() == QEvent.Wheel:
|
||||
wheel_event = event
|
||||
if self.original_pixmap is not None:
|
||||
delta = wheel_event.angleDelta().y()
|
||||
|
||||
if delta > 0:
|
||||
new_scale = self.zoom_scale + self.zoom_wheel_step
|
||||
if new_scale <= self.zoom_max:
|
||||
self.zoom_scale = new_scale
|
||||
self._apply_zoom()
|
||||
else:
|
||||
new_scale = self.zoom_scale - self.zoom_wheel_step
|
||||
if new_scale >= self.zoom_min:
|
||||
self.zoom_scale = new_scale
|
||||
self._apply_zoom()
|
||||
|
||||
return True
|
||||
|
||||
return super().eventFilter(obj, event)
|
||||
352
src/gui/widgets/annotation_tools_widget.py
Normal file
352
src/gui/widgets/annotation_tools_widget.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""
|
||||
Annotation tools widget for controlling annotation parameters.
|
||||
Includes pen tool, color picker, class selection, and annotation management.
|
||||
"""
|
||||
|
||||
from PySide6.QtWidgets import (
|
||||
QWidget,
|
||||
QVBoxLayout,
|
||||
QHBoxLayout,
|
||||
QLabel,
|
||||
QGroupBox,
|
||||
QPushButton,
|
||||
QComboBox,
|
||||
QSpinBox,
|
||||
QColorDialog,
|
||||
QInputDialog,
|
||||
QMessageBox,
|
||||
)
|
||||
from PySide6.QtGui import QColor, QIcon, QPixmap, QPainter
|
||||
from PySide6.QtCore import Qt, Signal
|
||||
from typing import Optional, Dict
|
||||
|
||||
from src.database.db_manager import DatabaseManager
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AnnotationToolsWidget(QWidget):
|
||||
"""
|
||||
Widget for annotation tool controls.
|
||||
|
||||
Features:
|
||||
- Enable/disable pen tool
|
||||
- Color selection for pen
|
||||
- Object class selection
|
||||
- Add new object classes
|
||||
- Pen width control
|
||||
- Clear annotations
|
||||
|
||||
Signals:
|
||||
pen_enabled_changed: Emitted when pen tool is enabled/disabled (bool)
|
||||
pen_color_changed: Emitted when pen color changes (QColor)
|
||||
pen_width_changed: Emitted when pen width changes (int)
|
||||
class_selected: Emitted when object class is selected (dict)
|
||||
clear_annotations_requested: Emitted when clear button is pressed
|
||||
"""
|
||||
|
||||
pen_enabled_changed = Signal(bool)
|
||||
pen_color_changed = Signal(QColor)
|
||||
pen_width_changed = Signal(int)
|
||||
class_selected = Signal(dict)
|
||||
clear_annotations_requested = Signal()
|
||||
|
||||
def __init__(self, db_manager: DatabaseManager, parent=None):
|
||||
"""
|
||||
Initialize annotation tools widget.
|
||||
|
||||
Args:
|
||||
db_manager: Database manager instance
|
||||
parent: Parent widget
|
||||
"""
|
||||
super().__init__(parent)
|
||||
self.db_manager = db_manager
|
||||
self.pen_enabled = False
|
||||
self.current_color = QColor(255, 0, 0, 128) # Red with 50% alpha
|
||||
self.current_class = None
|
||||
|
||||
self._setup_ui()
|
||||
self._load_object_classes()
|
||||
|
||||
def _setup_ui(self):
|
||||
"""Setup user interface."""
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# Pen Tool Group
|
||||
pen_group = QGroupBox("Pen Tool")
|
||||
pen_layout = QVBoxLayout()
|
||||
|
||||
# Enable/Disable pen
|
||||
button_layout = QHBoxLayout()
|
||||
self.pen_toggle_btn = QPushButton("Enable Pen")
|
||||
self.pen_toggle_btn.setCheckable(True)
|
||||
self.pen_toggle_btn.clicked.connect(self._on_pen_toggle)
|
||||
button_layout.addWidget(self.pen_toggle_btn)
|
||||
pen_layout.addLayout(button_layout)
|
||||
|
||||
# Pen width control
|
||||
width_layout = QHBoxLayout()
|
||||
width_layout.addWidget(QLabel("Pen Width:"))
|
||||
self.pen_width_spin = QSpinBox()
|
||||
self.pen_width_spin.setMinimum(1)
|
||||
self.pen_width_spin.setMaximum(20)
|
||||
self.pen_width_spin.setValue(3)
|
||||
self.pen_width_spin.valueChanged.connect(self._on_pen_width_changed)
|
||||
width_layout.addWidget(self.pen_width_spin)
|
||||
width_layout.addStretch()
|
||||
pen_layout.addLayout(width_layout)
|
||||
|
||||
# Color selection
|
||||
color_layout = QHBoxLayout()
|
||||
color_layout.addWidget(QLabel("Color:"))
|
||||
self.color_btn = QPushButton()
|
||||
self.color_btn.setFixedSize(40, 30)
|
||||
self.color_btn.clicked.connect(self._on_color_picker)
|
||||
self._update_color_button()
|
||||
color_layout.addWidget(self.color_btn)
|
||||
color_layout.addStretch()
|
||||
pen_layout.addLayout(color_layout)
|
||||
|
||||
pen_group.setLayout(pen_layout)
|
||||
layout.addWidget(pen_group)
|
||||
|
||||
# Object Class Group
|
||||
class_group = QGroupBox("Object Class")
|
||||
class_layout = QVBoxLayout()
|
||||
|
||||
# Class selection dropdown
|
||||
self.class_combo = QComboBox()
|
||||
self.class_combo.currentIndexChanged.connect(self._on_class_selected)
|
||||
class_layout.addWidget(self.class_combo)
|
||||
|
||||
# Add class button
|
||||
class_button_layout = QHBoxLayout()
|
||||
self.add_class_btn = QPushButton("Add New Class")
|
||||
self.add_class_btn.clicked.connect(self._on_add_class)
|
||||
class_button_layout.addWidget(self.add_class_btn)
|
||||
|
||||
self.refresh_classes_btn = QPushButton("Refresh")
|
||||
self.refresh_classes_btn.clicked.connect(self._load_object_classes)
|
||||
class_button_layout.addWidget(self.refresh_classes_btn)
|
||||
class_layout.addLayout(class_button_layout)
|
||||
|
||||
# Selected class info
|
||||
self.class_info_label = QLabel("No class selected")
|
||||
self.class_info_label.setWordWrap(True)
|
||||
self.class_info_label.setStyleSheet(
|
||||
"QLabel { color: #888; font-style: italic; }"
|
||||
)
|
||||
class_layout.addWidget(self.class_info_label)
|
||||
|
||||
class_group.setLayout(class_layout)
|
||||
layout.addWidget(class_group)
|
||||
|
||||
# Actions Group
|
||||
actions_group = QGroupBox("Actions")
|
||||
actions_layout = QVBoxLayout()
|
||||
|
||||
self.clear_btn = QPushButton("Clear All Annotations")
|
||||
self.clear_btn.clicked.connect(self._on_clear_annotations)
|
||||
actions_layout.addWidget(self.clear_btn)
|
||||
|
||||
actions_group.setLayout(actions_layout)
|
||||
layout.addWidget(actions_group)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def _update_color_button(self):
|
||||
"""Update the color button appearance with current color."""
|
||||
pixmap = QPixmap(40, 30)
|
||||
pixmap.fill(self.current_color)
|
||||
|
||||
# Add border
|
||||
painter = QPainter(pixmap)
|
||||
painter.setPen(Qt.black)
|
||||
painter.drawRect(0, 0, pixmap.width() - 1, pixmap.height() - 1)
|
||||
painter.end()
|
||||
|
||||
self.color_btn.setIcon(QIcon(pixmap))
|
||||
self.color_btn.setStyleSheet(f"background-color: {self.current_color.name()};")
|
||||
|
||||
def _load_object_classes(self):
|
||||
"""Load object classes from database and populate combo box."""
|
||||
try:
|
||||
classes = self.db_manager.get_object_classes()
|
||||
|
||||
# Clear and repopulate combo box
|
||||
self.class_combo.clear()
|
||||
self.class_combo.addItem("-- Select Class --", None)
|
||||
|
||||
for cls in classes:
|
||||
self.class_combo.addItem(cls["class_name"], cls)
|
||||
|
||||
logger.debug(f"Loaded {len(classes)} object classes")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading object classes: {e}")
|
||||
QMessageBox.warning(
|
||||
self, "Error", f"Failed to load object classes:\n{str(e)}"
|
||||
)
|
||||
|
||||
def _on_pen_toggle(self, checked: bool):
|
||||
"""Handle pen tool enable/disable."""
|
||||
self.pen_enabled = checked
|
||||
|
||||
if checked:
|
||||
self.pen_toggle_btn.setText("Disable Pen")
|
||||
self.pen_toggle_btn.setStyleSheet(
|
||||
"QPushButton { background-color: #4CAF50; }"
|
||||
)
|
||||
else:
|
||||
self.pen_toggle_btn.setText("Enable Pen")
|
||||
self.pen_toggle_btn.setStyleSheet("")
|
||||
|
||||
self.pen_enabled_changed.emit(self.pen_enabled)
|
||||
logger.debug(f"Pen tool {'enabled' if checked else 'disabled'}")
|
||||
|
||||
def _on_pen_width_changed(self, width: int):
|
||||
"""Handle pen width changes."""
|
||||
self.pen_width_changed.emit(width)
|
||||
logger.debug(f"Pen width changed to {width}")
|
||||
|
||||
def _on_color_picker(self):
|
||||
"""Open color picker dialog with alpha support."""
|
||||
color = QColorDialog.getColor(
|
||||
self.current_color,
|
||||
self,
|
||||
"Select Pen Color",
|
||||
QColorDialog.ShowAlphaChannel, # Enable alpha channel selection
|
||||
)
|
||||
|
||||
if color.isValid():
|
||||
self.current_color = color
|
||||
self._update_color_button()
|
||||
self.pen_color_changed.emit(color)
|
||||
logger.debug(
|
||||
f"Pen color changed to {color.name()} with alpha {color.alpha()}"
|
||||
)
|
||||
|
||||
def _on_class_selected(self, index: int):
|
||||
"""Handle object class selection."""
|
||||
class_data = self.class_combo.currentData()
|
||||
|
||||
if class_data:
|
||||
self.current_class = class_data
|
||||
|
||||
# Update info label
|
||||
info_text = (
|
||||
f"Class: {class_data['class_name']}\n" f"Color: {class_data['color']}"
|
||||
)
|
||||
if class_data.get("description"):
|
||||
info_text += f"\nDescription: {class_data['description']}"
|
||||
|
||||
self.class_info_label.setText(info_text)
|
||||
|
||||
# Update pen color to match class color with semi-transparency
|
||||
class_color = QColor(class_data["color"])
|
||||
if class_color.isValid():
|
||||
# Add 50% alpha for semi-transparency
|
||||
class_color.setAlpha(128)
|
||||
self.current_color = class_color
|
||||
self._update_color_button()
|
||||
self.pen_color_changed.emit(class_color)
|
||||
|
||||
self.class_selected.emit(class_data)
|
||||
logger.debug(f"Selected class: {class_data['class_name']}")
|
||||
else:
|
||||
self.current_class = None
|
||||
self.class_info_label.setText("No class selected")
|
||||
|
||||
def _on_add_class(self):
|
||||
"""Handle adding a new object class."""
|
||||
# Get class name
|
||||
class_name, ok = QInputDialog.getText(
|
||||
self, "Add Object Class", "Enter class name:"
|
||||
)
|
||||
|
||||
if not ok or not class_name.strip():
|
||||
return
|
||||
|
||||
class_name = class_name.strip()
|
||||
|
||||
# Check if class already exists
|
||||
existing = self.db_manager.get_object_class_by_name(class_name)
|
||||
if existing:
|
||||
QMessageBox.warning(
|
||||
self, "Class Exists", f"A class named '{class_name}' already exists."
|
||||
)
|
||||
return
|
||||
|
||||
# Get color
|
||||
color = QColorDialog.getColor(self.current_color, self, "Select Class Color")
|
||||
|
||||
if not color.isValid():
|
||||
return
|
||||
|
||||
# Get optional description
|
||||
description, ok = QInputDialog.getText(
|
||||
self, "Class Description", "Enter class description (optional):"
|
||||
)
|
||||
|
||||
if not ok:
|
||||
description = None
|
||||
|
||||
# Add to database
|
||||
try:
|
||||
class_id = self.db_manager.add_object_class(
|
||||
class_name, color.name(), description.strip() if description else None
|
||||
)
|
||||
|
||||
logger.info(f"Added new object class: {class_name} (ID: {class_id})")
|
||||
|
||||
# Reload classes and select the new one
|
||||
self._load_object_classes()
|
||||
|
||||
# Find and select the newly added class
|
||||
for i in range(self.class_combo.count()):
|
||||
class_data = self.class_combo.itemData(i)
|
||||
if class_data and class_data.get("id") == class_id:
|
||||
self.class_combo.setCurrentIndex(i)
|
||||
break
|
||||
|
||||
QMessageBox.information(
|
||||
self, "Success", f"Class '{class_name}' added successfully!"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding object class: {e}")
|
||||
QMessageBox.critical(
|
||||
self, "Error", f"Failed to add object class:\n{str(e)}"
|
||||
)
|
||||
|
||||
def _on_clear_annotations(self):
|
||||
"""Handle clear annotations button."""
|
||||
reply = QMessageBox.question(
|
||||
self,
|
||||
"Clear Annotations",
|
||||
"Are you sure you want to clear all annotations?",
|
||||
QMessageBox.Yes | QMessageBox.No,
|
||||
QMessageBox.No,
|
||||
)
|
||||
|
||||
if reply == QMessageBox.Yes:
|
||||
self.clear_annotations_requested.emit()
|
||||
logger.debug("Clear annotations requested")
|
||||
|
||||
def get_current_class(self) -> Optional[Dict]:
|
||||
"""Get currently selected object class."""
|
||||
return self.current_class
|
||||
|
||||
def get_pen_color(self) -> QColor:
|
||||
"""Get current pen color."""
|
||||
return self.current_color
|
||||
|
||||
def get_pen_width(self) -> int:
|
||||
"""Get current pen width."""
|
||||
return self.pen_width_spin.value()
|
||||
|
||||
def is_pen_enabled(self) -> bool:
|
||||
"""Check if pen tool is enabled."""
|
||||
return self.pen_enabled
|
||||
Reference in New Issue
Block a user