""" Annotation canvas widget for drawing annotations on images. Supports pen tool with color selection for manual annotation. """ from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QScrollArea from PySide6.QtGui import ( QPixmap, QImage, QPainter, QPen, QColor, QKeyEvent, QMouseEvent, QPaintEvent, ) from PySide6.QtCore import Qt, QEvent, Signal, QPoint from typing import List, Optional, Tuple import numpy as np from src.utils.image import Image, ImageLoadError from src.utils.logger import get_logger logger = get_logger(__name__) class AnnotationCanvasWidget(QWidget): """ Widget for displaying images and drawing annotations with pen tool. Features: - Display images with zoom functionality - Pen tool for drawing annotations - Configurable pen color and width - Mouse-based drawing interface - Zoom in/out with mouse wheel and keyboard Signals: zoom_changed: Emitted when zoom level changes (float zoom_scale) annotation_drawn: Emitted when a new stroke is completed (list of points) """ zoom_changed = Signal(float) annotation_drawn = Signal(list) # List of (x, y) points in normalized coordinates def __init__(self, parent=None): """Initialize the annotation canvas widget.""" super().__init__(parent) self.current_image = None self.original_pixmap = None self.annotation_pixmap = None # Overlay for annotations self.zoom_scale = 1.0 self.zoom_min = 0.1 self.zoom_max = 10.0 self.zoom_step = 0.1 self.zoom_wheel_step = 0.15 # Drawing state self.is_drawing = False self.pen_enabled = False self.pen_color = QColor(255, 0, 0, 128) # Default red with 50% alpha self.pen_width = 3 self.current_stroke = [] # Points in current stroke self.all_strokes = [] # All completed strokes self._setup_ui() def _setup_ui(self): """Setup user interface.""" layout = QVBoxLayout() layout.setContentsMargins(0, 0, 0, 0) # Scroll area for canvas self.scroll_area = QScrollArea() self.scroll_area.setWidgetResizable(True) self.scroll_area.setMinimumHeight(400) self.canvas_label = QLabel("No image loaded") self.canvas_label.setAlignment(Qt.AlignCenter) self.canvas_label.setStyleSheet( "QLabel { background-color: #2b2b2b; color: #888; }" ) self.canvas_label.setScaledContents(False) self.canvas_label.setMouseTracking(True) self.scroll_area.setWidget(self.canvas_label) self.scroll_area.viewport().installEventFilter(self) layout.addWidget(self.scroll_area) self.setLayout(layout) self.setFocusPolicy(Qt.StrongFocus) def load_image(self, image: Image): """ Load and display an image. Args: image: Image object to display """ self.current_image = image self.zoom_scale = 1.0 self.clear_annotations() self._display_image() logger.debug( f"Loaded image into annotation canvas: {image.width}x{image.height}" ) def clear(self): """Clear the displayed image and all annotations.""" self.current_image = None self.original_pixmap = None self.annotation_pixmap = None self.zoom_scale = 1.0 self.clear_annotations() self.canvas_label.setText("No image loaded") self.canvas_label.setPixmap(QPixmap()) def clear_annotations(self): """Clear all drawn annotations.""" self.all_strokes = [] self.current_stroke = [] self.is_drawing = False if self.annotation_pixmap: self.annotation_pixmap.fill(Qt.transparent) self._update_display() def _display_image(self): """Display the current image in the canvas.""" if self.current_image is None: return try: # Get RGB image data if self.current_image.channels == 3: image_data = self.current_image.get_rgb() height, width, channels = image_data.shape else: image_data = self.current_image.get_grayscale() height, width = image_data.shape image_data = np.ascontiguousarray(image_data) bytes_per_line = image_data.strides[0] qimage = QImage( image_data.data, width, height, bytes_per_line, self.current_image.qtimage_format, ) self.original_pixmap = QPixmap.fromImage(qimage) # Create transparent overlay for annotations self.annotation_pixmap = QPixmap(self.original_pixmap.size()) self.annotation_pixmap.fill(Qt.transparent) self._apply_zoom() except Exception as e: logger.error(f"Error displaying image: {e}") raise ImageLoadError(f"Failed to display image: {str(e)}") def _apply_zoom(self): """Apply current zoom level to the displayed image.""" if self.original_pixmap is None: return scaled_width = int(self.original_pixmap.width() * self.zoom_scale) scaled_height = int(self.original_pixmap.height() * self.zoom_scale) # Scale both image and annotations scaled_image = self.original_pixmap.scaled( scaled_width, scaled_height, Qt.KeepAspectRatio, ( Qt.SmoothTransformation if self.zoom_scale >= 1.0 else Qt.FastTransformation ), ) scaled_annotations = self.annotation_pixmap.scaled( scaled_width, scaled_height, Qt.KeepAspectRatio, ( Qt.SmoothTransformation if self.zoom_scale >= 1.0 else Qt.FastTransformation ), ) # Composite image and annotations combined = QPixmap(scaled_image.size()) painter = QPainter(combined) painter.drawPixmap(0, 0, scaled_image) painter.drawPixmap(0, 0, scaled_annotations) painter.end() self.canvas_label.setPixmap(combined) self.canvas_label.setScaledContents(False) self.canvas_label.adjustSize() self.zoom_changed.emit(self.zoom_scale) def _update_display(self): """Update display after drawing.""" self._apply_zoom() def set_pen_enabled(self, enabled: bool): """Enable or disable pen tool.""" self.pen_enabled = enabled if enabled: self.canvas_label.setCursor(Qt.CrossCursor) else: self.canvas_label.setCursor(Qt.ArrowCursor) def set_pen_color(self, color: QColor): """Set pen color.""" self.pen_color = color def set_pen_width(self, width: int): """Set pen width.""" self.pen_width = max(1, width) def get_zoom_percentage(self) -> int: """Get current zoom level as percentage.""" return int(self.zoom_scale * 100) def zoom_in(self): """Zoom in on the image.""" if self.original_pixmap is None: return new_scale = self.zoom_scale + self.zoom_step if new_scale <= self.zoom_max: self.zoom_scale = new_scale self._apply_zoom() def zoom_out(self): """Zoom out from the image.""" if self.original_pixmap is None: return new_scale = self.zoom_scale - self.zoom_step if new_scale >= self.zoom_min: self.zoom_scale = new_scale self._apply_zoom() def reset_zoom(self): """Reset zoom to 100%.""" if self.original_pixmap is None: return self.zoom_scale = 1.0 self._apply_zoom() def _canvas_to_image_coords(self, pos: QPoint) -> Optional[Tuple[int, int]]: """Convert canvas coordinates to image coordinates, accounting for zoom and centering.""" if self.original_pixmap is None or self.canvas_label.pixmap() is None: return None # Get the displayed pixmap size (after zoom) displayed_pixmap = self.canvas_label.pixmap() displayed_width = displayed_pixmap.width() displayed_height = displayed_pixmap.height() # Calculate offset due to label centering (label might be larger than pixmap) label_width = self.canvas_label.width() label_height = self.canvas_label.height() offset_x = max(0, (label_width - displayed_width) // 2) offset_y = max(0, (label_height - displayed_height) // 2) # Adjust position for offset and convert to image coordinates x = (pos.x() - offset_x) / self.zoom_scale y = (pos.y() - offset_y) / self.zoom_scale # Check bounds if ( 0 <= x < self.original_pixmap.width() and 0 <= y < self.original_pixmap.height() ): return (int(x), int(y)) return None def _image_to_normalized_coords(self, x: int, y: int) -> Tuple[float, float]: """Convert image coordinates to normalized coordinates (0-1).""" if self.original_pixmap is None: return (0.0, 0.0) norm_x = x / self.original_pixmap.width() norm_y = y / self.original_pixmap.height() return (norm_x, norm_y) def mousePressEvent(self, event: QMouseEvent): """Handle mouse press events for drawing.""" if not self.pen_enabled or self.annotation_pixmap is None: super().mousePressEvent(event) return if event.button() == Qt.LeftButton: # Get accurate position using global coordinates label_pos = self.canvas_label.mapFromGlobal(event.globalPos()) img_coords = self._canvas_to_image_coords(label_pos) if img_coords: self.is_drawing = True self.current_stroke = [img_coords] def mouseMoveEvent(self, event: QMouseEvent): """Handle mouse move events for drawing.""" if ( not self.is_drawing or not self.pen_enabled or self.annotation_pixmap is None ): super().mouseMoveEvent(event) return # Get accurate position using global coordinates label_pos = self.canvas_label.mapFromGlobal(event.globalPos()) img_coords = self._canvas_to_image_coords(label_pos) if img_coords and len(self.current_stroke) > 0: # Draw line from last point to current point painter = QPainter(self.annotation_pixmap) pen = QPen( self.pen_color, self.pen_width, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin ) painter.setPen(pen) last_point = self.current_stroke[-1] painter.drawLine(last_point[0], last_point[1], img_coords[0], img_coords[1]) painter.end() self.current_stroke.append(img_coords) self._update_display() def mouseReleaseEvent(self, event: QMouseEvent): """Handle mouse release events to complete a stroke.""" if not self.is_drawing or event.button() != Qt.LeftButton: super().mouseReleaseEvent(event) return self.is_drawing = False if len(self.current_stroke) > 1: # Convert to normalized coordinates and save stroke normalized_stroke = [ self._image_to_normalized_coords(x, y) for x, y in self.current_stroke ] self.all_strokes.append( { "points": normalized_stroke, "color": self.pen_color.name(), "alpha": self.pen_color.alpha(), "width": self.pen_width, } ) # Emit signal with normalized coordinates self.annotation_drawn.emit(normalized_stroke) logger.debug(f"Completed stroke with {len(normalized_stroke)} points") self.current_stroke = [] def get_all_strokes(self) -> List[dict]: """Get all drawn strokes with metadata.""" return self.all_strokes def compute_annotation_bounds(self) -> Optional[Tuple[float, float, float, float]]: """ Compute bounding box that encompasses all annotation strokes. Returns: Tuple of (x_min, y_min, x_max, y_max) in normalized coordinates (0-1), or None if no annotations exist. """ if not self.all_strokes: return None # Find min/max across all strokes all_x = [] all_y = [] for stroke in self.all_strokes: for x, y in stroke["points"]: all_x.append(x) all_y.append(y) if not all_x: return None x_min = min(all_x) y_min = min(all_y) x_max = max(all_x) y_max = max(all_y) return (x_min, y_min, x_max, y_max) def get_annotation_polyline(self) -> List[List[float]]: """ Get polyline coordinates representing all annotation strokes. Returns: List of [x, y] coordinate pairs in normalized coordinates (0-1). """ polyline = [] for stroke in self.all_strokes: polyline.extend(stroke["points"]) return polyline def draw_saved_polyline( self, polyline: List[List[float]], color: str, width: int = 3 ): """ Draw a polyline from database coordinates onto the annotation canvas. Args: polyline: List of [x, y] coordinate pairs in normalized coordinates (0-1) color: Color hex string (e.g., '#FF0000') width: Line width in pixels """ if not self.annotation_pixmap or not self.original_pixmap: logger.warning("Cannot draw polyline: no image loaded") return if len(polyline) < 2: logger.warning("Polyline has less than 2 points, cannot draw") return # Convert normalized coordinates to image coordinates img_coords = [] for x_norm, y_norm in polyline: x = int(x_norm * self.original_pixmap.width()) y = int(y_norm * self.original_pixmap.height()) img_coords.append((x, y)) # Draw polyline on annotation pixmap painter = QPainter(self.annotation_pixmap) pen_color = QColor(color) pen_color.setAlpha(128) # Add semi-transparency pen = QPen(pen_color, width, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin) painter.setPen(pen) # Draw lines between consecutive points for i in range(len(img_coords) - 1): x1, y1 = img_coords[i] x2, y2 = img_coords[i + 1] painter.drawLine(x1, y1, x2, y2) painter.end() # Store in all_strokes for consistency self.all_strokes.append( {"points": polyline, "color": color, "alpha": 128, "width": width} ) # Update display self._update_display() logger.debug( f"Drew saved polyline with {len(polyline)} points in color {color}" ) def keyPressEvent(self, event: QKeyEvent): """Handle keyboard events for zooming.""" if event.key() in (Qt.Key_Plus, Qt.Key_Equal): self.zoom_in() event.accept() elif event.key() == Qt.Key_Minus: self.zoom_out() event.accept() elif event.key() == Qt.Key_0 and event.modifiers() == Qt.ControlModifier: self.reset_zoom() event.accept() else: super().keyPressEvent(event) def eventFilter(self, obj, event: QEvent) -> bool: """Event filter to capture wheel events for zooming.""" if event.type() == QEvent.Wheel: wheel_event = event if self.original_pixmap is not None: delta = wheel_event.angleDelta().y() if delta > 0: new_scale = self.zoom_scale + self.zoom_wheel_step if new_scale <= self.zoom_max: self.zoom_scale = new_scale self._apply_zoom() else: new_scale = self.zoom_scale - self.zoom_wheel_step if new_scale >= self.zoom_min: self.zoom_scale = new_scale self._apply_zoom() return True return super().eventFilter(obj, event)