diff --git a/src/gui/tabs/annotation_tab.py b/src/gui/tabs/annotation_tab.py index 6ce6d57..ae16446 100644 --- a/src/gui/tabs/annotation_tab.py +++ b/src/gui/tabs/annotation_tab.py @@ -262,9 +262,9 @@ class AnnotationTab(QWidget): ) return - # Compute bounding box and polyline from annotations - bounds = self.annotation_canvas.compute_annotation_bounds() - if not bounds: + # Compute annotation parameters asbounding boxes and polylines from annotations + parameters = self.annotation_canvas.get_annotation_parameters() + if not parameters: QMessageBox.warning( self, "No Annotations", @@ -272,48 +272,56 @@ class AnnotationTab(QWidget): ) return - polyline = self.annotation_canvas.get_annotation_polyline() + # polyline = self.annotation_canvas.get_annotation_polyline() - try: - # Save annotation to database - annotation_id = self.db_manager.add_annotation( - image_id=self.current_image_id, - class_id=current_class["id"], - bbox=bounds, - annotator="manual", - segmentation_mask=polyline, - verified=False, - ) + for param in parameters: + bounds = param["bbox"] + polyline = param["polyline"] - logger.info( - f"Saved annotation (ID: {annotation_id}) for class '{current_class['class_name']}' " - f"with {len(polyline)} polyline points" - ) + try: + # Save annotation to database + annotation_id = self.db_manager.add_annotation( + image_id=self.current_image_id, + class_id=current_class["id"], + bbox=bounds, + annotator="manual", + segmentation_mask=polyline, + verified=False, + ) - QMessageBox.information( - self, - "Success", - f"Annotation saved successfully!\n\n" - f"Class: {current_class['class_name']}\n" - f"Bounding box: ({bounds[0]:.3f}, {bounds[1]:.3f}) to ({bounds[2]:.3f}, {bounds[3]:.3f})\n" - f"Polyline points: {len(polyline)}", - ) + logger.info( + f"Saved annotation (ID: {annotation_id}) for class '{current_class['class_name']}' " + f"Bounding box: ({bounds[0]:.3f}, {bounds[1]:.3f}) to ({bounds[2]:.3f}, {bounds[3]:.3f})\n" + f"with {len(polyline)} polyline points" + ) - # Optionally clear annotations after saving - reply = QMessageBox.question( - self, - "Clear Annotations", - "Do you want to clear the annotations to start a new one?", - QMessageBox.Yes | QMessageBox.No, - QMessageBox.Yes, - ) + # QMessageBox.information( + # self, + # "Success", + # f"Annotation saved successfully!\n\n" + # f"Class: {current_class['class_name']}\n" + # f"Bounding box: ({bounds[0]:.3f}, {bounds[1]:.3f}) to ({bounds[2]:.3f}, {bounds[3]:.3f})\n" + # f"Polyline points: {len(polyline)}", + # ) - if reply == QMessageBox.Yes: - self.annotation_canvas.clear_annotations() + except Exception as e: + logger.error(f"Failed to save annotation: {e}") + QMessageBox.critical( + self, "Error", f"Failed to save annotation:\n{str(e)}" + ) - except Exception as e: - logger.error(f"Failed to save annotation: {e}") - QMessageBox.critical(self, "Error", f"Failed to save annotation:\n{str(e)}") + # Optionally clear annotations after saving + reply = QMessageBox.question( + self, + "Clear Annotations", + "Do you want to clear the annotations to start a new one?", + QMessageBox.Yes | QMessageBox.No, + QMessageBox.Yes, + ) + + if reply == QMessageBox.Yes: + self.annotation_canvas.clear_annotations() + logger.info("Cleared annotations after saving") def _on_show_annotations(self): """Load and display saved annotations from database.""" @@ -348,6 +356,11 @@ class AnnotationTab(QWidget): # Draw the polyline self.annotation_canvas.draw_saved_polyline(polyline, color, width=3) + self.annotation_canvas.draw_saved_bbox( + [ann["x_min"], ann["y_min"], ann["x_max"], ann["y_max"]], + color, + width=3, + ) drawn_count += 1 logger.info(f"Displayed {drawn_count} saved annotations from database") diff --git a/src/gui/widgets/annotation_canvas_widget.py b/src/gui/widgets/annotation_canvas_widget.py index a2a57a6..9851e97 100644 --- a/src/gui/widgets/annotation_canvas_widget.py +++ b/src/gui/widgets/annotation_canvas_widget.py @@ -3,6 +3,8 @@ Annotation canvas widget for drawing annotations on images. Supports pen tool with color selection for manual annotation. """ +import numpy as np + from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QScrollArea from PySide6.QtGui import ( QPixmap, @@ -15,12 +17,17 @@ from PySide6.QtGui import ( QPaintEvent, ) from PySide6.QtCore import Qt, QEvent, Signal, QPoint -from typing import List, Optional, Tuple -import numpy as np +from typing import Any, Dict, List, Optional, Tuple + +from scipy.ndimage import binary_dilation, label, binary_fill_holes, find_objects +from skimage.measure import find_contours from src.utils.image import Image, ImageLoadError from src.utils.logger import get_logger +# For debugging visualization +import pylab as plt + logger = get_logger(__name__) @@ -369,49 +376,149 @@ class AnnotationCanvasWidget(QWidget): """Get all drawn strokes with metadata.""" return self.all_strokes - def compute_annotation_bounds(self) -> Optional[Tuple[float, float, float, float]]: + # def get_annotation_bounds(self) -> Optional[Tuple[float, float, float, float]]: + # """ + # Compute bounding box that encompasses all annotation strokes. + + # Returns: + # Tuple of (x_min, y_min, x_max, y_max) in normalized coordinates (0-1), + # or None if no annotations exist. + # """ + # if not self.all_strokes: + # return None + + # # Find min/max across all strokes + # all_x = [] + # all_y = [] + + # for stroke in self.all_strokes: + # for x, y in stroke["points"]: + # all_x.append(x) + # all_y.append(y) + + # if not all_x: + # return None + + # x_min = min(all_x) + # y_min = min(all_y) + # x_max = max(all_x) + # y_max = max(all_y) + + # return (x_min, y_min, x_max, y_max) + + # def get_annotation_polyline(self) -> List[List[float]]: + # """ + # Get polyline coordinates representing all annotation strokes. + + # Returns: + # List of [x, y] coordinate pairs in normalized coordinates (0-1). + # """ + # polyline = [] + + # fig = plt.figure() + # ax1 = fig.add_subplot(411) + # ax2 = fig.add_subplot(412) + # ax3 = fig.add_subplot(413) + # ax4 = fig.add_subplot(414) + + # # Get np.arrays from annotation_pixmap accoriding to the color of the stroke + # qimage = self.annotation_pixmap.toImage() + # arr = np.ndarray( + # (qimage.height(), qimage.width(), 4), + # buffer=qimage.constBits(), + # strides=[qimage.bytesPerLine(), 4, 1], + # dtype=np.uint8, + # ) + # print(arr.shape, arr.dtype, arr.min(), arr.max()) + # arr = np.sum(arr, axis=2) + # ax1.imshow(arr) + + # arr_bin = arr > 0 + # ax2.imshow(arr_bin) + + # arr_bin = binary_fill_holes(arr_bin) + # ax3.imshow(arr_bin) + + # labels, _number_of_features = label( + # arr_bin, + # ) + + # ax4.imshow(labels) + + # objects = find_objects(labels) + # bounding_boxes = np.array( + # [[obj[0].start, obj[0].stop, obj[1].start, obj[1].stop] for obj in objects] + # ) / np.array([arr.shape[0], arr.shape[1]]) + + # print(objects) + # print(bounding_boxes) + # print(np.array([arr.shape[0], arr.shape[1]])) + + # polylines = find_contours(arr_bin, 0.5) + # for pl in polylines: + # ax1.plot(pl[:, 1], pl[:, 0], "k") + + # print(arr.shape, arr.dtype, arr.min(), arr.max()) + + # plt.show() + + # return polyline + + def get_annotation_parameters(self) -> Dict[str, Any]: """ - Compute bounding box that encompasses all annotation strokes. + Get all annotation parameters including bounding box and polyline. Returns: - Tuple of (x_min, y_min, x_max, y_max) in normalized coordinates (0-1), - or None if no annotations exist. + Dictionary containing: + - 'bbox': Bounding box coordinates (x_min, y_min, x_max, y_max) + - 'polyline': List of [x, y] coordinate pairs """ - if not self.all_strokes: + + # Get np.arrays from annotation_pixmap accoriding to the color of the stroke + qimage = self.annotation_pixmap.toImage() + arr = np.ndarray( + (qimage.height(), qimage.width(), 4), + buffer=qimage.constBits(), + strides=[qimage.bytesPerLine(), 4, 1], + dtype=np.uint8, + ) + arr = np.sum(arr, axis=2) + arr_bin = arr > 0 + arr_bin = binary_fill_holes(arr_bin) + + labels, _number_of_features = label( + arr_bin, + ) + if _number_of_features == 0: return None - # Find min/max across all strokes - all_x = [] - all_y = [] + objects = find_objects(labels) + w, h = arr.shape + bounding_boxes = [ + [obj[0].start / w, obj[1].start / h, obj[0].stop / w, obj[1].stop / h] + for obj in objects + ] - for stroke in self.all_strokes: - for x, y in stroke["points"]: - all_x.append(x) - all_y.append(y) + polylines = find_contours(arr_bin, 0.5) + params = [] + for i, pl in enumerate(polylines): + # pl is in [row, col] format from find_contours + # We need to normalize: row/height, col/width + # w = height (rows), h = width (cols) from line 510 + normalized_polyline = (pl[::-1] / np.array([w, h])).tolist() - if not all_x: - return None + logger.debug(f"Polyline {i}: {len(pl)} points") + logger.debug(f" w={w} (height), h={h} (width)") + logger.debug(f" First 3 normalized points: {normalized_polyline[:3]}") - x_min = min(all_x) - y_min = min(all_y) - x_max = max(all_x) - y_max = max(all_y) + params.append( + { + "bbox": bounding_boxes[i], + "polyline": normalized_polyline, + } + ) - return (x_min, y_min, x_max, y_max) - - def get_annotation_polyline(self) -> List[List[float]]: - """ - Get polyline coordinates representing all annotation strokes. - - Returns: - List of [x, y] coordinate pairs in normalized coordinates (0-1). - """ - polyline = [] - - for stroke in self.all_strokes: - polyline.extend(stroke["points"]) - - return polyline + return params def draw_saved_polyline( self, polyline: List[List[float]], color: str, width: int = 3 @@ -433,12 +540,22 @@ class AnnotationCanvasWidget(QWidget): return # Convert normalized coordinates to image coordinates + # Polyline is stored as [[y_norm, x_norm], ...] (row_norm, col_norm format) + img_width = self.original_pixmap.width() + img_height = self.original_pixmap.height() + + logger.debug(f"Loading polyline with {len(polyline)} points") + logger.debug(f" Image size: {img_width}x{img_height}") + logger.debug(f" First 3 normalized points from DB: {polyline[:3]}") + img_coords = [] - for x_norm, y_norm in polyline: - x = int(x_norm * self.original_pixmap.width()) - y = int(y_norm * self.original_pixmap.height()) + for y_norm, x_norm in polyline: + x = int(x_norm * img_width) + y = int(y_norm * img_height) img_coords.append((x, y)) + logger.debug(f" First 3 pixel coords: {img_coords[:3]}") + # Draw polyline on annotation pixmap painter = QPainter(self.annotation_pixmap) pen_color = QColor(color) @@ -465,6 +582,64 @@ class AnnotationCanvasWidget(QWidget): f"Drew saved polyline with {len(polyline)} points in color {color}" ) + def draw_saved_bbox(self, bbox: List[float], color: str, width: int = 3): + """ + Draw a bounding box from database coordinates onto the annotation canvas. + + Args: + bbox: Bounding box as [y_min_norm, x_min_norm, y_max_norm, x_max_norm] + in normalized coordinates (0-1) + color: Color hex string (e.g., '#FF0000') + width: Line width in pixels + """ + if not self.annotation_pixmap or not self.original_pixmap: + logger.warning("Cannot draw bounding box: no image loaded") + return + + if len(bbox) != 4: + logger.warning( + f"Invalid bounding box format: expected 4 values, got {len(bbox)}" + ) + return + + # Convert normalized coordinates to image coordinates + # bbox format: [y_min_norm, x_min_norm, y_max_norm, x_max_norm] + img_width = self.original_pixmap.width() + img_height = self.original_pixmap.height() + + y_min_norm, x_min_norm, y_max_norm, x_max_norm = bbox + x_min = int(x_min_norm * img_width) + y_min = int(y_min_norm * img_height) + x_max = int(x_max_norm * img_width) + y_max = int(y_max_norm * img_height) + + logger.debug(f"Drawing bounding box: {bbox}") + logger.debug(f" Image size: {img_width}x{img_height}") + logger.debug(f" Pixel coords: ({x_min}, {y_min}) to ({x_max}, {y_max})") + + # Draw bounding box on annotation pixmap + painter = QPainter(self.annotation_pixmap) + pen_color = QColor(color) + pen_color.setAlpha(128) # Add semi-transparency + pen = QPen(pen_color, width, Qt.SolidLine, Qt.SquareCap, Qt.MiterJoin) + painter.setPen(pen) + + # Draw rectangle + rect_width = x_max - x_min + rect_height = y_max - y_min + painter.drawRect(x_min, y_min, rect_width, rect_height) + + painter.end() + + # Store in all_strokes for consistency + self.all_strokes.append( + {"bbox": bbox, "color": color, "alpha": 128, "width": width} + ) + + # Update display + self._update_display() + logger.debug(f"Drew saved bounding box in color {color}") + def keyPressEvent(self, event: QKeyEvent): """Handle keyboard events for zooming.""" if event.key() in (Qt.Key_Plus, Qt.Key_Equal):