From fc22479621e5802967dabe24d8d0eabe7c71aa1d Mon Sep 17 00:00:00 2001 From: Martin Laasmaa Date: Mon, 8 Dec 2025 23:15:54 +0200 Subject: [PATCH] Adding pen tool for annotation --- src/database/db_manager.py | 228 ++++++++++- src/database/schema.sql | 27 +- src/gui/tabs/annotation_tab.py | 116 ++++-- src/gui/widgets/__init__.py | 4 +- src/gui/widgets/annotation_canvas_widget.py | 406 ++++++++++++++++++++ src/gui/widgets/annotation_tools_widget.py | 352 +++++++++++++++++ 6 files changed, 1079 insertions(+), 54 deletions(-) create mode 100644 src/gui/widgets/annotation_canvas_widget.py create mode 100644 src/gui/widgets/annotation_tools_widget.py diff --git a/src/database/db_manager.py b/src/database/db_manager.py index 5dcf5c2..db2da9e 100644 --- a/src/database/db_manager.py +++ b/src/database/db_manager.py @@ -30,18 +30,48 @@ class DatabaseManager: # Create directory if it doesn't exist Path(self.db_path).parent.mkdir(parents=True, exist_ok=True) - # Read schema file and execute - schema_path = Path(__file__).parent / "schema.sql" - with open(schema_path, "r") as f: - schema_sql = f.read() - conn = self.get_connection() try: + # Check if annotations table needs migration + self._migrate_annotations_table(conn) + + # Read schema file and execute + schema_path = Path(__file__).parent / "schema.sql" + with open(schema_path, "r") as f: + schema_sql = f.read() + conn.executescript(schema_sql) conn.commit() finally: conn.close() + def _migrate_annotations_table(self, conn: sqlite3.Connection) -> None: + """ + Migrate annotations table from old schema (class_name) to new schema (class_id). + """ + cursor = conn.cursor() + + # Check if annotations table exists + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='annotations'" + ) + if not cursor.fetchone(): + # Table doesn't exist yet, no migration needed + return + + # Check if table has old schema (class_name column) + cursor.execute("PRAGMA table_info(annotations)") + columns = {row[1]: row for row in cursor.fetchall()} + + if "class_name" in columns and "class_id" not in columns: + # Old schema detected, need to migrate + print("Migrating annotations table to new schema with class_id...") + + # Drop old annotations table (assuming no critical data since this is a new feature) + cursor.execute("DROP TABLE IF EXISTS annotations") + conn.commit() + print("Old annotations table dropped, will be recreated with new schema") + def get_connection(self) -> sqlite3.Connection: """Get database connection with proper settings.""" conn = sqlite3.connect(self.db_path) @@ -593,25 +623,38 @@ class DatabaseManager: def add_annotation( self, image_id: int, - class_name: str, + class_id: int, bbox: Tuple[float, float, float, float], annotator: str, segmentation_mask: Optional[List[List[float]]] = None, verified: bool = False, ) -> int: - """Add manual annotation.""" + """ + Add manual annotation. + + Args: + image_id: ID of the image + class_id: ID of the object class (foreign key to object_classes) + bbox: Bounding box coordinates (normalized 0-1) + annotator: Name of person/tool creating annotation + segmentation_mask: Polygon coordinates for segmentation + verified: Whether annotation has been verified + + Returns: + ID of the inserted annotation + """ conn = self.get_connection() try: cursor = conn.cursor() x_min, y_min, x_max, y_max = bbox cursor.execute( """ - INSERT INTO annotations (image_id, class_name, x_min, y_min, x_max, y_max, segmentation_mask, annotator, verified) + INSERT INTO annotations (image_id, class_id, x_min, y_min, x_max, y_max, segmentation_mask, annotator, verified) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( image_id, - class_name, + class_id, x_min, y_min, x_max, @@ -627,15 +670,178 @@ class DatabaseManager: conn.close() def get_annotations_for_image(self, image_id: int) -> List[Dict]: - """Get all annotations for an image.""" + """ + Get all annotations for an image with class information. + + Args: + image_id: ID of the image + + Returns: + List of annotation dictionaries with joined class information + """ conn = self.get_connection() try: cursor = conn.cursor() - cursor.execute("SELECT * FROM annotations WHERE image_id = ?", (image_id,)) + cursor.execute( + """ + SELECT + a.*, + c.class_name, + c.color as class_color, + c.description as class_description + FROM annotations a + JOIN object_classes c ON a.class_id = c.id + WHERE a.image_id = ? + ORDER BY a.created_at DESC + """, + (image_id,), + ) + annotations = [] + for row in cursor.fetchall(): + ann = dict(row) + if ann.get("segmentation_mask"): + ann["segmentation_mask"] = json.loads(ann["segmentation_mask"]) + annotations.append(ann) + return annotations + finally: + conn.close() + + # ==================== Object Class Operations ==================== + + def get_object_classes(self) -> List[Dict]: + """ + Get all object classes. + + Returns: + List of object class dictionaries + """ + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute("SELECT * FROM object_classes ORDER BY class_name") return [dict(row) for row in cursor.fetchall()] finally: conn.close() + def get_object_class_by_id(self, class_id: int) -> Optional[Dict]: + """Get object class by ID.""" + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute("SELECT * FROM object_classes WHERE id = ?", (class_id,)) + row = cursor.fetchone() + return dict(row) if row else None + finally: + conn.close() + + def get_object_class_by_name(self, class_name: str) -> Optional[Dict]: + """Get object class by name.""" + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM object_classes WHERE class_name = ?", (class_name,) + ) + row = cursor.fetchone() + return dict(row) if row else None + finally: + conn.close() + + def add_object_class( + self, class_name: str, color: str, description: Optional[str] = None + ) -> int: + """ + Add a new object class. + + Args: + class_name: Name of the object class + color: Hex color code (e.g., '#FF0000') + description: Optional description + + Returns: + ID of the inserted object class + """ + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute( + """ + INSERT INTO object_classes (class_name, color, description) + VALUES (?, ?, ?) + """, + (class_name, color, description), + ) + conn.commit() + return cursor.lastrowid + except sqlite3.IntegrityError: + # Class already exists + existing = self.get_object_class_by_name(class_name) + return existing["id"] if existing else None + finally: + conn.close() + + def update_object_class( + self, + class_id: int, + class_name: Optional[str] = None, + color: Optional[str] = None, + description: Optional[str] = None, + ) -> bool: + """ + Update an object class. + + Args: + class_id: ID of the class to update + class_name: New class name (optional) + color: New color (optional) + description: New description (optional) + + Returns: + True if updated, False otherwise + """ + conn = self.get_connection() + try: + updates = {} + if class_name is not None: + updates["class_name"] = class_name + if color is not None: + updates["color"] = color + if description is not None: + updates["description"] = description + + if not updates: + return False + + set_clauses = [f"{key} = ?" for key in updates.keys()] + params = list(updates.values()) + [class_id] + + query = f"UPDATE object_classes SET {', '.join(set_clauses)} WHERE id = ?" + cursor = conn.cursor() + cursor.execute(query, params) + conn.commit() + return cursor.rowcount > 0 + finally: + conn.close() + + def delete_object_class(self, class_id: int) -> bool: + """ + Delete an object class. + + Args: + class_id: ID of the class to delete + + Returns: + True if deleted, False otherwise + """ + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute("DELETE FROM object_classes WHERE id = ?", (class_id,)) + conn.commit() + return cursor.rowcount > 0 + finally: + conn.close() + @staticmethod def calculate_checksum(file_path: str) -> str: """Calculate MD5 checksum of a file.""" diff --git a/src/database/schema.sql b/src/database/schema.sql index b09ffee..64123eb 100644 --- a/src/database/schema.sql +++ b/src/database/schema.sql @@ -44,11 +44,27 @@ CREATE TABLE IF NOT EXISTS detections ( FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE ); --- Annotations table: stores manual annotations (future feature) +-- Object classes table: stores annotation class definitions with colors +CREATE TABLE IF NOT EXISTS object_classes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + class_name TEXT NOT NULL UNIQUE, + color TEXT NOT NULL, -- Hex color code (e.g., '#FF0000') + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + description TEXT +); + +-- Insert default object classes +INSERT OR IGNORE INTO object_classes (class_name, color, description) VALUES + ('cell', '#FF0000', 'Cell object'), + ('nucleus', '#00FF00', 'Cell nucleus'), + ('mitochondria', '#0000FF', 'Mitochondria'), + ('vesicle', '#FFFF00', 'Vesicle'); + +-- Annotations table: stores manual annotations CREATE TABLE IF NOT EXISTS annotations ( id INTEGER PRIMARY KEY AUTOINCREMENT, image_id INTEGER NOT NULL, - class_name TEXT NOT NULL, + class_id INTEGER NOT NULL, x_min REAL NOT NULL CHECK(x_min >= 0 AND x_min <= 1), y_min REAL NOT NULL CHECK(y_min >= 0 AND y_min <= 1), x_max REAL NOT NULL CHECK(x_max >= 0 AND x_max <= 1), @@ -57,7 +73,8 @@ CREATE TABLE IF NOT EXISTS annotations ( annotator TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, verified BOOLEAN DEFAULT 0, - FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE + FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE, + FOREIGN KEY (class_id) REFERENCES object_classes (id) ON DELETE CASCADE ); -- Create indexes for performance optimization @@ -69,4 +86,6 @@ CREATE INDEX IF NOT EXISTS idx_detections_confidence ON detections(confidence); CREATE INDEX IF NOT EXISTS idx_images_relative_path ON images(relative_path); CREATE INDEX IF NOT EXISTS idx_images_added_at ON images(added_at); CREATE INDEX IF NOT EXISTS idx_annotations_image_id ON annotations(image_id); -CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at); \ No newline at end of file +CREATE INDEX IF NOT EXISTS idx_annotations_class_id ON annotations(class_id); +CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at); +CREATE INDEX IF NOT EXISTS idx_object_classes_class_name ON object_classes(class_name); \ No newline at end of file diff --git a/src/gui/tabs/annotation_tab.py b/src/gui/tabs/annotation_tab.py index a373ad0..7933caf 100644 --- a/src/gui/tabs/annotation_tab.py +++ b/src/gui/tabs/annotation_tab.py @@ -1,6 +1,6 @@ """ Annotation tab for the microscopy object detection application. -Future feature for manual annotation. +Manual annotation with pen tool and object class management. """ from PySide6.QtWidgets import ( @@ -21,13 +21,13 @@ from src.database.db_manager import DatabaseManager from src.utils.config_manager import ConfigManager from src.utils.image import Image, ImageLoadError from src.utils.logger import get_logger -from src.gui.widgets import ImageDisplayWidget +from src.gui.widgets import AnnotationCanvasWidget, AnnotationToolsWidget logger = get_logger(__name__) class AnnotationTab(QWidget): - """Annotation tab placeholder (future feature).""" + """Annotation tab for manual image annotation.""" def __init__( self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None @@ -37,6 +37,7 @@ class AnnotationTab(QWidget): self.config_manager = config_manager self.current_image = None self.current_image_path = None + self.current_image_id = None self._setup_ui() @@ -52,49 +53,52 @@ class AnnotationTab(QWidget): self.left_splitter = QSplitter(Qt.Vertical) self.left_splitter.setHandleWidth(10) - # Image display section - display_group = QGroupBox("Image Display") - display_layout = QVBoxLayout() + # Annotation canvas section + canvas_group = QGroupBox("Annotation Canvas") + canvas_layout = QVBoxLayout() - # Use the reusable ImageDisplayWidget - self.image_display_widget = ImageDisplayWidget() - self.image_display_widget.zoom_changed.connect(self._on_zoom_changed) - display_layout.addWidget(self.image_display_widget) + # Use the AnnotationCanvasWidget + self.annotation_canvas = AnnotationCanvasWidget() + self.annotation_canvas.zoom_changed.connect(self._on_zoom_changed) + self.annotation_canvas.annotation_drawn.connect(self._on_annotation_drawn) + canvas_layout.addWidget(self.annotation_canvas) - display_group.setLayout(display_layout) - self.left_splitter.addWidget(display_group) + canvas_group.setLayout(canvas_layout) + self.left_splitter.addWidget(canvas_group) - # Zoom controls info - zoom_info = QLabel("Zoom: Mouse wheel or +/- keys to zoom in/out") - zoom_info.setStyleSheet("QLabel { color: #888; font-style: italic; }") - self.left_splitter.addWidget(zoom_info) + # Controls info + controls_info = QLabel( + "Zoom: Mouse wheel or +/- keys | Drawing: Enable pen and drag mouse" + ) + controls_info.setStyleSheet("QLabel { color: #888; font-style: italic; }") + self.left_splitter.addWidget(controls_info) # } # { Right splitter for annotation tools and controls self.right_splitter = QSplitter(Qt.Vertical) self.right_splitter.setHandleWidth(10) + # Annotation tools section + self.annotation_tools = AnnotationToolsWidget(self.db_manager) + self.annotation_tools.pen_enabled_changed.connect( + self.annotation_canvas.set_pen_enabled + ) + self.annotation_tools.pen_color_changed.connect( + self.annotation_canvas.set_pen_color + ) + self.annotation_tools.pen_width_changed.connect( + self.annotation_canvas.set_pen_width + ) + self.annotation_tools.class_selected.connect(self._on_class_selected) + self.annotation_tools.clear_annotations_requested.connect( + self._on_clear_annotations + ) + self.right_splitter.addWidget(self.annotation_tools) + # Image loading section load_group = QGroupBox("Image Loading") load_layout = QVBoxLayout() - # Future features info - info_group = QGroupBox("Annotation Tool (Future Feature)") - info_layout = QVBoxLayout() - info_label = QLabel( - "Full annotation functionality will be implemented in future version.\n\n" - "Planned Features:\n" - "- Drawing tools for bounding boxes\n" - "- Class label assignment\n" - "- Export annotations to YOLO format\n" - "- Annotation verification" - ) - info_label.setWordWrap(True) - info_layout.addWidget(info_label) - info_group.setLayout(info_layout) - - self.right_splitter.addWidget(info_group) - # Load image button button_layout = QHBoxLayout() self.load_image_btn = QPushButton("Load Image") @@ -158,13 +162,22 @@ class AnnotationTab(QWidget): "annotation_tab/last_directory", str(Path(file_path).parent) ) - # Display image using the ImageDisplayWidget - self.image_display_widget.load_image(self.current_image) + # Get or create image in database + relative_path = str(Path(file_path).name) # Simplified for now + self.current_image_id = self.db_manager.get_or_create_image( + relative_path, + Path(file_path).name, + self.current_image.width, + self.current_image.height, + ) + + # Display image using the AnnotationCanvasWidget + self.annotation_canvas.load_image(self.current_image) # Update info label self._update_image_info() - logger.info(f"Loaded image: {file_path}") + logger.info(f"Loaded image: {file_path} (DB ID: {self.current_image_id})") except ImageLoadError as e: logger.error(f"Failed to load image: {e}") @@ -181,7 +194,7 @@ class AnnotationTab(QWidget): self.image_info_label.setText("No image loaded") return - zoom_percentage = self.image_display_widget.get_zoom_percentage() + zoom_percentage = self.annotation_canvas.get_zoom_percentage() info_text = ( f"File: {Path(self.current_image_path).name}\n" f"Size: {self.current_image.width}x{self.current_image.height} pixels\n" @@ -194,9 +207,36 @@ class AnnotationTab(QWidget): self.image_info_label.setText(info_text) def _on_zoom_changed(self, zoom_scale: float): - """Handle zoom level changes from the image display widget.""" + """Handle zoom level changes from the annotation canvas.""" self._update_image_info() + def _on_annotation_drawn(self, points: list): + """Handle when an annotation stroke is drawn.""" + current_class = self.annotation_tools.get_current_class() + + if not current_class: + logger.warning("Annotation drawn but no object class selected") + QMessageBox.warning( + self, + "No Class Selected", + "Please select an object class before drawing annotations.", + ) + return + + logger.info( + f"Annotation drawn with {len(points)} points for class: {current_class['class_name']}" + ) + # Future: Save annotation to database or export + + def _on_class_selected(self, class_data: dict): + """Handle when an object class is selected.""" + logger.debug(f"Object class selected: {class_data['class_name']}") + + def _on_clear_annotations(self): + """Handle clearing all annotations.""" + self.annotation_canvas.clear_annotations() + logger.info("Cleared all annotations") + def _restore_state(self): """Restore splitter positions from settings.""" settings = QSettings("microscopy_app", "object_detection") diff --git a/src/gui/widgets/__init__.py b/src/gui/widgets/__init__.py index 2946406..df8fad7 100644 --- a/src/gui/widgets/__init__.py +++ b/src/gui/widgets/__init__.py @@ -1,5 +1,7 @@ """GUI widgets for the microscopy object detection application.""" from src.gui.widgets.image_display_widget import ImageDisplayWidget +from src.gui.widgets.annotation_canvas_widget import AnnotationCanvasWidget +from src.gui.widgets.annotation_tools_widget import AnnotationToolsWidget -__all__ = ["ImageDisplayWidget"] +__all__ = ["ImageDisplayWidget", "AnnotationCanvasWidget", "AnnotationToolsWidget"] diff --git a/src/gui/widgets/annotation_canvas_widget.py b/src/gui/widgets/annotation_canvas_widget.py new file mode 100644 index 0000000..7ff2bc5 --- /dev/null +++ b/src/gui/widgets/annotation_canvas_widget.py @@ -0,0 +1,406 @@ +""" +Annotation canvas widget for drawing annotations on images. +Supports pen tool with color selection for manual annotation. +""" + +from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QScrollArea +from PySide6.QtGui import ( + QPixmap, + QImage, + QPainter, + QPen, + QColor, + QKeyEvent, + QMouseEvent, + QPaintEvent, +) +from PySide6.QtCore import Qt, QEvent, Signal, QPoint +from typing import List, Optional, Tuple +import numpy as np + +from src.utils.image import Image, ImageLoadError +from src.utils.logger import get_logger + +logger = get_logger(__name__) + + +class AnnotationCanvasWidget(QWidget): + """ + Widget for displaying images and drawing annotations with pen tool. + + Features: + - Display images with zoom functionality + - Pen tool for drawing annotations + - Configurable pen color and width + - Mouse-based drawing interface + - Zoom in/out with mouse wheel and keyboard + + Signals: + zoom_changed: Emitted when zoom level changes (float zoom_scale) + annotation_drawn: Emitted when a new stroke is completed (list of points) + """ + + zoom_changed = Signal(float) + annotation_drawn = Signal(list) # List of (x, y) points in normalized coordinates + + def __init__(self, parent=None): + """Initialize the annotation canvas widget.""" + super().__init__(parent) + + self.current_image = None + self.original_pixmap = None + self.annotation_pixmap = None # Overlay for annotations + self.zoom_scale = 1.0 + self.zoom_min = 0.1 + self.zoom_max = 10.0 + self.zoom_step = 0.1 + self.zoom_wheel_step = 0.15 + + # Drawing state + self.is_drawing = False + self.pen_enabled = False + self.pen_color = QColor(255, 0, 0, 128) # Default red with 50% alpha + self.pen_width = 3 + self.current_stroke = [] # Points in current stroke + self.all_strokes = [] # All completed strokes + + self._setup_ui() + + def _setup_ui(self): + """Setup user interface.""" + layout = QVBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + + # Scroll area for canvas + self.scroll_area = QScrollArea() + self.scroll_area.setWidgetResizable(True) + self.scroll_area.setMinimumHeight(400) + + self.canvas_label = QLabel("No image loaded") + self.canvas_label.setAlignment(Qt.AlignCenter) + self.canvas_label.setStyleSheet( + "QLabel { background-color: #2b2b2b; color: #888; }" + ) + self.canvas_label.setScaledContents(False) + self.canvas_label.setMouseTracking(True) + + self.scroll_area.setWidget(self.canvas_label) + self.scroll_area.viewport().installEventFilter(self) + + layout.addWidget(self.scroll_area) + self.setLayout(layout) + + self.setFocusPolicy(Qt.StrongFocus) + + def load_image(self, image: Image): + """ + Load and display an image. + + Args: + image: Image object to display + """ + self.current_image = image + self.zoom_scale = 1.0 + self.clear_annotations() + self._display_image() + logger.debug( + f"Loaded image into annotation canvas: {image.width}x{image.height}" + ) + + def clear(self): + """Clear the displayed image and all annotations.""" + self.current_image = None + self.original_pixmap = None + self.annotation_pixmap = None + self.zoom_scale = 1.0 + self.clear_annotations() + self.canvas_label.setText("No image loaded") + self.canvas_label.setPixmap(QPixmap()) + + def clear_annotations(self): + """Clear all drawn annotations.""" + self.all_strokes = [] + self.current_stroke = [] + self.is_drawing = False + if self.annotation_pixmap: + self.annotation_pixmap.fill(Qt.transparent) + self._update_display() + + def _display_image(self): + """Display the current image in the canvas.""" + if self.current_image is None: + return + + try: + # Get RGB image data + if self.current_image.channels == 3: + image_data = self.current_image.get_rgb() + height, width, channels = image_data.shape + else: + image_data = self.current_image.get_grayscale() + height, width = image_data.shape + + image_data = np.ascontiguousarray(image_data) + bytes_per_line = image_data.strides[0] + + qimage = QImage( + image_data.data, + width, + height, + bytes_per_line, + self.current_image.qtimage_format, + ) + + self.original_pixmap = QPixmap.fromImage(qimage) + + # Create transparent overlay for annotations + self.annotation_pixmap = QPixmap(self.original_pixmap.size()) + self.annotation_pixmap.fill(Qt.transparent) + + self._apply_zoom() + + except Exception as e: + logger.error(f"Error displaying image: {e}") + raise ImageLoadError(f"Failed to display image: {str(e)}") + + def _apply_zoom(self): + """Apply current zoom level to the displayed image.""" + if self.original_pixmap is None: + return + + scaled_width = int(self.original_pixmap.width() * self.zoom_scale) + scaled_height = int(self.original_pixmap.height() * self.zoom_scale) + + # Scale both image and annotations + scaled_image = self.original_pixmap.scaled( + scaled_width, + scaled_height, + Qt.KeepAspectRatio, + ( + Qt.SmoothTransformation + if self.zoom_scale >= 1.0 + else Qt.FastTransformation + ), + ) + + scaled_annotations = self.annotation_pixmap.scaled( + scaled_width, + scaled_height, + Qt.KeepAspectRatio, + ( + Qt.SmoothTransformation + if self.zoom_scale >= 1.0 + else Qt.FastTransformation + ), + ) + + # Composite image and annotations + combined = QPixmap(scaled_image.size()) + painter = QPainter(combined) + painter.drawPixmap(0, 0, scaled_image) + painter.drawPixmap(0, 0, scaled_annotations) + painter.end() + + self.canvas_label.setPixmap(combined) + self.canvas_label.setScaledContents(False) + self.canvas_label.adjustSize() + + self.zoom_changed.emit(self.zoom_scale) + + def _update_display(self): + """Update display after drawing.""" + self._apply_zoom() + + def set_pen_enabled(self, enabled: bool): + """Enable or disable pen tool.""" + self.pen_enabled = enabled + if enabled: + self.canvas_label.setCursor(Qt.CrossCursor) + else: + self.canvas_label.setCursor(Qt.ArrowCursor) + + def set_pen_color(self, color: QColor): + """Set pen color.""" + self.pen_color = color + + def set_pen_width(self, width: int): + """Set pen width.""" + self.pen_width = max(1, width) + + def get_zoom_percentage(self) -> int: + """Get current zoom level as percentage.""" + return int(self.zoom_scale * 100) + + def zoom_in(self): + """Zoom in on the image.""" + if self.original_pixmap is None: + return + new_scale = self.zoom_scale + self.zoom_step + if new_scale <= self.zoom_max: + self.zoom_scale = new_scale + self._apply_zoom() + + def zoom_out(self): + """Zoom out from the image.""" + if self.original_pixmap is None: + return + new_scale = self.zoom_scale - self.zoom_step + if new_scale >= self.zoom_min: + self.zoom_scale = new_scale + self._apply_zoom() + + def reset_zoom(self): + """Reset zoom to 100%.""" + if self.original_pixmap is None: + return + self.zoom_scale = 1.0 + self._apply_zoom() + + def _canvas_to_image_coords(self, pos: QPoint) -> Optional[Tuple[int, int]]: + """Convert canvas coordinates to image coordinates, accounting for zoom and centering.""" + if self.original_pixmap is None or self.canvas_label.pixmap() is None: + return None + + # Get the displayed pixmap size (after zoom) + displayed_pixmap = self.canvas_label.pixmap() + displayed_width = displayed_pixmap.width() + displayed_height = displayed_pixmap.height() + + # Calculate offset due to label centering (label might be larger than pixmap) + label_width = self.canvas_label.width() + label_height = self.canvas_label.height() + offset_x = max(0, (label_width - displayed_width) // 2) + offset_y = max(0, (label_height - displayed_height) // 2) + + # Adjust position for offset and convert to image coordinates + x = (pos.x() - offset_x) / self.zoom_scale + y = (pos.y() - offset_y) / self.zoom_scale + + # Check bounds + if ( + 0 <= x < self.original_pixmap.width() + and 0 <= y < self.original_pixmap.height() + ): + return (int(x), int(y)) + return None + + def _image_to_normalized_coords(self, x: int, y: int) -> Tuple[float, float]: + """Convert image coordinates to normalized coordinates (0-1).""" + if self.original_pixmap is None: + return (0.0, 0.0) + + norm_x = x / self.original_pixmap.width() + norm_y = y / self.original_pixmap.height() + return (norm_x, norm_y) + + def mousePressEvent(self, event: QMouseEvent): + """Handle mouse press events for drawing.""" + if not self.pen_enabled or self.annotation_pixmap is None: + super().mousePressEvent(event) + return + + if event.button() == Qt.LeftButton: + # Get accurate position using global coordinates + label_pos = self.canvas_label.mapFromGlobal(event.globalPos()) + img_coords = self._canvas_to_image_coords(label_pos) + + if img_coords: + self.is_drawing = True + self.current_stroke = [img_coords] + + def mouseMoveEvent(self, event: QMouseEvent): + """Handle mouse move events for drawing.""" + if ( + not self.is_drawing + or not self.pen_enabled + or self.annotation_pixmap is None + ): + super().mouseMoveEvent(event) + return + + # Get accurate position using global coordinates + label_pos = self.canvas_label.mapFromGlobal(event.globalPos()) + img_coords = self._canvas_to_image_coords(label_pos) + + if img_coords and len(self.current_stroke) > 0: + # Draw line from last point to current point + painter = QPainter(self.annotation_pixmap) + pen = QPen( + self.pen_color, self.pen_width, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin + ) + painter.setPen(pen) + + last_point = self.current_stroke[-1] + painter.drawLine(last_point[0], last_point[1], img_coords[0], img_coords[1]) + painter.end() + + self.current_stroke.append(img_coords) + self._update_display() + + def mouseReleaseEvent(self, event: QMouseEvent): + """Handle mouse release events to complete a stroke.""" + if not self.is_drawing or event.button() != Qt.LeftButton: + super().mouseReleaseEvent(event) + return + + self.is_drawing = False + + if len(self.current_stroke) > 1: + # Convert to normalized coordinates and save stroke + normalized_stroke = [ + self._image_to_normalized_coords(x, y) for x, y in self.current_stroke + ] + self.all_strokes.append( + { + "points": normalized_stroke, + "color": self.pen_color.name(), + "alpha": self.pen_color.alpha(), + "width": self.pen_width, + } + ) + + # Emit signal with normalized coordinates + self.annotation_drawn.emit(normalized_stroke) + logger.debug(f"Completed stroke with {len(normalized_stroke)} points") + + self.current_stroke = [] + + def get_all_strokes(self) -> List[dict]: + """Get all drawn strokes with metadata.""" + return self.all_strokes + + def keyPressEvent(self, event: QKeyEvent): + """Handle keyboard events for zooming.""" + if event.key() in (Qt.Key_Plus, Qt.Key_Equal): + self.zoom_in() + event.accept() + elif event.key() == Qt.Key_Minus: + self.zoom_out() + event.accept() + elif event.key() == Qt.Key_0 and event.modifiers() == Qt.ControlModifier: + self.reset_zoom() + event.accept() + else: + super().keyPressEvent(event) + + def eventFilter(self, obj, event: QEvent) -> bool: + """Event filter to capture wheel events for zooming.""" + if event.type() == QEvent.Wheel: + wheel_event = event + if self.original_pixmap is not None: + delta = wheel_event.angleDelta().y() + + if delta > 0: + new_scale = self.zoom_scale + self.zoom_wheel_step + if new_scale <= self.zoom_max: + self.zoom_scale = new_scale + self._apply_zoom() + else: + new_scale = self.zoom_scale - self.zoom_wheel_step + if new_scale >= self.zoom_min: + self.zoom_scale = new_scale + self._apply_zoom() + + return True + + return super().eventFilter(obj, event) diff --git a/src/gui/widgets/annotation_tools_widget.py b/src/gui/widgets/annotation_tools_widget.py new file mode 100644 index 0000000..e59e1ff --- /dev/null +++ b/src/gui/widgets/annotation_tools_widget.py @@ -0,0 +1,352 @@ +""" +Annotation tools widget for controlling annotation parameters. +Includes pen tool, color picker, class selection, and annotation management. +""" + +from PySide6.QtWidgets import ( + QWidget, + QVBoxLayout, + QHBoxLayout, + QLabel, + QGroupBox, + QPushButton, + QComboBox, + QSpinBox, + QColorDialog, + QInputDialog, + QMessageBox, +) +from PySide6.QtGui import QColor, QIcon, QPixmap, QPainter +from PySide6.QtCore import Qt, Signal +from typing import Optional, Dict + +from src.database.db_manager import DatabaseManager +from src.utils.logger import get_logger + +logger = get_logger(__name__) + + +class AnnotationToolsWidget(QWidget): + """ + Widget for annotation tool controls. + + Features: + - Enable/disable pen tool + - Color selection for pen + - Object class selection + - Add new object classes + - Pen width control + - Clear annotations + + Signals: + pen_enabled_changed: Emitted when pen tool is enabled/disabled (bool) + pen_color_changed: Emitted when pen color changes (QColor) + pen_width_changed: Emitted when pen width changes (int) + class_selected: Emitted when object class is selected (dict) + clear_annotations_requested: Emitted when clear button is pressed + """ + + pen_enabled_changed = Signal(bool) + pen_color_changed = Signal(QColor) + pen_width_changed = Signal(int) + class_selected = Signal(dict) + clear_annotations_requested = Signal() + + def __init__(self, db_manager: DatabaseManager, parent=None): + """ + Initialize annotation tools widget. + + Args: + db_manager: Database manager instance + parent: Parent widget + """ + super().__init__(parent) + self.db_manager = db_manager + self.pen_enabled = False + self.current_color = QColor(255, 0, 0, 128) # Red with 50% alpha + self.current_class = None + + self._setup_ui() + self._load_object_classes() + + def _setup_ui(self): + """Setup user interface.""" + layout = QVBoxLayout() + + # Pen Tool Group + pen_group = QGroupBox("Pen Tool") + pen_layout = QVBoxLayout() + + # Enable/Disable pen + button_layout = QHBoxLayout() + self.pen_toggle_btn = QPushButton("Enable Pen") + self.pen_toggle_btn.setCheckable(True) + self.pen_toggle_btn.clicked.connect(self._on_pen_toggle) + button_layout.addWidget(self.pen_toggle_btn) + pen_layout.addLayout(button_layout) + + # Pen width control + width_layout = QHBoxLayout() + width_layout.addWidget(QLabel("Pen Width:")) + self.pen_width_spin = QSpinBox() + self.pen_width_spin.setMinimum(1) + self.pen_width_spin.setMaximum(20) + self.pen_width_spin.setValue(3) + self.pen_width_spin.valueChanged.connect(self._on_pen_width_changed) + width_layout.addWidget(self.pen_width_spin) + width_layout.addStretch() + pen_layout.addLayout(width_layout) + + # Color selection + color_layout = QHBoxLayout() + color_layout.addWidget(QLabel("Color:")) + self.color_btn = QPushButton() + self.color_btn.setFixedSize(40, 30) + self.color_btn.clicked.connect(self._on_color_picker) + self._update_color_button() + color_layout.addWidget(self.color_btn) + color_layout.addStretch() + pen_layout.addLayout(color_layout) + + pen_group.setLayout(pen_layout) + layout.addWidget(pen_group) + + # Object Class Group + class_group = QGroupBox("Object Class") + class_layout = QVBoxLayout() + + # Class selection dropdown + self.class_combo = QComboBox() + self.class_combo.currentIndexChanged.connect(self._on_class_selected) + class_layout.addWidget(self.class_combo) + + # Add class button + class_button_layout = QHBoxLayout() + self.add_class_btn = QPushButton("Add New Class") + self.add_class_btn.clicked.connect(self._on_add_class) + class_button_layout.addWidget(self.add_class_btn) + + self.refresh_classes_btn = QPushButton("Refresh") + self.refresh_classes_btn.clicked.connect(self._load_object_classes) + class_button_layout.addWidget(self.refresh_classes_btn) + class_layout.addLayout(class_button_layout) + + # Selected class info + self.class_info_label = QLabel("No class selected") + self.class_info_label.setWordWrap(True) + self.class_info_label.setStyleSheet( + "QLabel { color: #888; font-style: italic; }" + ) + class_layout.addWidget(self.class_info_label) + + class_group.setLayout(class_layout) + layout.addWidget(class_group) + + # Actions Group + actions_group = QGroupBox("Actions") + actions_layout = QVBoxLayout() + + self.clear_btn = QPushButton("Clear All Annotations") + self.clear_btn.clicked.connect(self._on_clear_annotations) + actions_layout.addWidget(self.clear_btn) + + actions_group.setLayout(actions_layout) + layout.addWidget(actions_group) + + layout.addStretch() + self.setLayout(layout) + + def _update_color_button(self): + """Update the color button appearance with current color.""" + pixmap = QPixmap(40, 30) + pixmap.fill(self.current_color) + + # Add border + painter = QPainter(pixmap) + painter.setPen(Qt.black) + painter.drawRect(0, 0, pixmap.width() - 1, pixmap.height() - 1) + painter.end() + + self.color_btn.setIcon(QIcon(pixmap)) + self.color_btn.setStyleSheet(f"background-color: {self.current_color.name()};") + + def _load_object_classes(self): + """Load object classes from database and populate combo box.""" + try: + classes = self.db_manager.get_object_classes() + + # Clear and repopulate combo box + self.class_combo.clear() + self.class_combo.addItem("-- Select Class --", None) + + for cls in classes: + self.class_combo.addItem(cls["class_name"], cls) + + logger.debug(f"Loaded {len(classes)} object classes") + + except Exception as e: + logger.error(f"Error loading object classes: {e}") + QMessageBox.warning( + self, "Error", f"Failed to load object classes:\n{str(e)}" + ) + + def _on_pen_toggle(self, checked: bool): + """Handle pen tool enable/disable.""" + self.pen_enabled = checked + + if checked: + self.pen_toggle_btn.setText("Disable Pen") + self.pen_toggle_btn.setStyleSheet( + "QPushButton { background-color: #4CAF50; }" + ) + else: + self.pen_toggle_btn.setText("Enable Pen") + self.pen_toggle_btn.setStyleSheet("") + + self.pen_enabled_changed.emit(self.pen_enabled) + logger.debug(f"Pen tool {'enabled' if checked else 'disabled'}") + + def _on_pen_width_changed(self, width: int): + """Handle pen width changes.""" + self.pen_width_changed.emit(width) + logger.debug(f"Pen width changed to {width}") + + def _on_color_picker(self): + """Open color picker dialog with alpha support.""" + color = QColorDialog.getColor( + self.current_color, + self, + "Select Pen Color", + QColorDialog.ShowAlphaChannel, # Enable alpha channel selection + ) + + if color.isValid(): + self.current_color = color + self._update_color_button() + self.pen_color_changed.emit(color) + logger.debug( + f"Pen color changed to {color.name()} with alpha {color.alpha()}" + ) + + def _on_class_selected(self, index: int): + """Handle object class selection.""" + class_data = self.class_combo.currentData() + + if class_data: + self.current_class = class_data + + # Update info label + info_text = ( + f"Class: {class_data['class_name']}\n" f"Color: {class_data['color']}" + ) + if class_data.get("description"): + info_text += f"\nDescription: {class_data['description']}" + + self.class_info_label.setText(info_text) + + # Update pen color to match class color with semi-transparency + class_color = QColor(class_data["color"]) + if class_color.isValid(): + # Add 50% alpha for semi-transparency + class_color.setAlpha(128) + self.current_color = class_color + self._update_color_button() + self.pen_color_changed.emit(class_color) + + self.class_selected.emit(class_data) + logger.debug(f"Selected class: {class_data['class_name']}") + else: + self.current_class = None + self.class_info_label.setText("No class selected") + + def _on_add_class(self): + """Handle adding a new object class.""" + # Get class name + class_name, ok = QInputDialog.getText( + self, "Add Object Class", "Enter class name:" + ) + + if not ok or not class_name.strip(): + return + + class_name = class_name.strip() + + # Check if class already exists + existing = self.db_manager.get_object_class_by_name(class_name) + if existing: + QMessageBox.warning( + self, "Class Exists", f"A class named '{class_name}' already exists." + ) + return + + # Get color + color = QColorDialog.getColor(self.current_color, self, "Select Class Color") + + if not color.isValid(): + return + + # Get optional description + description, ok = QInputDialog.getText( + self, "Class Description", "Enter class description (optional):" + ) + + if not ok: + description = None + + # Add to database + try: + class_id = self.db_manager.add_object_class( + class_name, color.name(), description.strip() if description else None + ) + + logger.info(f"Added new object class: {class_name} (ID: {class_id})") + + # Reload classes and select the new one + self._load_object_classes() + + # Find and select the newly added class + for i in range(self.class_combo.count()): + class_data = self.class_combo.itemData(i) + if class_data and class_data.get("id") == class_id: + self.class_combo.setCurrentIndex(i) + break + + QMessageBox.information( + self, "Success", f"Class '{class_name}' added successfully!" + ) + + except Exception as e: + logger.error(f"Error adding object class: {e}") + QMessageBox.critical( + self, "Error", f"Failed to add object class:\n{str(e)}" + ) + + def _on_clear_annotations(self): + """Handle clear annotations button.""" + reply = QMessageBox.question( + self, + "Clear Annotations", + "Are you sure you want to clear all annotations?", + QMessageBox.Yes | QMessageBox.No, + QMessageBox.No, + ) + + if reply == QMessageBox.Yes: + self.clear_annotations_requested.emit() + logger.debug("Clear annotations requested") + + def get_current_class(self) -> Optional[Dict]: + """Get currently selected object class.""" + return self.current_class + + def get_pen_color(self) -> QColor: + """Get current pen color.""" + return self.current_color + + def get_pen_width(self) -> int: + """Get current pen width.""" + return self.pen_width_spin.value() + + def is_pen_enabled(self) -> bool: + """Check if pen tool is enabled.""" + return self.pen_enabled