503 lines
16 KiB
Python
503 lines
16 KiB
Python
"""
|
|
Annotation canvas widget for drawing annotations on images.
|
|
Supports pen tool with color selection for manual annotation.
|
|
"""
|
|
|
|
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QScrollArea
|
|
from PySide6.QtGui import (
|
|
QPixmap,
|
|
QImage,
|
|
QPainter,
|
|
QPen,
|
|
QColor,
|
|
QKeyEvent,
|
|
QMouseEvent,
|
|
QPaintEvent,
|
|
)
|
|
from PySide6.QtCore import Qt, QEvent, Signal, QPoint
|
|
from typing import List, Optional, Tuple
|
|
import numpy as np
|
|
|
|
from src.utils.image import Image, ImageLoadError
|
|
from src.utils.logger import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class AnnotationCanvasWidget(QWidget):
|
|
"""
|
|
Widget for displaying images and drawing annotations with pen tool.
|
|
|
|
Features:
|
|
- Display images with zoom functionality
|
|
- Pen tool for drawing annotations
|
|
- Configurable pen color and width
|
|
- Mouse-based drawing interface
|
|
- Zoom in/out with mouse wheel and keyboard
|
|
|
|
Signals:
|
|
zoom_changed: Emitted when zoom level changes (float zoom_scale)
|
|
annotation_drawn: Emitted when a new stroke is completed (list of points)
|
|
"""
|
|
|
|
zoom_changed = Signal(float)
|
|
annotation_drawn = Signal(list) # List of (x, y) points in normalized coordinates
|
|
|
|
def __init__(self, parent=None):
|
|
"""Initialize the annotation canvas widget."""
|
|
super().__init__(parent)
|
|
|
|
self.current_image = None
|
|
self.original_pixmap = None
|
|
self.annotation_pixmap = None # Overlay for annotations
|
|
self.zoom_scale = 1.0
|
|
self.zoom_min = 0.1
|
|
self.zoom_max = 10.0
|
|
self.zoom_step = 0.1
|
|
self.zoom_wheel_step = 0.15
|
|
|
|
# Drawing state
|
|
self.is_drawing = False
|
|
self.pen_enabled = False
|
|
self.pen_color = QColor(255, 0, 0, 128) # Default red with 50% alpha
|
|
self.pen_width = 3
|
|
self.current_stroke = [] # Points in current stroke
|
|
self.all_strokes = [] # All completed strokes
|
|
|
|
self._setup_ui()
|
|
|
|
def _setup_ui(self):
|
|
"""Setup user interface."""
|
|
layout = QVBoxLayout()
|
|
layout.setContentsMargins(0, 0, 0, 0)
|
|
|
|
# Scroll area for canvas
|
|
self.scroll_area = QScrollArea()
|
|
self.scroll_area.setWidgetResizable(True)
|
|
self.scroll_area.setMinimumHeight(400)
|
|
|
|
self.canvas_label = QLabel("No image loaded")
|
|
self.canvas_label.setAlignment(Qt.AlignCenter)
|
|
self.canvas_label.setStyleSheet(
|
|
"QLabel { background-color: #2b2b2b; color: #888; }"
|
|
)
|
|
self.canvas_label.setScaledContents(False)
|
|
self.canvas_label.setMouseTracking(True)
|
|
|
|
self.scroll_area.setWidget(self.canvas_label)
|
|
self.scroll_area.viewport().installEventFilter(self)
|
|
|
|
layout.addWidget(self.scroll_area)
|
|
self.setLayout(layout)
|
|
|
|
self.setFocusPolicy(Qt.StrongFocus)
|
|
|
|
def load_image(self, image: Image):
|
|
"""
|
|
Load and display an image.
|
|
|
|
Args:
|
|
image: Image object to display
|
|
"""
|
|
self.current_image = image
|
|
self.zoom_scale = 1.0
|
|
self.clear_annotations()
|
|
self._display_image()
|
|
logger.debug(
|
|
f"Loaded image into annotation canvas: {image.width}x{image.height}"
|
|
)
|
|
|
|
def clear(self):
|
|
"""Clear the displayed image and all annotations."""
|
|
self.current_image = None
|
|
self.original_pixmap = None
|
|
self.annotation_pixmap = None
|
|
self.zoom_scale = 1.0
|
|
self.clear_annotations()
|
|
self.canvas_label.setText("No image loaded")
|
|
self.canvas_label.setPixmap(QPixmap())
|
|
|
|
def clear_annotations(self):
|
|
"""Clear all drawn annotations."""
|
|
self.all_strokes = []
|
|
self.current_stroke = []
|
|
self.is_drawing = False
|
|
if self.annotation_pixmap:
|
|
self.annotation_pixmap.fill(Qt.transparent)
|
|
self._update_display()
|
|
|
|
def _display_image(self):
|
|
"""Display the current image in the canvas."""
|
|
if self.current_image is None:
|
|
return
|
|
|
|
try:
|
|
# Get RGB image data
|
|
if self.current_image.channels == 3:
|
|
image_data = self.current_image.get_rgb()
|
|
height, width, channels = image_data.shape
|
|
else:
|
|
image_data = self.current_image.get_grayscale()
|
|
height, width = image_data.shape
|
|
|
|
image_data = np.ascontiguousarray(image_data)
|
|
bytes_per_line = image_data.strides[0]
|
|
|
|
qimage = QImage(
|
|
image_data.data,
|
|
width,
|
|
height,
|
|
bytes_per_line,
|
|
self.current_image.qtimage_format,
|
|
)
|
|
|
|
self.original_pixmap = QPixmap.fromImage(qimage)
|
|
|
|
# Create transparent overlay for annotations
|
|
self.annotation_pixmap = QPixmap(self.original_pixmap.size())
|
|
self.annotation_pixmap.fill(Qt.transparent)
|
|
|
|
self._apply_zoom()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error displaying image: {e}")
|
|
raise ImageLoadError(f"Failed to display image: {str(e)}")
|
|
|
|
def _apply_zoom(self):
|
|
"""Apply current zoom level to the displayed image."""
|
|
if self.original_pixmap is None:
|
|
return
|
|
|
|
scaled_width = int(self.original_pixmap.width() * self.zoom_scale)
|
|
scaled_height = int(self.original_pixmap.height() * self.zoom_scale)
|
|
|
|
# Scale both image and annotations
|
|
scaled_image = self.original_pixmap.scaled(
|
|
scaled_width,
|
|
scaled_height,
|
|
Qt.KeepAspectRatio,
|
|
(
|
|
Qt.SmoothTransformation
|
|
if self.zoom_scale >= 1.0
|
|
else Qt.FastTransformation
|
|
),
|
|
)
|
|
|
|
scaled_annotations = self.annotation_pixmap.scaled(
|
|
scaled_width,
|
|
scaled_height,
|
|
Qt.KeepAspectRatio,
|
|
(
|
|
Qt.SmoothTransformation
|
|
if self.zoom_scale >= 1.0
|
|
else Qt.FastTransformation
|
|
),
|
|
)
|
|
|
|
# Composite image and annotations
|
|
combined = QPixmap(scaled_image.size())
|
|
painter = QPainter(combined)
|
|
painter.drawPixmap(0, 0, scaled_image)
|
|
painter.drawPixmap(0, 0, scaled_annotations)
|
|
painter.end()
|
|
|
|
self.canvas_label.setPixmap(combined)
|
|
self.canvas_label.setScaledContents(False)
|
|
self.canvas_label.adjustSize()
|
|
|
|
self.zoom_changed.emit(self.zoom_scale)
|
|
|
|
def _update_display(self):
|
|
"""Update display after drawing."""
|
|
self._apply_zoom()
|
|
|
|
def set_pen_enabled(self, enabled: bool):
|
|
"""Enable or disable pen tool."""
|
|
self.pen_enabled = enabled
|
|
if enabled:
|
|
self.canvas_label.setCursor(Qt.CrossCursor)
|
|
else:
|
|
self.canvas_label.setCursor(Qt.ArrowCursor)
|
|
|
|
def set_pen_color(self, color: QColor):
|
|
"""Set pen color."""
|
|
self.pen_color = color
|
|
|
|
def set_pen_width(self, width: int):
|
|
"""Set pen width."""
|
|
self.pen_width = max(1, width)
|
|
|
|
def get_zoom_percentage(self) -> int:
|
|
"""Get current zoom level as percentage."""
|
|
return int(self.zoom_scale * 100)
|
|
|
|
def zoom_in(self):
|
|
"""Zoom in on the image."""
|
|
if self.original_pixmap is None:
|
|
return
|
|
new_scale = self.zoom_scale + self.zoom_step
|
|
if new_scale <= self.zoom_max:
|
|
self.zoom_scale = new_scale
|
|
self._apply_zoom()
|
|
|
|
def zoom_out(self):
|
|
"""Zoom out from the image."""
|
|
if self.original_pixmap is None:
|
|
return
|
|
new_scale = self.zoom_scale - self.zoom_step
|
|
if new_scale >= self.zoom_min:
|
|
self.zoom_scale = new_scale
|
|
self._apply_zoom()
|
|
|
|
def reset_zoom(self):
|
|
"""Reset zoom to 100%."""
|
|
if self.original_pixmap is None:
|
|
return
|
|
self.zoom_scale = 1.0
|
|
self._apply_zoom()
|
|
|
|
def _canvas_to_image_coords(self, pos: QPoint) -> Optional[Tuple[int, int]]:
|
|
"""Convert canvas coordinates to image coordinates, accounting for zoom and centering."""
|
|
if self.original_pixmap is None or self.canvas_label.pixmap() is None:
|
|
return None
|
|
|
|
# Get the displayed pixmap size (after zoom)
|
|
displayed_pixmap = self.canvas_label.pixmap()
|
|
displayed_width = displayed_pixmap.width()
|
|
displayed_height = displayed_pixmap.height()
|
|
|
|
# Calculate offset due to label centering (label might be larger than pixmap)
|
|
label_width = self.canvas_label.width()
|
|
label_height = self.canvas_label.height()
|
|
offset_x = max(0, (label_width - displayed_width) // 2)
|
|
offset_y = max(0, (label_height - displayed_height) // 2)
|
|
|
|
# Adjust position for offset and convert to image coordinates
|
|
x = (pos.x() - offset_x) / self.zoom_scale
|
|
y = (pos.y() - offset_y) / self.zoom_scale
|
|
|
|
# Check bounds
|
|
if (
|
|
0 <= x < self.original_pixmap.width()
|
|
and 0 <= y < self.original_pixmap.height()
|
|
):
|
|
return (int(x), int(y))
|
|
return None
|
|
|
|
def _image_to_normalized_coords(self, x: int, y: int) -> Tuple[float, float]:
|
|
"""Convert image coordinates to normalized coordinates (0-1)."""
|
|
if self.original_pixmap is None:
|
|
return (0.0, 0.0)
|
|
|
|
norm_x = x / self.original_pixmap.width()
|
|
norm_y = y / self.original_pixmap.height()
|
|
return (norm_x, norm_y)
|
|
|
|
def mousePressEvent(self, event: QMouseEvent):
|
|
"""Handle mouse press events for drawing."""
|
|
if not self.pen_enabled or self.annotation_pixmap is None:
|
|
super().mousePressEvent(event)
|
|
return
|
|
|
|
if event.button() == Qt.LeftButton:
|
|
# Get accurate position using global coordinates
|
|
label_pos = self.canvas_label.mapFromGlobal(event.globalPos())
|
|
img_coords = self._canvas_to_image_coords(label_pos)
|
|
|
|
if img_coords:
|
|
self.is_drawing = True
|
|
self.current_stroke = [img_coords]
|
|
|
|
def mouseMoveEvent(self, event: QMouseEvent):
|
|
"""Handle mouse move events for drawing."""
|
|
if (
|
|
not self.is_drawing
|
|
or not self.pen_enabled
|
|
or self.annotation_pixmap is None
|
|
):
|
|
super().mouseMoveEvent(event)
|
|
return
|
|
|
|
# Get accurate position using global coordinates
|
|
label_pos = self.canvas_label.mapFromGlobal(event.globalPos())
|
|
img_coords = self._canvas_to_image_coords(label_pos)
|
|
|
|
if img_coords and len(self.current_stroke) > 0:
|
|
# Draw line from last point to current point
|
|
painter = QPainter(self.annotation_pixmap)
|
|
pen = QPen(
|
|
self.pen_color, self.pen_width, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin
|
|
)
|
|
painter.setPen(pen)
|
|
|
|
last_point = self.current_stroke[-1]
|
|
painter.drawLine(last_point[0], last_point[1], img_coords[0], img_coords[1])
|
|
painter.end()
|
|
|
|
self.current_stroke.append(img_coords)
|
|
self._update_display()
|
|
|
|
def mouseReleaseEvent(self, event: QMouseEvent):
|
|
"""Handle mouse release events to complete a stroke."""
|
|
if not self.is_drawing or event.button() != Qt.LeftButton:
|
|
super().mouseReleaseEvent(event)
|
|
return
|
|
|
|
self.is_drawing = False
|
|
|
|
if len(self.current_stroke) > 1:
|
|
# Convert to normalized coordinates and save stroke
|
|
normalized_stroke = [
|
|
self._image_to_normalized_coords(x, y) for x, y in self.current_stroke
|
|
]
|
|
self.all_strokes.append(
|
|
{
|
|
"points": normalized_stroke,
|
|
"color": self.pen_color.name(),
|
|
"alpha": self.pen_color.alpha(),
|
|
"width": self.pen_width,
|
|
}
|
|
)
|
|
|
|
# Emit signal with normalized coordinates
|
|
self.annotation_drawn.emit(normalized_stroke)
|
|
logger.debug(f"Completed stroke with {len(normalized_stroke)} points")
|
|
|
|
self.current_stroke = []
|
|
|
|
def get_all_strokes(self) -> List[dict]:
|
|
"""Get all drawn strokes with metadata."""
|
|
return self.all_strokes
|
|
|
|
def compute_annotation_bounds(self) -> Optional[Tuple[float, float, float, float]]:
|
|
"""
|
|
Compute bounding box that encompasses all annotation strokes.
|
|
|
|
Returns:
|
|
Tuple of (x_min, y_min, x_max, y_max) in normalized coordinates (0-1),
|
|
or None if no annotations exist.
|
|
"""
|
|
if not self.all_strokes:
|
|
return None
|
|
|
|
# Find min/max across all strokes
|
|
all_x = []
|
|
all_y = []
|
|
|
|
for stroke in self.all_strokes:
|
|
for x, y in stroke["points"]:
|
|
all_x.append(x)
|
|
all_y.append(y)
|
|
|
|
if not all_x:
|
|
return None
|
|
|
|
x_min = min(all_x)
|
|
y_min = min(all_y)
|
|
x_max = max(all_x)
|
|
y_max = max(all_y)
|
|
|
|
return (x_min, y_min, x_max, y_max)
|
|
|
|
def get_annotation_polyline(self) -> List[List[float]]:
|
|
"""
|
|
Get polyline coordinates representing all annotation strokes.
|
|
|
|
Returns:
|
|
List of [x, y] coordinate pairs in normalized coordinates (0-1).
|
|
"""
|
|
polyline = []
|
|
|
|
for stroke in self.all_strokes:
|
|
polyline.extend(stroke["points"])
|
|
|
|
return polyline
|
|
|
|
def draw_saved_polyline(
|
|
self, polyline: List[List[float]], color: str, width: int = 3
|
|
):
|
|
"""
|
|
Draw a polyline from database coordinates onto the annotation canvas.
|
|
|
|
Args:
|
|
polyline: List of [x, y] coordinate pairs in normalized coordinates (0-1)
|
|
color: Color hex string (e.g., '#FF0000')
|
|
width: Line width in pixels
|
|
"""
|
|
if not self.annotation_pixmap or not self.original_pixmap:
|
|
logger.warning("Cannot draw polyline: no image loaded")
|
|
return
|
|
|
|
if len(polyline) < 2:
|
|
logger.warning("Polyline has less than 2 points, cannot draw")
|
|
return
|
|
|
|
# Convert normalized coordinates to image coordinates
|
|
img_coords = []
|
|
for x_norm, y_norm in polyline:
|
|
x = int(x_norm * self.original_pixmap.width())
|
|
y = int(y_norm * self.original_pixmap.height())
|
|
img_coords.append((x, y))
|
|
|
|
# Draw polyline on annotation pixmap
|
|
painter = QPainter(self.annotation_pixmap)
|
|
pen_color = QColor(color)
|
|
pen_color.setAlpha(128) # Add semi-transparency
|
|
pen = QPen(pen_color, width, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin)
|
|
painter.setPen(pen)
|
|
|
|
# Draw lines between consecutive points
|
|
for i in range(len(img_coords) - 1):
|
|
x1, y1 = img_coords[i]
|
|
x2, y2 = img_coords[i + 1]
|
|
painter.drawLine(x1, y1, x2, y2)
|
|
|
|
painter.end()
|
|
|
|
# Store in all_strokes for consistency
|
|
self.all_strokes.append(
|
|
{"points": polyline, "color": color, "alpha": 128, "width": width}
|
|
)
|
|
|
|
# Update display
|
|
self._update_display()
|
|
logger.debug(
|
|
f"Drew saved polyline with {len(polyline)} points in color {color}"
|
|
)
|
|
|
|
def keyPressEvent(self, event: QKeyEvent):
|
|
"""Handle keyboard events for zooming."""
|
|
if event.key() in (Qt.Key_Plus, Qt.Key_Equal):
|
|
self.zoom_in()
|
|
event.accept()
|
|
elif event.key() == Qt.Key_Minus:
|
|
self.zoom_out()
|
|
event.accept()
|
|
elif event.key() == Qt.Key_0 and event.modifiers() == Qt.ControlModifier:
|
|
self.reset_zoom()
|
|
event.accept()
|
|
else:
|
|
super().keyPressEvent(event)
|
|
|
|
def eventFilter(self, obj, event: QEvent) -> bool:
|
|
"""Event filter to capture wheel events for zooming."""
|
|
if event.type() == QEvent.Wheel:
|
|
wheel_event = event
|
|
if self.original_pixmap is not None:
|
|
delta = wheel_event.angleDelta().y()
|
|
|
|
if delta > 0:
|
|
new_scale = self.zoom_scale + self.zoom_wheel_step
|
|
if new_scale <= self.zoom_max:
|
|
self.zoom_scale = new_scale
|
|
self._apply_zoom()
|
|
else:
|
|
new_scale = self.zoom_scale - self.zoom_wheel_step
|
|
if new_scale >= self.zoom_min:
|
|
self.zoom_scale = new_scale
|
|
self._apply_zoom()
|
|
|
|
return True
|
|
|
|
return super().eventFilter(obj, event)
|