Adding pen tool for annotation

This commit is contained in:
2025-12-08 23:15:54 +02:00
parent f84dea0bff
commit fc22479621
6 changed files with 1079 additions and 54 deletions

View File

@@ -30,18 +30,48 @@ class DatabaseManager:
# Create directory if it doesn't exist # Create directory if it doesn't exist
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True) Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
# Read schema file and execute
schema_path = Path(__file__).parent / "schema.sql"
with open(schema_path, "r") as f:
schema_sql = f.read()
conn = self.get_connection() conn = self.get_connection()
try: try:
# Check if annotations table needs migration
self._migrate_annotations_table(conn)
# Read schema file and execute
schema_path = Path(__file__).parent / "schema.sql"
with open(schema_path, "r") as f:
schema_sql = f.read()
conn.executescript(schema_sql) conn.executescript(schema_sql)
conn.commit() conn.commit()
finally: finally:
conn.close() conn.close()
def _migrate_annotations_table(self, conn: sqlite3.Connection) -> None:
"""
Migrate annotations table from old schema (class_name) to new schema (class_id).
"""
cursor = conn.cursor()
# Check if annotations table exists
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='annotations'"
)
if not cursor.fetchone():
# Table doesn't exist yet, no migration needed
return
# Check if table has old schema (class_name column)
cursor.execute("PRAGMA table_info(annotations)")
columns = {row[1]: row for row in cursor.fetchall()}
if "class_name" in columns and "class_id" not in columns:
# Old schema detected, need to migrate
print("Migrating annotations table to new schema with class_id...")
# Drop old annotations table (assuming no critical data since this is a new feature)
cursor.execute("DROP TABLE IF EXISTS annotations")
conn.commit()
print("Old annotations table dropped, will be recreated with new schema")
def get_connection(self) -> sqlite3.Connection: def get_connection(self) -> sqlite3.Connection:
"""Get database connection with proper settings.""" """Get database connection with proper settings."""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
@@ -593,25 +623,38 @@ class DatabaseManager:
def add_annotation( def add_annotation(
self, self,
image_id: int, image_id: int,
class_name: str, class_id: int,
bbox: Tuple[float, float, float, float], bbox: Tuple[float, float, float, float],
annotator: str, annotator: str,
segmentation_mask: Optional[List[List[float]]] = None, segmentation_mask: Optional[List[List[float]]] = None,
verified: bool = False, verified: bool = False,
) -> int: ) -> int:
"""Add manual annotation.""" """
Add manual annotation.
Args:
image_id: ID of the image
class_id: ID of the object class (foreign key to object_classes)
bbox: Bounding box coordinates (normalized 0-1)
annotator: Name of person/tool creating annotation
segmentation_mask: Polygon coordinates for segmentation
verified: Whether annotation has been verified
Returns:
ID of the inserted annotation
"""
conn = self.get_connection() conn = self.get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
x_min, y_min, x_max, y_max = bbox x_min, y_min, x_max, y_max = bbox
cursor.execute( cursor.execute(
""" """
INSERT INTO annotations (image_id, class_name, x_min, y_min, x_max, y_max, segmentation_mask, annotator, verified) INSERT INTO annotations (image_id, class_id, x_min, y_min, x_max, y_max, segmentation_mask, annotator, verified)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""", """,
( (
image_id, image_id,
class_name, class_id,
x_min, x_min,
y_min, y_min,
x_max, x_max,
@@ -627,15 +670,178 @@ class DatabaseManager:
conn.close() conn.close()
def get_annotations_for_image(self, image_id: int) -> List[Dict]: def get_annotations_for_image(self, image_id: int) -> List[Dict]:
"""Get all annotations for an image.""" """
Get all annotations for an image with class information.
Args:
image_id: ID of the image
Returns:
List of annotation dictionaries with joined class information
"""
conn = self.get_connection() conn = self.get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute("SELECT * FROM annotations WHERE image_id = ?", (image_id,)) cursor.execute(
"""
SELECT
a.*,
c.class_name,
c.color as class_color,
c.description as class_description
FROM annotations a
JOIN object_classes c ON a.class_id = c.id
WHERE a.image_id = ?
ORDER BY a.created_at DESC
""",
(image_id,),
)
annotations = []
for row in cursor.fetchall():
ann = dict(row)
if ann.get("segmentation_mask"):
ann["segmentation_mask"] = json.loads(ann["segmentation_mask"])
annotations.append(ann)
return annotations
finally:
conn.close()
# ==================== Object Class Operations ====================
def get_object_classes(self) -> List[Dict]:
"""
Get all object classes.
Returns:
List of object class dictionaries
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM object_classes ORDER BY class_name")
return [dict(row) for row in cursor.fetchall()] return [dict(row) for row in cursor.fetchall()]
finally: finally:
conn.close() conn.close()
def get_object_class_by_id(self, class_id: int) -> Optional[Dict]:
"""Get object class by ID."""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM object_classes WHERE id = ?", (class_id,))
row = cursor.fetchone()
return dict(row) if row else None
finally:
conn.close()
def get_object_class_by_name(self, class_name: str) -> Optional[Dict]:
"""Get object class by name."""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"SELECT * FROM object_classes WHERE class_name = ?", (class_name,)
)
row = cursor.fetchone()
return dict(row) if row else None
finally:
conn.close()
def add_object_class(
self, class_name: str, color: str, description: Optional[str] = None
) -> int:
"""
Add a new object class.
Args:
class_name: Name of the object class
color: Hex color code (e.g., '#FF0000')
description: Optional description
Returns:
ID of the inserted object class
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO object_classes (class_name, color, description)
VALUES (?, ?, ?)
""",
(class_name, color, description),
)
conn.commit()
return cursor.lastrowid
except sqlite3.IntegrityError:
# Class already exists
existing = self.get_object_class_by_name(class_name)
return existing["id"] if existing else None
finally:
conn.close()
def update_object_class(
self,
class_id: int,
class_name: Optional[str] = None,
color: Optional[str] = None,
description: Optional[str] = None,
) -> bool:
"""
Update an object class.
Args:
class_id: ID of the class to update
class_name: New class name (optional)
color: New color (optional)
description: New description (optional)
Returns:
True if updated, False otherwise
"""
conn = self.get_connection()
try:
updates = {}
if class_name is not None:
updates["class_name"] = class_name
if color is not None:
updates["color"] = color
if description is not None:
updates["description"] = description
if not updates:
return False
set_clauses = [f"{key} = ?" for key in updates.keys()]
params = list(updates.values()) + [class_id]
query = f"UPDATE object_classes SET {', '.join(set_clauses)} WHERE id = ?"
cursor = conn.cursor()
cursor.execute(query, params)
conn.commit()
return cursor.rowcount > 0
finally:
conn.close()
def delete_object_class(self, class_id: int) -> bool:
"""
Delete an object class.
Args:
class_id: ID of the class to delete
Returns:
True if deleted, False otherwise
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("DELETE FROM object_classes WHERE id = ?", (class_id,))
conn.commit()
return cursor.rowcount > 0
finally:
conn.close()
@staticmethod @staticmethod
def calculate_checksum(file_path: str) -> str: def calculate_checksum(file_path: str) -> str:
"""Calculate MD5 checksum of a file.""" """Calculate MD5 checksum of a file."""

View File

@@ -44,11 +44,27 @@ CREATE TABLE IF NOT EXISTS detections (
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE
); );
-- Annotations table: stores manual annotations (future feature) -- Object classes table: stores annotation class definitions with colors
CREATE TABLE IF NOT EXISTS object_classes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
class_name TEXT NOT NULL UNIQUE,
color TEXT NOT NULL, -- Hex color code (e.g., '#FF0000')
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
description TEXT
);
-- Insert default object classes
INSERT OR IGNORE INTO object_classes (class_name, color, description) VALUES
('cell', '#FF0000', 'Cell object'),
('nucleus', '#00FF00', 'Cell nucleus'),
('mitochondria', '#0000FF', 'Mitochondria'),
('vesicle', '#FFFF00', 'Vesicle');
-- Annotations table: stores manual annotations
CREATE TABLE IF NOT EXISTS annotations ( CREATE TABLE IF NOT EXISTS annotations (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
image_id INTEGER NOT NULL, image_id INTEGER NOT NULL,
class_name TEXT NOT NULL, class_id INTEGER NOT NULL,
x_min REAL NOT NULL CHECK(x_min >= 0 AND x_min <= 1), x_min REAL NOT NULL CHECK(x_min >= 0 AND x_min <= 1),
y_min REAL NOT NULL CHECK(y_min >= 0 AND y_min <= 1), y_min REAL NOT NULL CHECK(y_min >= 0 AND y_min <= 1),
x_max REAL NOT NULL CHECK(x_max >= 0 AND x_max <= 1), x_max REAL NOT NULL CHECK(x_max >= 0 AND x_max <= 1),
@@ -57,7 +73,8 @@ CREATE TABLE IF NOT EXISTS annotations (
annotator TEXT, annotator TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
verified BOOLEAN DEFAULT 0, verified BOOLEAN DEFAULT 0,
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE,
FOREIGN KEY (class_id) REFERENCES object_classes (id) ON DELETE CASCADE
); );
-- Create indexes for performance optimization -- Create indexes for performance optimization
@@ -69,4 +86,6 @@ CREATE INDEX IF NOT EXISTS idx_detections_confidence ON detections(confidence);
CREATE INDEX IF NOT EXISTS idx_images_relative_path ON images(relative_path); CREATE INDEX IF NOT EXISTS idx_images_relative_path ON images(relative_path);
CREATE INDEX IF NOT EXISTS idx_images_added_at ON images(added_at); CREATE INDEX IF NOT EXISTS idx_images_added_at ON images(added_at);
CREATE INDEX IF NOT EXISTS idx_annotations_image_id ON annotations(image_id); CREATE INDEX IF NOT EXISTS idx_annotations_image_id ON annotations(image_id);
CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at); CREATE INDEX IF NOT EXISTS idx_annotations_class_id ON annotations(class_id);
CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at);
CREATE INDEX IF NOT EXISTS idx_object_classes_class_name ON object_classes(class_name);

View File

@@ -1,6 +1,6 @@
""" """
Annotation tab for the microscopy object detection application. Annotation tab for the microscopy object detection application.
Future feature for manual annotation. Manual annotation with pen tool and object class management.
""" """
from PySide6.QtWidgets import ( from PySide6.QtWidgets import (
@@ -21,13 +21,13 @@ from src.database.db_manager import DatabaseManager
from src.utils.config_manager import ConfigManager from src.utils.config_manager import ConfigManager
from src.utils.image import Image, ImageLoadError from src.utils.image import Image, ImageLoadError
from src.utils.logger import get_logger from src.utils.logger import get_logger
from src.gui.widgets import ImageDisplayWidget from src.gui.widgets import AnnotationCanvasWidget, AnnotationToolsWidget
logger = get_logger(__name__) logger = get_logger(__name__)
class AnnotationTab(QWidget): class AnnotationTab(QWidget):
"""Annotation tab placeholder (future feature).""" """Annotation tab for manual image annotation."""
def __init__( def __init__(
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
@@ -37,6 +37,7 @@ class AnnotationTab(QWidget):
self.config_manager = config_manager self.config_manager = config_manager
self.current_image = None self.current_image = None
self.current_image_path = None self.current_image_path = None
self.current_image_id = None
self._setup_ui() self._setup_ui()
@@ -52,49 +53,52 @@ class AnnotationTab(QWidget):
self.left_splitter = QSplitter(Qt.Vertical) self.left_splitter = QSplitter(Qt.Vertical)
self.left_splitter.setHandleWidth(10) self.left_splitter.setHandleWidth(10)
# Image display section # Annotation canvas section
display_group = QGroupBox("Image Display") canvas_group = QGroupBox("Annotation Canvas")
display_layout = QVBoxLayout() canvas_layout = QVBoxLayout()
# Use the reusable ImageDisplayWidget # Use the AnnotationCanvasWidget
self.image_display_widget = ImageDisplayWidget() self.annotation_canvas = AnnotationCanvasWidget()
self.image_display_widget.zoom_changed.connect(self._on_zoom_changed) self.annotation_canvas.zoom_changed.connect(self._on_zoom_changed)
display_layout.addWidget(self.image_display_widget) self.annotation_canvas.annotation_drawn.connect(self._on_annotation_drawn)
canvas_layout.addWidget(self.annotation_canvas)
display_group.setLayout(display_layout) canvas_group.setLayout(canvas_layout)
self.left_splitter.addWidget(display_group) self.left_splitter.addWidget(canvas_group)
# Zoom controls info # Controls info
zoom_info = QLabel("Zoom: Mouse wheel or +/- keys to zoom in/out") controls_info = QLabel(
zoom_info.setStyleSheet("QLabel { color: #888; font-style: italic; }") "Zoom: Mouse wheel or +/- keys | Drawing: Enable pen and drag mouse"
self.left_splitter.addWidget(zoom_info) )
controls_info.setStyleSheet("QLabel { color: #888; font-style: italic; }")
self.left_splitter.addWidget(controls_info)
# } # }
# { Right splitter for annotation tools and controls # { Right splitter for annotation tools and controls
self.right_splitter = QSplitter(Qt.Vertical) self.right_splitter = QSplitter(Qt.Vertical)
self.right_splitter.setHandleWidth(10) self.right_splitter.setHandleWidth(10)
# Annotation tools section
self.annotation_tools = AnnotationToolsWidget(self.db_manager)
self.annotation_tools.pen_enabled_changed.connect(
self.annotation_canvas.set_pen_enabled
)
self.annotation_tools.pen_color_changed.connect(
self.annotation_canvas.set_pen_color
)
self.annotation_tools.pen_width_changed.connect(
self.annotation_canvas.set_pen_width
)
self.annotation_tools.class_selected.connect(self._on_class_selected)
self.annotation_tools.clear_annotations_requested.connect(
self._on_clear_annotations
)
self.right_splitter.addWidget(self.annotation_tools)
# Image loading section # Image loading section
load_group = QGroupBox("Image Loading") load_group = QGroupBox("Image Loading")
load_layout = QVBoxLayout() load_layout = QVBoxLayout()
# Future features info
info_group = QGroupBox("Annotation Tool (Future Feature)")
info_layout = QVBoxLayout()
info_label = QLabel(
"Full annotation functionality will be implemented in future version.\n\n"
"Planned Features:\n"
"- Drawing tools for bounding boxes\n"
"- Class label assignment\n"
"- Export annotations to YOLO format\n"
"- Annotation verification"
)
info_label.setWordWrap(True)
info_layout.addWidget(info_label)
info_group.setLayout(info_layout)
self.right_splitter.addWidget(info_group)
# Load image button # Load image button
button_layout = QHBoxLayout() button_layout = QHBoxLayout()
self.load_image_btn = QPushButton("Load Image") self.load_image_btn = QPushButton("Load Image")
@@ -158,13 +162,22 @@ class AnnotationTab(QWidget):
"annotation_tab/last_directory", str(Path(file_path).parent) "annotation_tab/last_directory", str(Path(file_path).parent)
) )
# Display image using the ImageDisplayWidget # Get or create image in database
self.image_display_widget.load_image(self.current_image) relative_path = str(Path(file_path).name) # Simplified for now
self.current_image_id = self.db_manager.get_or_create_image(
relative_path,
Path(file_path).name,
self.current_image.width,
self.current_image.height,
)
# Display image using the AnnotationCanvasWidget
self.annotation_canvas.load_image(self.current_image)
# Update info label # Update info label
self._update_image_info() self._update_image_info()
logger.info(f"Loaded image: {file_path}") logger.info(f"Loaded image: {file_path} (DB ID: {self.current_image_id})")
except ImageLoadError as e: except ImageLoadError as e:
logger.error(f"Failed to load image: {e}") logger.error(f"Failed to load image: {e}")
@@ -181,7 +194,7 @@ class AnnotationTab(QWidget):
self.image_info_label.setText("No image loaded") self.image_info_label.setText("No image loaded")
return return
zoom_percentage = self.image_display_widget.get_zoom_percentage() zoom_percentage = self.annotation_canvas.get_zoom_percentage()
info_text = ( info_text = (
f"File: {Path(self.current_image_path).name}\n" f"File: {Path(self.current_image_path).name}\n"
f"Size: {self.current_image.width}x{self.current_image.height} pixels\n" f"Size: {self.current_image.width}x{self.current_image.height} pixels\n"
@@ -194,9 +207,36 @@ class AnnotationTab(QWidget):
self.image_info_label.setText(info_text) self.image_info_label.setText(info_text)
def _on_zoom_changed(self, zoom_scale: float): def _on_zoom_changed(self, zoom_scale: float):
"""Handle zoom level changes from the image display widget.""" """Handle zoom level changes from the annotation canvas."""
self._update_image_info() self._update_image_info()
def _on_annotation_drawn(self, points: list):
"""Handle when an annotation stroke is drawn."""
current_class = self.annotation_tools.get_current_class()
if not current_class:
logger.warning("Annotation drawn but no object class selected")
QMessageBox.warning(
self,
"No Class Selected",
"Please select an object class before drawing annotations.",
)
return
logger.info(
f"Annotation drawn with {len(points)} points for class: {current_class['class_name']}"
)
# Future: Save annotation to database or export
def _on_class_selected(self, class_data: dict):
"""Handle when an object class is selected."""
logger.debug(f"Object class selected: {class_data['class_name']}")
def _on_clear_annotations(self):
"""Handle clearing all annotations."""
self.annotation_canvas.clear_annotations()
logger.info("Cleared all annotations")
def _restore_state(self): def _restore_state(self):
"""Restore splitter positions from settings.""" """Restore splitter positions from settings."""
settings = QSettings("microscopy_app", "object_detection") settings = QSettings("microscopy_app", "object_detection")

View File

@@ -1,5 +1,7 @@
"""GUI widgets for the microscopy object detection application.""" """GUI widgets for the microscopy object detection application."""
from src.gui.widgets.image_display_widget import ImageDisplayWidget from src.gui.widgets.image_display_widget import ImageDisplayWidget
from src.gui.widgets.annotation_canvas_widget import AnnotationCanvasWidget
from src.gui.widgets.annotation_tools_widget import AnnotationToolsWidget
__all__ = ["ImageDisplayWidget"] __all__ = ["ImageDisplayWidget", "AnnotationCanvasWidget", "AnnotationToolsWidget"]

View File

@@ -0,0 +1,406 @@
"""
Annotation canvas widget for drawing annotations on images.
Supports pen tool with color selection for manual annotation.
"""
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QScrollArea
from PySide6.QtGui import (
QPixmap,
QImage,
QPainter,
QPen,
QColor,
QKeyEvent,
QMouseEvent,
QPaintEvent,
)
from PySide6.QtCore import Qt, QEvent, Signal, QPoint
from typing import List, Optional, Tuple
import numpy as np
from src.utils.image import Image, ImageLoadError
from src.utils.logger import get_logger
logger = get_logger(__name__)
class AnnotationCanvasWidget(QWidget):
"""
Widget for displaying images and drawing annotations with pen tool.
Features:
- Display images with zoom functionality
- Pen tool for drawing annotations
- Configurable pen color and width
- Mouse-based drawing interface
- Zoom in/out with mouse wheel and keyboard
Signals:
zoom_changed: Emitted when zoom level changes (float zoom_scale)
annotation_drawn: Emitted when a new stroke is completed (list of points)
"""
zoom_changed = Signal(float)
annotation_drawn = Signal(list) # List of (x, y) points in normalized coordinates
def __init__(self, parent=None):
"""Initialize the annotation canvas widget."""
super().__init__(parent)
self.current_image = None
self.original_pixmap = None
self.annotation_pixmap = None # Overlay for annotations
self.zoom_scale = 1.0
self.zoom_min = 0.1
self.zoom_max = 10.0
self.zoom_step = 0.1
self.zoom_wheel_step = 0.15
# Drawing state
self.is_drawing = False
self.pen_enabled = False
self.pen_color = QColor(255, 0, 0, 128) # Default red with 50% alpha
self.pen_width = 3
self.current_stroke = [] # Points in current stroke
self.all_strokes = [] # All completed strokes
self._setup_ui()
def _setup_ui(self):
"""Setup user interface."""
layout = QVBoxLayout()
layout.setContentsMargins(0, 0, 0, 0)
# Scroll area for canvas
self.scroll_area = QScrollArea()
self.scroll_area.setWidgetResizable(True)
self.scroll_area.setMinimumHeight(400)
self.canvas_label = QLabel("No image loaded")
self.canvas_label.setAlignment(Qt.AlignCenter)
self.canvas_label.setStyleSheet(
"QLabel { background-color: #2b2b2b; color: #888; }"
)
self.canvas_label.setScaledContents(False)
self.canvas_label.setMouseTracking(True)
self.scroll_area.setWidget(self.canvas_label)
self.scroll_area.viewport().installEventFilter(self)
layout.addWidget(self.scroll_area)
self.setLayout(layout)
self.setFocusPolicy(Qt.StrongFocus)
def load_image(self, image: Image):
"""
Load and display an image.
Args:
image: Image object to display
"""
self.current_image = image
self.zoom_scale = 1.0
self.clear_annotations()
self._display_image()
logger.debug(
f"Loaded image into annotation canvas: {image.width}x{image.height}"
)
def clear(self):
"""Clear the displayed image and all annotations."""
self.current_image = None
self.original_pixmap = None
self.annotation_pixmap = None
self.zoom_scale = 1.0
self.clear_annotations()
self.canvas_label.setText("No image loaded")
self.canvas_label.setPixmap(QPixmap())
def clear_annotations(self):
"""Clear all drawn annotations."""
self.all_strokes = []
self.current_stroke = []
self.is_drawing = False
if self.annotation_pixmap:
self.annotation_pixmap.fill(Qt.transparent)
self._update_display()
def _display_image(self):
"""Display the current image in the canvas."""
if self.current_image is None:
return
try:
# Get RGB image data
if self.current_image.channels == 3:
image_data = self.current_image.get_rgb()
height, width, channels = image_data.shape
else:
image_data = self.current_image.get_grayscale()
height, width = image_data.shape
image_data = np.ascontiguousarray(image_data)
bytes_per_line = image_data.strides[0]
qimage = QImage(
image_data.data,
width,
height,
bytes_per_line,
self.current_image.qtimage_format,
)
self.original_pixmap = QPixmap.fromImage(qimage)
# Create transparent overlay for annotations
self.annotation_pixmap = QPixmap(self.original_pixmap.size())
self.annotation_pixmap.fill(Qt.transparent)
self._apply_zoom()
except Exception as e:
logger.error(f"Error displaying image: {e}")
raise ImageLoadError(f"Failed to display image: {str(e)}")
def _apply_zoom(self):
"""Apply current zoom level to the displayed image."""
if self.original_pixmap is None:
return
scaled_width = int(self.original_pixmap.width() * self.zoom_scale)
scaled_height = int(self.original_pixmap.height() * self.zoom_scale)
# Scale both image and annotations
scaled_image = self.original_pixmap.scaled(
scaled_width,
scaled_height,
Qt.KeepAspectRatio,
(
Qt.SmoothTransformation
if self.zoom_scale >= 1.0
else Qt.FastTransformation
),
)
scaled_annotations = self.annotation_pixmap.scaled(
scaled_width,
scaled_height,
Qt.KeepAspectRatio,
(
Qt.SmoothTransformation
if self.zoom_scale >= 1.0
else Qt.FastTransformation
),
)
# Composite image and annotations
combined = QPixmap(scaled_image.size())
painter = QPainter(combined)
painter.drawPixmap(0, 0, scaled_image)
painter.drawPixmap(0, 0, scaled_annotations)
painter.end()
self.canvas_label.setPixmap(combined)
self.canvas_label.setScaledContents(False)
self.canvas_label.adjustSize()
self.zoom_changed.emit(self.zoom_scale)
def _update_display(self):
"""Update display after drawing."""
self._apply_zoom()
def set_pen_enabled(self, enabled: bool):
"""Enable or disable pen tool."""
self.pen_enabled = enabled
if enabled:
self.canvas_label.setCursor(Qt.CrossCursor)
else:
self.canvas_label.setCursor(Qt.ArrowCursor)
def set_pen_color(self, color: QColor):
"""Set pen color."""
self.pen_color = color
def set_pen_width(self, width: int):
"""Set pen width."""
self.pen_width = max(1, width)
def get_zoom_percentage(self) -> int:
"""Get current zoom level as percentage."""
return int(self.zoom_scale * 100)
def zoom_in(self):
"""Zoom in on the image."""
if self.original_pixmap is None:
return
new_scale = self.zoom_scale + self.zoom_step
if new_scale <= self.zoom_max:
self.zoom_scale = new_scale
self._apply_zoom()
def zoom_out(self):
"""Zoom out from the image."""
if self.original_pixmap is None:
return
new_scale = self.zoom_scale - self.zoom_step
if new_scale >= self.zoom_min:
self.zoom_scale = new_scale
self._apply_zoom()
def reset_zoom(self):
"""Reset zoom to 100%."""
if self.original_pixmap is None:
return
self.zoom_scale = 1.0
self._apply_zoom()
def _canvas_to_image_coords(self, pos: QPoint) -> Optional[Tuple[int, int]]:
"""Convert canvas coordinates to image coordinates, accounting for zoom and centering."""
if self.original_pixmap is None or self.canvas_label.pixmap() is None:
return None
# Get the displayed pixmap size (after zoom)
displayed_pixmap = self.canvas_label.pixmap()
displayed_width = displayed_pixmap.width()
displayed_height = displayed_pixmap.height()
# Calculate offset due to label centering (label might be larger than pixmap)
label_width = self.canvas_label.width()
label_height = self.canvas_label.height()
offset_x = max(0, (label_width - displayed_width) // 2)
offset_y = max(0, (label_height - displayed_height) // 2)
# Adjust position for offset and convert to image coordinates
x = (pos.x() - offset_x) / self.zoom_scale
y = (pos.y() - offset_y) / self.zoom_scale
# Check bounds
if (
0 <= x < self.original_pixmap.width()
and 0 <= y < self.original_pixmap.height()
):
return (int(x), int(y))
return None
def _image_to_normalized_coords(self, x: int, y: int) -> Tuple[float, float]:
"""Convert image coordinates to normalized coordinates (0-1)."""
if self.original_pixmap is None:
return (0.0, 0.0)
norm_x = x / self.original_pixmap.width()
norm_y = y / self.original_pixmap.height()
return (norm_x, norm_y)
def mousePressEvent(self, event: QMouseEvent):
"""Handle mouse press events for drawing."""
if not self.pen_enabled or self.annotation_pixmap is None:
super().mousePressEvent(event)
return
if event.button() == Qt.LeftButton:
# Get accurate position using global coordinates
label_pos = self.canvas_label.mapFromGlobal(event.globalPos())
img_coords = self._canvas_to_image_coords(label_pos)
if img_coords:
self.is_drawing = True
self.current_stroke = [img_coords]
def mouseMoveEvent(self, event: QMouseEvent):
"""Handle mouse move events for drawing."""
if (
not self.is_drawing
or not self.pen_enabled
or self.annotation_pixmap is None
):
super().mouseMoveEvent(event)
return
# Get accurate position using global coordinates
label_pos = self.canvas_label.mapFromGlobal(event.globalPos())
img_coords = self._canvas_to_image_coords(label_pos)
if img_coords and len(self.current_stroke) > 0:
# Draw line from last point to current point
painter = QPainter(self.annotation_pixmap)
pen = QPen(
self.pen_color, self.pen_width, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin
)
painter.setPen(pen)
last_point = self.current_stroke[-1]
painter.drawLine(last_point[0], last_point[1], img_coords[0], img_coords[1])
painter.end()
self.current_stroke.append(img_coords)
self._update_display()
def mouseReleaseEvent(self, event: QMouseEvent):
"""Handle mouse release events to complete a stroke."""
if not self.is_drawing or event.button() != Qt.LeftButton:
super().mouseReleaseEvent(event)
return
self.is_drawing = False
if len(self.current_stroke) > 1:
# Convert to normalized coordinates and save stroke
normalized_stroke = [
self._image_to_normalized_coords(x, y) for x, y in self.current_stroke
]
self.all_strokes.append(
{
"points": normalized_stroke,
"color": self.pen_color.name(),
"alpha": self.pen_color.alpha(),
"width": self.pen_width,
}
)
# Emit signal with normalized coordinates
self.annotation_drawn.emit(normalized_stroke)
logger.debug(f"Completed stroke with {len(normalized_stroke)} points")
self.current_stroke = []
def get_all_strokes(self) -> List[dict]:
"""Get all drawn strokes with metadata."""
return self.all_strokes
def keyPressEvent(self, event: QKeyEvent):
"""Handle keyboard events for zooming."""
if event.key() in (Qt.Key_Plus, Qt.Key_Equal):
self.zoom_in()
event.accept()
elif event.key() == Qt.Key_Minus:
self.zoom_out()
event.accept()
elif event.key() == Qt.Key_0 and event.modifiers() == Qt.ControlModifier:
self.reset_zoom()
event.accept()
else:
super().keyPressEvent(event)
def eventFilter(self, obj, event: QEvent) -> bool:
"""Event filter to capture wheel events for zooming."""
if event.type() == QEvent.Wheel:
wheel_event = event
if self.original_pixmap is not None:
delta = wheel_event.angleDelta().y()
if delta > 0:
new_scale = self.zoom_scale + self.zoom_wheel_step
if new_scale <= self.zoom_max:
self.zoom_scale = new_scale
self._apply_zoom()
else:
new_scale = self.zoom_scale - self.zoom_wheel_step
if new_scale >= self.zoom_min:
self.zoom_scale = new_scale
self._apply_zoom()
return True
return super().eventFilter(obj, event)

View File

@@ -0,0 +1,352 @@
"""
Annotation tools widget for controlling annotation parameters.
Includes pen tool, color picker, class selection, and annotation management.
"""
from PySide6.QtWidgets import (
QWidget,
QVBoxLayout,
QHBoxLayout,
QLabel,
QGroupBox,
QPushButton,
QComboBox,
QSpinBox,
QColorDialog,
QInputDialog,
QMessageBox,
)
from PySide6.QtGui import QColor, QIcon, QPixmap, QPainter
from PySide6.QtCore import Qt, Signal
from typing import Optional, Dict
from src.database.db_manager import DatabaseManager
from src.utils.logger import get_logger
logger = get_logger(__name__)
class AnnotationToolsWidget(QWidget):
"""
Widget for annotation tool controls.
Features:
- Enable/disable pen tool
- Color selection for pen
- Object class selection
- Add new object classes
- Pen width control
- Clear annotations
Signals:
pen_enabled_changed: Emitted when pen tool is enabled/disabled (bool)
pen_color_changed: Emitted when pen color changes (QColor)
pen_width_changed: Emitted when pen width changes (int)
class_selected: Emitted when object class is selected (dict)
clear_annotations_requested: Emitted when clear button is pressed
"""
pen_enabled_changed = Signal(bool)
pen_color_changed = Signal(QColor)
pen_width_changed = Signal(int)
class_selected = Signal(dict)
clear_annotations_requested = Signal()
def __init__(self, db_manager: DatabaseManager, parent=None):
"""
Initialize annotation tools widget.
Args:
db_manager: Database manager instance
parent: Parent widget
"""
super().__init__(parent)
self.db_manager = db_manager
self.pen_enabled = False
self.current_color = QColor(255, 0, 0, 128) # Red with 50% alpha
self.current_class = None
self._setup_ui()
self._load_object_classes()
def _setup_ui(self):
"""Setup user interface."""
layout = QVBoxLayout()
# Pen Tool Group
pen_group = QGroupBox("Pen Tool")
pen_layout = QVBoxLayout()
# Enable/Disable pen
button_layout = QHBoxLayout()
self.pen_toggle_btn = QPushButton("Enable Pen")
self.pen_toggle_btn.setCheckable(True)
self.pen_toggle_btn.clicked.connect(self._on_pen_toggle)
button_layout.addWidget(self.pen_toggle_btn)
pen_layout.addLayout(button_layout)
# Pen width control
width_layout = QHBoxLayout()
width_layout.addWidget(QLabel("Pen Width:"))
self.pen_width_spin = QSpinBox()
self.pen_width_spin.setMinimum(1)
self.pen_width_spin.setMaximum(20)
self.pen_width_spin.setValue(3)
self.pen_width_spin.valueChanged.connect(self._on_pen_width_changed)
width_layout.addWidget(self.pen_width_spin)
width_layout.addStretch()
pen_layout.addLayout(width_layout)
# Color selection
color_layout = QHBoxLayout()
color_layout.addWidget(QLabel("Color:"))
self.color_btn = QPushButton()
self.color_btn.setFixedSize(40, 30)
self.color_btn.clicked.connect(self._on_color_picker)
self._update_color_button()
color_layout.addWidget(self.color_btn)
color_layout.addStretch()
pen_layout.addLayout(color_layout)
pen_group.setLayout(pen_layout)
layout.addWidget(pen_group)
# Object Class Group
class_group = QGroupBox("Object Class")
class_layout = QVBoxLayout()
# Class selection dropdown
self.class_combo = QComboBox()
self.class_combo.currentIndexChanged.connect(self._on_class_selected)
class_layout.addWidget(self.class_combo)
# Add class button
class_button_layout = QHBoxLayout()
self.add_class_btn = QPushButton("Add New Class")
self.add_class_btn.clicked.connect(self._on_add_class)
class_button_layout.addWidget(self.add_class_btn)
self.refresh_classes_btn = QPushButton("Refresh")
self.refresh_classes_btn.clicked.connect(self._load_object_classes)
class_button_layout.addWidget(self.refresh_classes_btn)
class_layout.addLayout(class_button_layout)
# Selected class info
self.class_info_label = QLabel("No class selected")
self.class_info_label.setWordWrap(True)
self.class_info_label.setStyleSheet(
"QLabel { color: #888; font-style: italic; }"
)
class_layout.addWidget(self.class_info_label)
class_group.setLayout(class_layout)
layout.addWidget(class_group)
# Actions Group
actions_group = QGroupBox("Actions")
actions_layout = QVBoxLayout()
self.clear_btn = QPushButton("Clear All Annotations")
self.clear_btn.clicked.connect(self._on_clear_annotations)
actions_layout.addWidget(self.clear_btn)
actions_group.setLayout(actions_layout)
layout.addWidget(actions_group)
layout.addStretch()
self.setLayout(layout)
def _update_color_button(self):
"""Update the color button appearance with current color."""
pixmap = QPixmap(40, 30)
pixmap.fill(self.current_color)
# Add border
painter = QPainter(pixmap)
painter.setPen(Qt.black)
painter.drawRect(0, 0, pixmap.width() - 1, pixmap.height() - 1)
painter.end()
self.color_btn.setIcon(QIcon(pixmap))
self.color_btn.setStyleSheet(f"background-color: {self.current_color.name()};")
def _load_object_classes(self):
"""Load object classes from database and populate combo box."""
try:
classes = self.db_manager.get_object_classes()
# Clear and repopulate combo box
self.class_combo.clear()
self.class_combo.addItem("-- Select Class --", None)
for cls in classes:
self.class_combo.addItem(cls["class_name"], cls)
logger.debug(f"Loaded {len(classes)} object classes")
except Exception as e:
logger.error(f"Error loading object classes: {e}")
QMessageBox.warning(
self, "Error", f"Failed to load object classes:\n{str(e)}"
)
def _on_pen_toggle(self, checked: bool):
"""Handle pen tool enable/disable."""
self.pen_enabled = checked
if checked:
self.pen_toggle_btn.setText("Disable Pen")
self.pen_toggle_btn.setStyleSheet(
"QPushButton { background-color: #4CAF50; }"
)
else:
self.pen_toggle_btn.setText("Enable Pen")
self.pen_toggle_btn.setStyleSheet("")
self.pen_enabled_changed.emit(self.pen_enabled)
logger.debug(f"Pen tool {'enabled' if checked else 'disabled'}")
def _on_pen_width_changed(self, width: int):
"""Handle pen width changes."""
self.pen_width_changed.emit(width)
logger.debug(f"Pen width changed to {width}")
def _on_color_picker(self):
"""Open color picker dialog with alpha support."""
color = QColorDialog.getColor(
self.current_color,
self,
"Select Pen Color",
QColorDialog.ShowAlphaChannel, # Enable alpha channel selection
)
if color.isValid():
self.current_color = color
self._update_color_button()
self.pen_color_changed.emit(color)
logger.debug(
f"Pen color changed to {color.name()} with alpha {color.alpha()}"
)
def _on_class_selected(self, index: int):
"""Handle object class selection."""
class_data = self.class_combo.currentData()
if class_data:
self.current_class = class_data
# Update info label
info_text = (
f"Class: {class_data['class_name']}\n" f"Color: {class_data['color']}"
)
if class_data.get("description"):
info_text += f"\nDescription: {class_data['description']}"
self.class_info_label.setText(info_text)
# Update pen color to match class color with semi-transparency
class_color = QColor(class_data["color"])
if class_color.isValid():
# Add 50% alpha for semi-transparency
class_color.setAlpha(128)
self.current_color = class_color
self._update_color_button()
self.pen_color_changed.emit(class_color)
self.class_selected.emit(class_data)
logger.debug(f"Selected class: {class_data['class_name']}")
else:
self.current_class = None
self.class_info_label.setText("No class selected")
def _on_add_class(self):
"""Handle adding a new object class."""
# Get class name
class_name, ok = QInputDialog.getText(
self, "Add Object Class", "Enter class name:"
)
if not ok or not class_name.strip():
return
class_name = class_name.strip()
# Check if class already exists
existing = self.db_manager.get_object_class_by_name(class_name)
if existing:
QMessageBox.warning(
self, "Class Exists", f"A class named '{class_name}' already exists."
)
return
# Get color
color = QColorDialog.getColor(self.current_color, self, "Select Class Color")
if not color.isValid():
return
# Get optional description
description, ok = QInputDialog.getText(
self, "Class Description", "Enter class description (optional):"
)
if not ok:
description = None
# Add to database
try:
class_id = self.db_manager.add_object_class(
class_name, color.name(), description.strip() if description else None
)
logger.info(f"Added new object class: {class_name} (ID: {class_id})")
# Reload classes and select the new one
self._load_object_classes()
# Find and select the newly added class
for i in range(self.class_combo.count()):
class_data = self.class_combo.itemData(i)
if class_data and class_data.get("id") == class_id:
self.class_combo.setCurrentIndex(i)
break
QMessageBox.information(
self, "Success", f"Class '{class_name}' added successfully!"
)
except Exception as e:
logger.error(f"Error adding object class: {e}")
QMessageBox.critical(
self, "Error", f"Failed to add object class:\n{str(e)}"
)
def _on_clear_annotations(self):
"""Handle clear annotations button."""
reply = QMessageBox.question(
self,
"Clear Annotations",
"Are you sure you want to clear all annotations?",
QMessageBox.Yes | QMessageBox.No,
QMessageBox.No,
)
if reply == QMessageBox.Yes:
self.clear_annotations_requested.emit()
logger.debug("Clear annotations requested")
def get_current_class(self) -> Optional[Dict]:
"""Get currently selected object class."""
return self.current_class
def get_pen_color(self) -> QColor:
"""Get current pen color."""
return self.current_color
def get_pen_width(self) -> int:
"""Get current pen width."""
return self.pen_width_spin.value()
def is_pen_enabled(self) -> bool:
"""Check if pen tool is enabled."""
return self.pen_enabled