Adding pen tool for annotation

This commit is contained in:
2025-12-08 23:15:54 +02:00
parent f84dea0bff
commit fc22479621
6 changed files with 1079 additions and 54 deletions

View File

@@ -1,5 +1,7 @@
"""GUI widgets for the microscopy object detection application."""
from src.gui.widgets.image_display_widget import ImageDisplayWidget
from src.gui.widgets.annotation_canvas_widget import AnnotationCanvasWidget
from src.gui.widgets.annotation_tools_widget import AnnotationToolsWidget
__all__ = ["ImageDisplayWidget"]
__all__ = ["ImageDisplayWidget", "AnnotationCanvasWidget", "AnnotationToolsWidget"]

View File

@@ -0,0 +1,406 @@
"""
Annotation canvas widget for drawing annotations on images.
Supports pen tool with color selection for manual annotation.
"""
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QScrollArea
from PySide6.QtGui import (
QPixmap,
QImage,
QPainter,
QPen,
QColor,
QKeyEvent,
QMouseEvent,
QPaintEvent,
)
from PySide6.QtCore import Qt, QEvent, Signal, QPoint
from typing import List, Optional, Tuple
import numpy as np
from src.utils.image import Image, ImageLoadError
from src.utils.logger import get_logger
logger = get_logger(__name__)
class AnnotationCanvasWidget(QWidget):
"""
Widget for displaying images and drawing annotations with pen tool.
Features:
- Display images with zoom functionality
- Pen tool for drawing annotations
- Configurable pen color and width
- Mouse-based drawing interface
- Zoom in/out with mouse wheel and keyboard
Signals:
zoom_changed: Emitted when zoom level changes (float zoom_scale)
annotation_drawn: Emitted when a new stroke is completed (list of points)
"""
zoom_changed = Signal(float)
annotation_drawn = Signal(list) # List of (x, y) points in normalized coordinates
def __init__(self, parent=None):
"""Initialize the annotation canvas widget."""
super().__init__(parent)
self.current_image = None
self.original_pixmap = None
self.annotation_pixmap = None # Overlay for annotations
self.zoom_scale = 1.0
self.zoom_min = 0.1
self.zoom_max = 10.0
self.zoom_step = 0.1
self.zoom_wheel_step = 0.15
# Drawing state
self.is_drawing = False
self.pen_enabled = False
self.pen_color = QColor(255, 0, 0, 128) # Default red with 50% alpha
self.pen_width = 3
self.current_stroke = [] # Points in current stroke
self.all_strokes = [] # All completed strokes
self._setup_ui()
def _setup_ui(self):
"""Setup user interface."""
layout = QVBoxLayout()
layout.setContentsMargins(0, 0, 0, 0)
# Scroll area for canvas
self.scroll_area = QScrollArea()
self.scroll_area.setWidgetResizable(True)
self.scroll_area.setMinimumHeight(400)
self.canvas_label = QLabel("No image loaded")
self.canvas_label.setAlignment(Qt.AlignCenter)
self.canvas_label.setStyleSheet(
"QLabel { background-color: #2b2b2b; color: #888; }"
)
self.canvas_label.setScaledContents(False)
self.canvas_label.setMouseTracking(True)
self.scroll_area.setWidget(self.canvas_label)
self.scroll_area.viewport().installEventFilter(self)
layout.addWidget(self.scroll_area)
self.setLayout(layout)
self.setFocusPolicy(Qt.StrongFocus)
def load_image(self, image: Image):
"""
Load and display an image.
Args:
image: Image object to display
"""
self.current_image = image
self.zoom_scale = 1.0
self.clear_annotations()
self._display_image()
logger.debug(
f"Loaded image into annotation canvas: {image.width}x{image.height}"
)
def clear(self):
"""Clear the displayed image and all annotations."""
self.current_image = None
self.original_pixmap = None
self.annotation_pixmap = None
self.zoom_scale = 1.0
self.clear_annotations()
self.canvas_label.setText("No image loaded")
self.canvas_label.setPixmap(QPixmap())
def clear_annotations(self):
"""Clear all drawn annotations."""
self.all_strokes = []
self.current_stroke = []
self.is_drawing = False
if self.annotation_pixmap:
self.annotation_pixmap.fill(Qt.transparent)
self._update_display()
def _display_image(self):
"""Display the current image in the canvas."""
if self.current_image is None:
return
try:
# Get RGB image data
if self.current_image.channels == 3:
image_data = self.current_image.get_rgb()
height, width, channels = image_data.shape
else:
image_data = self.current_image.get_grayscale()
height, width = image_data.shape
image_data = np.ascontiguousarray(image_data)
bytes_per_line = image_data.strides[0]
qimage = QImage(
image_data.data,
width,
height,
bytes_per_line,
self.current_image.qtimage_format,
)
self.original_pixmap = QPixmap.fromImage(qimage)
# Create transparent overlay for annotations
self.annotation_pixmap = QPixmap(self.original_pixmap.size())
self.annotation_pixmap.fill(Qt.transparent)
self._apply_zoom()
except Exception as e:
logger.error(f"Error displaying image: {e}")
raise ImageLoadError(f"Failed to display image: {str(e)}")
def _apply_zoom(self):
"""Apply current zoom level to the displayed image."""
if self.original_pixmap is None:
return
scaled_width = int(self.original_pixmap.width() * self.zoom_scale)
scaled_height = int(self.original_pixmap.height() * self.zoom_scale)
# Scale both image and annotations
scaled_image = self.original_pixmap.scaled(
scaled_width,
scaled_height,
Qt.KeepAspectRatio,
(
Qt.SmoothTransformation
if self.zoom_scale >= 1.0
else Qt.FastTransformation
),
)
scaled_annotations = self.annotation_pixmap.scaled(
scaled_width,
scaled_height,
Qt.KeepAspectRatio,
(
Qt.SmoothTransformation
if self.zoom_scale >= 1.0
else Qt.FastTransformation
),
)
# Composite image and annotations
combined = QPixmap(scaled_image.size())
painter = QPainter(combined)
painter.drawPixmap(0, 0, scaled_image)
painter.drawPixmap(0, 0, scaled_annotations)
painter.end()
self.canvas_label.setPixmap(combined)
self.canvas_label.setScaledContents(False)
self.canvas_label.adjustSize()
self.zoom_changed.emit(self.zoom_scale)
def _update_display(self):
"""Update display after drawing."""
self._apply_zoom()
def set_pen_enabled(self, enabled: bool):
"""Enable or disable pen tool."""
self.pen_enabled = enabled
if enabled:
self.canvas_label.setCursor(Qt.CrossCursor)
else:
self.canvas_label.setCursor(Qt.ArrowCursor)
def set_pen_color(self, color: QColor):
"""Set pen color."""
self.pen_color = color
def set_pen_width(self, width: int):
"""Set pen width."""
self.pen_width = max(1, width)
def get_zoom_percentage(self) -> int:
"""Get current zoom level as percentage."""
return int(self.zoom_scale * 100)
def zoom_in(self):
"""Zoom in on the image."""
if self.original_pixmap is None:
return
new_scale = self.zoom_scale + self.zoom_step
if new_scale <= self.zoom_max:
self.zoom_scale = new_scale
self._apply_zoom()
def zoom_out(self):
"""Zoom out from the image."""
if self.original_pixmap is None:
return
new_scale = self.zoom_scale - self.zoom_step
if new_scale >= self.zoom_min:
self.zoom_scale = new_scale
self._apply_zoom()
def reset_zoom(self):
"""Reset zoom to 100%."""
if self.original_pixmap is None:
return
self.zoom_scale = 1.0
self._apply_zoom()
def _canvas_to_image_coords(self, pos: QPoint) -> Optional[Tuple[int, int]]:
"""Convert canvas coordinates to image coordinates, accounting for zoom and centering."""
if self.original_pixmap is None or self.canvas_label.pixmap() is None:
return None
# Get the displayed pixmap size (after zoom)
displayed_pixmap = self.canvas_label.pixmap()
displayed_width = displayed_pixmap.width()
displayed_height = displayed_pixmap.height()
# Calculate offset due to label centering (label might be larger than pixmap)
label_width = self.canvas_label.width()
label_height = self.canvas_label.height()
offset_x = max(0, (label_width - displayed_width) // 2)
offset_y = max(0, (label_height - displayed_height) // 2)
# Adjust position for offset and convert to image coordinates
x = (pos.x() - offset_x) / self.zoom_scale
y = (pos.y() - offset_y) / self.zoom_scale
# Check bounds
if (
0 <= x < self.original_pixmap.width()
and 0 <= y < self.original_pixmap.height()
):
return (int(x), int(y))
return None
def _image_to_normalized_coords(self, x: int, y: int) -> Tuple[float, float]:
"""Convert image coordinates to normalized coordinates (0-1)."""
if self.original_pixmap is None:
return (0.0, 0.0)
norm_x = x / self.original_pixmap.width()
norm_y = y / self.original_pixmap.height()
return (norm_x, norm_y)
def mousePressEvent(self, event: QMouseEvent):
"""Handle mouse press events for drawing."""
if not self.pen_enabled or self.annotation_pixmap is None:
super().mousePressEvent(event)
return
if event.button() == Qt.LeftButton:
# Get accurate position using global coordinates
label_pos = self.canvas_label.mapFromGlobal(event.globalPos())
img_coords = self._canvas_to_image_coords(label_pos)
if img_coords:
self.is_drawing = True
self.current_stroke = [img_coords]
def mouseMoveEvent(self, event: QMouseEvent):
"""Handle mouse move events for drawing."""
if (
not self.is_drawing
or not self.pen_enabled
or self.annotation_pixmap is None
):
super().mouseMoveEvent(event)
return
# Get accurate position using global coordinates
label_pos = self.canvas_label.mapFromGlobal(event.globalPos())
img_coords = self._canvas_to_image_coords(label_pos)
if img_coords and len(self.current_stroke) > 0:
# Draw line from last point to current point
painter = QPainter(self.annotation_pixmap)
pen = QPen(
self.pen_color, self.pen_width, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin
)
painter.setPen(pen)
last_point = self.current_stroke[-1]
painter.drawLine(last_point[0], last_point[1], img_coords[0], img_coords[1])
painter.end()
self.current_stroke.append(img_coords)
self._update_display()
def mouseReleaseEvent(self, event: QMouseEvent):
"""Handle mouse release events to complete a stroke."""
if not self.is_drawing or event.button() != Qt.LeftButton:
super().mouseReleaseEvent(event)
return
self.is_drawing = False
if len(self.current_stroke) > 1:
# Convert to normalized coordinates and save stroke
normalized_stroke = [
self._image_to_normalized_coords(x, y) for x, y in self.current_stroke
]
self.all_strokes.append(
{
"points": normalized_stroke,
"color": self.pen_color.name(),
"alpha": self.pen_color.alpha(),
"width": self.pen_width,
}
)
# Emit signal with normalized coordinates
self.annotation_drawn.emit(normalized_stroke)
logger.debug(f"Completed stroke with {len(normalized_stroke)} points")
self.current_stroke = []
def get_all_strokes(self) -> List[dict]:
"""Get all drawn strokes with metadata."""
return self.all_strokes
def keyPressEvent(self, event: QKeyEvent):
"""Handle keyboard events for zooming."""
if event.key() in (Qt.Key_Plus, Qt.Key_Equal):
self.zoom_in()
event.accept()
elif event.key() == Qt.Key_Minus:
self.zoom_out()
event.accept()
elif event.key() == Qt.Key_0 and event.modifiers() == Qt.ControlModifier:
self.reset_zoom()
event.accept()
else:
super().keyPressEvent(event)
def eventFilter(self, obj, event: QEvent) -> bool:
"""Event filter to capture wheel events for zooming."""
if event.type() == QEvent.Wheel:
wheel_event = event
if self.original_pixmap is not None:
delta = wheel_event.angleDelta().y()
if delta > 0:
new_scale = self.zoom_scale + self.zoom_wheel_step
if new_scale <= self.zoom_max:
self.zoom_scale = new_scale
self._apply_zoom()
else:
new_scale = self.zoom_scale - self.zoom_wheel_step
if new_scale >= self.zoom_min:
self.zoom_scale = new_scale
self._apply_zoom()
return True
return super().eventFilter(obj, event)

View File

@@ -0,0 +1,352 @@
"""
Annotation tools widget for controlling annotation parameters.
Includes pen tool, color picker, class selection, and annotation management.
"""
from PySide6.QtWidgets import (
QWidget,
QVBoxLayout,
QHBoxLayout,
QLabel,
QGroupBox,
QPushButton,
QComboBox,
QSpinBox,
QColorDialog,
QInputDialog,
QMessageBox,
)
from PySide6.QtGui import QColor, QIcon, QPixmap, QPainter
from PySide6.QtCore import Qt, Signal
from typing import Optional, Dict
from src.database.db_manager import DatabaseManager
from src.utils.logger import get_logger
logger = get_logger(__name__)
class AnnotationToolsWidget(QWidget):
"""
Widget for annotation tool controls.
Features:
- Enable/disable pen tool
- Color selection for pen
- Object class selection
- Add new object classes
- Pen width control
- Clear annotations
Signals:
pen_enabled_changed: Emitted when pen tool is enabled/disabled (bool)
pen_color_changed: Emitted when pen color changes (QColor)
pen_width_changed: Emitted when pen width changes (int)
class_selected: Emitted when object class is selected (dict)
clear_annotations_requested: Emitted when clear button is pressed
"""
pen_enabled_changed = Signal(bool)
pen_color_changed = Signal(QColor)
pen_width_changed = Signal(int)
class_selected = Signal(dict)
clear_annotations_requested = Signal()
def __init__(self, db_manager: DatabaseManager, parent=None):
"""
Initialize annotation tools widget.
Args:
db_manager: Database manager instance
parent: Parent widget
"""
super().__init__(parent)
self.db_manager = db_manager
self.pen_enabled = False
self.current_color = QColor(255, 0, 0, 128) # Red with 50% alpha
self.current_class = None
self._setup_ui()
self._load_object_classes()
def _setup_ui(self):
"""Setup user interface."""
layout = QVBoxLayout()
# Pen Tool Group
pen_group = QGroupBox("Pen Tool")
pen_layout = QVBoxLayout()
# Enable/Disable pen
button_layout = QHBoxLayout()
self.pen_toggle_btn = QPushButton("Enable Pen")
self.pen_toggle_btn.setCheckable(True)
self.pen_toggle_btn.clicked.connect(self._on_pen_toggle)
button_layout.addWidget(self.pen_toggle_btn)
pen_layout.addLayout(button_layout)
# Pen width control
width_layout = QHBoxLayout()
width_layout.addWidget(QLabel("Pen Width:"))
self.pen_width_spin = QSpinBox()
self.pen_width_spin.setMinimum(1)
self.pen_width_spin.setMaximum(20)
self.pen_width_spin.setValue(3)
self.pen_width_spin.valueChanged.connect(self._on_pen_width_changed)
width_layout.addWidget(self.pen_width_spin)
width_layout.addStretch()
pen_layout.addLayout(width_layout)
# Color selection
color_layout = QHBoxLayout()
color_layout.addWidget(QLabel("Color:"))
self.color_btn = QPushButton()
self.color_btn.setFixedSize(40, 30)
self.color_btn.clicked.connect(self._on_color_picker)
self._update_color_button()
color_layout.addWidget(self.color_btn)
color_layout.addStretch()
pen_layout.addLayout(color_layout)
pen_group.setLayout(pen_layout)
layout.addWidget(pen_group)
# Object Class Group
class_group = QGroupBox("Object Class")
class_layout = QVBoxLayout()
# Class selection dropdown
self.class_combo = QComboBox()
self.class_combo.currentIndexChanged.connect(self._on_class_selected)
class_layout.addWidget(self.class_combo)
# Add class button
class_button_layout = QHBoxLayout()
self.add_class_btn = QPushButton("Add New Class")
self.add_class_btn.clicked.connect(self._on_add_class)
class_button_layout.addWidget(self.add_class_btn)
self.refresh_classes_btn = QPushButton("Refresh")
self.refresh_classes_btn.clicked.connect(self._load_object_classes)
class_button_layout.addWidget(self.refresh_classes_btn)
class_layout.addLayout(class_button_layout)
# Selected class info
self.class_info_label = QLabel("No class selected")
self.class_info_label.setWordWrap(True)
self.class_info_label.setStyleSheet(
"QLabel { color: #888; font-style: italic; }"
)
class_layout.addWidget(self.class_info_label)
class_group.setLayout(class_layout)
layout.addWidget(class_group)
# Actions Group
actions_group = QGroupBox("Actions")
actions_layout = QVBoxLayout()
self.clear_btn = QPushButton("Clear All Annotations")
self.clear_btn.clicked.connect(self._on_clear_annotations)
actions_layout.addWidget(self.clear_btn)
actions_group.setLayout(actions_layout)
layout.addWidget(actions_group)
layout.addStretch()
self.setLayout(layout)
def _update_color_button(self):
"""Update the color button appearance with current color."""
pixmap = QPixmap(40, 30)
pixmap.fill(self.current_color)
# Add border
painter = QPainter(pixmap)
painter.setPen(Qt.black)
painter.drawRect(0, 0, pixmap.width() - 1, pixmap.height() - 1)
painter.end()
self.color_btn.setIcon(QIcon(pixmap))
self.color_btn.setStyleSheet(f"background-color: {self.current_color.name()};")
def _load_object_classes(self):
"""Load object classes from database and populate combo box."""
try:
classes = self.db_manager.get_object_classes()
# Clear and repopulate combo box
self.class_combo.clear()
self.class_combo.addItem("-- Select Class --", None)
for cls in classes:
self.class_combo.addItem(cls["class_name"], cls)
logger.debug(f"Loaded {len(classes)} object classes")
except Exception as e:
logger.error(f"Error loading object classes: {e}")
QMessageBox.warning(
self, "Error", f"Failed to load object classes:\n{str(e)}"
)
def _on_pen_toggle(self, checked: bool):
"""Handle pen tool enable/disable."""
self.pen_enabled = checked
if checked:
self.pen_toggle_btn.setText("Disable Pen")
self.pen_toggle_btn.setStyleSheet(
"QPushButton { background-color: #4CAF50; }"
)
else:
self.pen_toggle_btn.setText("Enable Pen")
self.pen_toggle_btn.setStyleSheet("")
self.pen_enabled_changed.emit(self.pen_enabled)
logger.debug(f"Pen tool {'enabled' if checked else 'disabled'}")
def _on_pen_width_changed(self, width: int):
"""Handle pen width changes."""
self.pen_width_changed.emit(width)
logger.debug(f"Pen width changed to {width}")
def _on_color_picker(self):
"""Open color picker dialog with alpha support."""
color = QColorDialog.getColor(
self.current_color,
self,
"Select Pen Color",
QColorDialog.ShowAlphaChannel, # Enable alpha channel selection
)
if color.isValid():
self.current_color = color
self._update_color_button()
self.pen_color_changed.emit(color)
logger.debug(
f"Pen color changed to {color.name()} with alpha {color.alpha()}"
)
def _on_class_selected(self, index: int):
"""Handle object class selection."""
class_data = self.class_combo.currentData()
if class_data:
self.current_class = class_data
# Update info label
info_text = (
f"Class: {class_data['class_name']}\n" f"Color: {class_data['color']}"
)
if class_data.get("description"):
info_text += f"\nDescription: {class_data['description']}"
self.class_info_label.setText(info_text)
# Update pen color to match class color with semi-transparency
class_color = QColor(class_data["color"])
if class_color.isValid():
# Add 50% alpha for semi-transparency
class_color.setAlpha(128)
self.current_color = class_color
self._update_color_button()
self.pen_color_changed.emit(class_color)
self.class_selected.emit(class_data)
logger.debug(f"Selected class: {class_data['class_name']}")
else:
self.current_class = None
self.class_info_label.setText("No class selected")
def _on_add_class(self):
"""Handle adding a new object class."""
# Get class name
class_name, ok = QInputDialog.getText(
self, "Add Object Class", "Enter class name:"
)
if not ok or not class_name.strip():
return
class_name = class_name.strip()
# Check if class already exists
existing = self.db_manager.get_object_class_by_name(class_name)
if existing:
QMessageBox.warning(
self, "Class Exists", f"A class named '{class_name}' already exists."
)
return
# Get color
color = QColorDialog.getColor(self.current_color, self, "Select Class Color")
if not color.isValid():
return
# Get optional description
description, ok = QInputDialog.getText(
self, "Class Description", "Enter class description (optional):"
)
if not ok:
description = None
# Add to database
try:
class_id = self.db_manager.add_object_class(
class_name, color.name(), description.strip() if description else None
)
logger.info(f"Added new object class: {class_name} (ID: {class_id})")
# Reload classes and select the new one
self._load_object_classes()
# Find and select the newly added class
for i in range(self.class_combo.count()):
class_data = self.class_combo.itemData(i)
if class_data and class_data.get("id") == class_id:
self.class_combo.setCurrentIndex(i)
break
QMessageBox.information(
self, "Success", f"Class '{class_name}' added successfully!"
)
except Exception as e:
logger.error(f"Error adding object class: {e}")
QMessageBox.critical(
self, "Error", f"Failed to add object class:\n{str(e)}"
)
def _on_clear_annotations(self):
"""Handle clear annotations button."""
reply = QMessageBox.question(
self,
"Clear Annotations",
"Are you sure you want to clear all annotations?",
QMessageBox.Yes | QMessageBox.No,
QMessageBox.No,
)
if reply == QMessageBox.Yes:
self.clear_annotations_requested.emit()
logger.debug("Clear annotations requested")
def get_current_class(self) -> Optional[Dict]:
"""Get currently selected object class."""
return self.current_class
def get_pen_color(self) -> QColor:
"""Get current pen color."""
return self.current_color
def get_pen_width(self) -> int:
"""Get current pen width."""
return self.pen_width_spin.value()
def is_pen_enabled(self) -> bool:
"""Check if pen tool is enabled."""
return self.pen_enabled