27 Commits

Author SHA1 Message Date
e5036c10cf Small fix 2025-12-16 18:03:56 +02:00
c7e388d9ae Updating progressbar 2025-12-16 17:20:25 +02:00
6b995e7325 upate 2025-12-16 13:24:20 +02:00
0e0741d323 Update on convert_grayscale_to_rgb_preserve_range, making it class method 2025-12-16 12:37:34 +02:00
dd99a0677c Updating image converter and aading simple script to visulaize segmentation 2025-12-16 11:27:38 +02:00
9c4c39fb39 Adding image converter 2025-12-12 23:52:34 +02:00
20a87c9040 Updating config 2025-12-12 21:51:12 +02:00
9f7d2be1ac Updating the base model preset 2025-12-11 23:27:02 +02:00
dbde07c0e8 Making training tab scrollable 2025-12-11 23:12:39 +02:00
b3c5a51dbb Using QPolygonF instead of drawLine 2025-12-11 17:14:07 +02:00
9a221acb63 Making image manipulations thru one class 2025-12-11 16:59:56 +02:00
32a6a122bd Fixing circular import 2025-12-11 16:06:39 +02:00
9ba44043ef Defining image extensions only in one place 2025-12-11 15:50:14 +02:00
8eb1cc8c86 Fixing grayscale conversion 2025-12-11 15:15:38 +02:00
e4ce882a18 Grayscale RGB conversion modified 2025-12-11 15:06:59 +02:00
6b6d6fad03 2Stage training fix 2025-12-11 12:50:34 +02:00
c0684a9c14 Implementing 2 stage training 2025-12-11 12:04:08 +02:00
221c80aa8c Small image showing fix 2025-12-11 11:20:20 +02:00
833b222fad Adding result shower 2025-12-10 16:55:28 +02:00
5370d31dce Merge pull request 'Update training' (#2) from training into main
Reviewed-on: #2
2025-12-10 15:47:00 +02:00
5d196c3a4a Update training 2025-12-10 15:46:26 +02:00
f719c7ec40 Merge pull request 'segmentation' (#1) from segmentation into main
Reviewed-on: #1
2025-12-10 12:08:54 +02:00
e6a5e74fa1 Adding feature to remove annotations 2025-12-10 00:19:59 +02:00
35e2398e95 Fixing bounding box drawing 2025-12-09 23:56:29 +02:00
c3d44ac945 Renaming Pen tool to polyline tool 2025-12-09 23:38:23 +02:00
dad5c2bf74 Updating 2025-12-09 22:44:23 +02:00
73cb698488 Saving state before replacing annotation tool 2025-12-09 22:00:56 +02:00
18 changed files with 3985 additions and 573 deletions

View File

@@ -12,12 +12,28 @@ image_repository:
models: models:
default_base_model: yolov8s-seg.pt default_base_model: yolov8s-seg.pt
models_directory: data/models models_directory: data/models
base_model_choices:
- yolov8s-seg.pt
- yolo11s-seg.pt
training: training:
default_epochs: 100 default_epochs: 100
default_batch_size: 16 default_batch_size: 16
default_imgsz: 640 default_imgsz: 1024
default_patience: 50 default_patience: 50
default_lr0: 0.01 default_lr0: 0.01
two_stage:
enabled: false
stage1:
epochs: 20
lr0: 0.0005
patience: 10
freeze: 10
stage2:
epochs: 150
lr0: 0.0003
patience: 30
last_dataset_yaml: /home/martin/code/object_detection/data/datasets/data.yaml
last_dataset_dir: /home/martin/code/object_detection/data/datasets
detection: detection:
default_confidence: 0.25 default_confidence: 0.25
default_iou: 0.45 default_iou: 0.45

View File

@@ -10,6 +10,14 @@ from typing import List, Dict, Optional, Tuple, Any, Union
from pathlib import Path from pathlib import Path
import csv import csv
import hashlib import hashlib
import yaml
from src.utils.logger import get_logger
from src.utils.image import Image
IMAGE_EXTENSIONS = tuple(Image.SUPPORTED_EXTENSIONS)
logger = get_logger(__name__)
class DatabaseManager: class DatabaseManager:
@@ -443,6 +451,25 @@ class DatabaseManager:
filters["model_id"] = model_id filters["model_id"] = model_id
return self.get_detections(filters) return self.get_detections(filters)
def delete_detections_for_image(
self, image_id: int, model_id: Optional[int] = None
) -> int:
"""Delete detections tied to a specific image and optional model."""
conn = self.get_connection()
try:
cursor = conn.cursor()
if model_id is not None:
cursor.execute(
"DELETE FROM detections WHERE image_id = ? AND model_id = ?",
(image_id, model_id),
)
else:
cursor.execute("DELETE FROM detections WHERE image_id = ?", (image_id,))
conn.commit()
return cursor.rowcount
finally:
conn.close()
def delete_detections_for_model(self, model_id: int) -> int: def delete_detections_for_model(self, model_id: int) -> int:
"""Delete all detections for a specific model.""" """Delete all detections for a specific model."""
conn = self.get_connection() conn = self.get_connection()
@@ -706,6 +733,25 @@ class DatabaseManager:
finally: finally:
conn.close() conn.close()
def delete_annotation(self, annotation_id: int) -> bool:
"""
Delete a manual annotation by ID.
Args:
annotation_id: ID of the annotation to delete
Returns:
True if an annotation was deleted, False otherwise.
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("DELETE FROM annotations WHERE id = ?", (annotation_id,))
conn.commit()
return cursor.rowcount > 0
finally:
conn.close()
# ==================== Object Class Operations ==================== # ==================== Object Class Operations ====================
def get_object_classes(self) -> List[Dict]: def get_object_classes(self) -> List[Dict]:
@@ -842,6 +888,187 @@ class DatabaseManager:
finally: finally:
conn.close() conn.close()
# ==================== Dataset Utilities ====================
def compose_data_yaml(
self,
dataset_root: str,
output_path: Optional[str] = None,
splits: Optional[Dict[str, str]] = None,
) -> str:
"""
Compose a YOLO data.yaml file based on dataset folders and database metadata.
Args:
dataset_root: Base directory containing the dataset structure.
output_path: Optional output path; defaults to <dataset_root>/data.yaml.
splits: Optional mapping overriding train/val/test image directories (relative
to dataset_root or absolute paths).
Returns:
Path to the generated YAML file.
"""
dataset_root_path = Path(dataset_root).expanduser()
if not dataset_root_path.exists():
raise ValueError(f"Dataset root does not exist: {dataset_root_path}")
dataset_root_path = dataset_root_path.resolve()
split_map: Dict[str, str] = {key: "" for key in ("train", "val", "test")}
if splits:
for key, value in splits.items():
if key in split_map and value:
split_map[key] = value
inferred = self._infer_split_dirs(dataset_root_path)
for key in split_map:
if not split_map[key]:
split_map[key] = inferred.get(key, "")
for required in ("train", "val"):
if not split_map[required]:
raise ValueError(
"Unable to determine %s image directory under %s. Provide it "
"explicitly via the 'splits' argument."
% (required, dataset_root_path)
)
yaml_splits: Dict[str, str] = {}
for key, value in split_map.items():
if not value:
continue
yaml_splits[key] = self._normalize_split_value(value, dataset_root_path)
class_names = self._fetch_annotation_class_names()
if not class_names:
class_names = [cls["class_name"] for cls in self.get_object_classes()]
if not class_names:
raise ValueError("No object classes available to populate data.yaml")
names_map = {idx: name for idx, name in enumerate(class_names)}
payload: Dict[str, Any] = {
"path": dataset_root_path.as_posix(),
"train": yaml_splits["train"],
"val": yaml_splits["val"],
"names": names_map,
"nc": len(class_names),
}
if yaml_splits.get("test"):
payload["test"] = yaml_splits["test"]
output_path_obj = (
Path(output_path).expanduser()
if output_path
else dataset_root_path / "data.yaml"
)
output_path_obj.parent.mkdir(parents=True, exist_ok=True)
with open(output_path_obj, "w", encoding="utf-8") as handle:
yaml.safe_dump(payload, handle, sort_keys=False)
logger.info(f"Generated data.yaml at {output_path_obj}")
return output_path_obj.as_posix()
def _fetch_annotation_class_names(self) -> List[str]:
"""Return class names referenced by annotations (ordered by class ID)."""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"""
SELECT DISTINCT c.id, c.class_name
FROM annotations a
JOIN object_classes c ON a.class_id = c.id
ORDER BY c.id
"""
)
rows = cursor.fetchall()
return [row["class_name"] for row in rows]
finally:
conn.close()
def _infer_split_dirs(self, dataset_root: Path) -> Dict[str, str]:
"""Infer train/val/test image directories relative to dataset_root."""
patterns = {
"train": [
"train/images",
"training/images",
"images/train",
"images/training",
"train",
"training",
],
"val": [
"val/images",
"validation/images",
"images/val",
"images/validation",
"val",
"validation",
],
"test": [
"test/images",
"testing/images",
"images/test",
"images/testing",
"test",
"testing",
],
}
inferred: Dict[str, str] = {key: "" for key in patterns}
for split_name, options in patterns.items():
for relative in options:
candidate = (dataset_root / relative).resolve()
if (
candidate.exists()
and candidate.is_dir()
and self._directory_has_images(candidate)
):
try:
inferred[split_name] = candidate.relative_to(
dataset_root
).as_posix()
except ValueError:
inferred[split_name] = candidate.as_posix()
break
return inferred
def _normalize_split_value(self, split_value: str, dataset_root: Path) -> str:
"""Validate and normalize a split directory to a YAML-friendly string."""
split_path = Path(split_value).expanduser()
if not split_path.is_absolute():
split_path = (dataset_root / split_path).resolve()
else:
split_path = split_path.resolve()
if not split_path.exists() or not split_path.is_dir():
raise ValueError(f"Split directory not found: {split_path}")
if not self._directory_has_images(split_path):
raise ValueError(f"No images found under {split_path}")
try:
return split_path.relative_to(dataset_root).as_posix()
except ValueError:
return split_path.as_posix()
@staticmethod
def _directory_has_images(directory: Path, max_checks: int = 2000) -> bool:
"""Return True if directory tree contains at least one image file."""
checked = 0
try:
for file_path in directory.rglob("*"):
if not file_path.is_file():
continue
if file_path.suffix.lower() in IMAGE_EXTENSIONS:
return True
checked += 1
if checked >= max_checks:
break
except Exception:
return False
return False
@staticmethod @staticmethod
def calculate_checksum(file_path: str) -> str: def calculate_checksum(file_path: str) -> str:
"""Calculate MD5 checksum of a file.""" """Calculate MD5 checksum of a file."""

View File

@@ -297,7 +297,9 @@ class MainWindow(QMainWindow):
# Save window state before closing # Save window state before closing
self._save_window_state() self._save_window_state()
# Save annotation tab state if it exists # Persist tab state and stop background work before exit
if hasattr(self, "training_tab"):
self.training_tab.shutdown()
if hasattr(self, "annotation_tab"): if hasattr(self, "annotation_tab"):
self.annotation_tab.save_state() self.annotation_tab.save_state()

View File

@@ -38,6 +38,9 @@ class AnnotationTab(QWidget):
self.current_image = None self.current_image = None
self.current_image_path = None self.current_image_path = None
self.current_image_id = None self.current_image_id = None
self.current_annotations = []
# IDs of annotations currently selected on the canvas (multi-select)
self.selected_annotation_ids = []
self._setup_ui() self._setup_ui()
@@ -61,6 +64,8 @@ class AnnotationTab(QWidget):
self.annotation_canvas = AnnotationCanvasWidget() self.annotation_canvas = AnnotationCanvasWidget()
self.annotation_canvas.zoom_changed.connect(self._on_zoom_changed) self.annotation_canvas.zoom_changed.connect(self._on_zoom_changed)
self.annotation_canvas.annotation_drawn.connect(self._on_annotation_drawn) self.annotation_canvas.annotation_drawn.connect(self._on_annotation_drawn)
# Selection of existing polylines (when tool is not in drawing mode)
self.annotation_canvas.annotation_selected.connect(self._on_annotation_selected)
canvas_layout.addWidget(self.annotation_canvas) canvas_layout.addWidget(self.annotation_canvas)
canvas_group.setLayout(canvas_layout) canvas_group.setLayout(canvas_layout)
@@ -80,24 +85,35 @@ class AnnotationTab(QWidget):
# Annotation tools section # Annotation tools section
self.annotation_tools = AnnotationToolsWidget(self.db_manager) self.annotation_tools = AnnotationToolsWidget(self.db_manager)
self.annotation_tools.pen_enabled_changed.connect( self.annotation_tools.polyline_enabled_changed.connect(
self.annotation_canvas.set_pen_enabled self.annotation_canvas.set_polyline_enabled
) )
self.annotation_tools.pen_color_changed.connect( self.annotation_tools.polyline_pen_color_changed.connect(
self.annotation_canvas.set_pen_color self.annotation_canvas.set_polyline_pen_color
) )
self.annotation_tools.pen_width_changed.connect( self.annotation_tools.polyline_pen_width_changed.connect(
self.annotation_canvas.set_pen_width self.annotation_canvas.set_polyline_pen_width
) )
# Show / hide bounding boxes
self.annotation_tools.show_bboxes_changed.connect(
self.annotation_canvas.set_show_bboxes
)
# RDP simplification controls
self.annotation_tools.simplify_on_finish_changed.connect(
self._on_simplify_on_finish_changed
)
self.annotation_tools.simplify_epsilon_changed.connect(
self._on_simplify_epsilon_changed
)
# Class selection and class-color changes
self.annotation_tools.class_selected.connect(self._on_class_selected) self.annotation_tools.class_selected.connect(self._on_class_selected)
self.annotation_tools.class_color_changed.connect(self._on_class_color_changed)
self.annotation_tools.clear_annotations_requested.connect( self.annotation_tools.clear_annotations_requested.connect(
self._on_clear_annotations self._on_clear_annotations
) )
self.annotation_tools.process_annotations_requested.connect( # Delete selected annotation on canvas
self._on_process_annotations self.annotation_tools.delete_selected_annotation_requested.connect(
) self._on_delete_selected_annotation
self.annotation_tools.show_annotations_requested.connect(
self._on_show_annotations
) )
self.right_splitter.addWidget(self.annotation_tools) self.right_splitter.addWidget(self.annotation_tools)
@@ -152,7 +168,7 @@ class AnnotationTab(QWidget):
self, self,
"Select Image", "Select Image",
start_dir, start_dir,
"Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)", "Images (*" + " *".join(Image.SUPPORTED_EXTENSIONS) + ")",
) )
if not file_path: if not file_path:
@@ -180,6 +196,9 @@ class AnnotationTab(QWidget):
# Display image using the AnnotationCanvasWidget # Display image using the AnnotationCanvasWidget
self.annotation_canvas.load_image(self.current_image) self.annotation_canvas.load_image(self.current_image)
# Load and display any existing annotations for this image
self._load_annotations_for_current_image()
# Update info label # Update info label
self._update_image_info() self._update_image_info()
@@ -217,7 +236,22 @@ class AnnotationTab(QWidget):
self._update_image_info() self._update_image_info()
def _on_annotation_drawn(self, points: list): def _on_annotation_drawn(self, points: list):
"""Handle when an annotation stroke is drawn.""" """
Handle when an annotation stroke is drawn.
Saves the new annotation directly to the database and refreshes the
on-canvas display of annotations for the current image.
"""
# Ensure we have an image loaded and in the DB
if not self.current_image or not self.current_image_id:
logger.warning("Annotation drawn but no image loaded")
QMessageBox.warning(
self,
"No Image",
"Please load an image before drawing annotations.",
)
return
current_class = self.annotation_tools.get_current_class() current_class = self.annotation_tools.get_current_class()
if not current_class: if not current_class:
@@ -229,153 +263,260 @@ class AnnotationTab(QWidget):
) )
return return
logger.info( if not points:
f"Annotation drawn with {len(points)} points for class: {current_class['class_name']}" logger.warning("Annotation drawn with no points, ignoring")
) return
# Future: Save annotation to database or export
def _on_class_selected(self, class_data: dict): # points are [(x_norm, y_norm), ...]
"""Handle when an object class is selected.""" xs = [p[0] for p in points]
logger.debug(f"Object class selected: {class_data['class_name']}") ys = [p[1] for p in points]
x_min, x_max = min(xs), max(xs)
y_min, y_max = min(ys), max(ys)
# Store segmentation mask in [y_norm, x_norm] format to match DB
db_polyline = [[float(y), float(x)] for (x, y) in points]
try:
annotation_id = self.db_manager.add_annotation(
image_id=self.current_image_id,
class_id=current_class["id"],
bbox=(x_min, y_min, x_max, y_max),
annotator="manual",
segmentation_mask=db_polyline,
verified=False,
)
logger.info(
f"Saved annotation (ID: {annotation_id}) for class "
f"'{current_class['class_name']}' "
f"Bounding box: ({x_min:.3f}, {y_min:.3f}) to ({x_max:.3f}, {y_max:.3f})\n"
f"with {len(points)} polyline points"
)
# Reload annotations from DB and redraw (respecting current class filter)
self._load_annotations_for_current_image()
except Exception as e:
logger.error(f"Failed to save annotation: {e}")
QMessageBox.critical(self, "Error", f"Failed to save annotation:\n{str(e)}")
def _on_annotation_selected(self, annotation_ids):
"""
Handle selection of existing annotations on the canvas.
Args:
annotation_ids: List of selected annotation IDs, or None/empty if cleared.
"""
if not annotation_ids:
self.selected_annotation_ids = []
self.annotation_tools.set_has_selected_annotation(False)
logger.debug("Annotation selection cleared on canvas")
return
# Normalize to a unique, sorted list of integer IDs
ids = sorted({int(aid) for aid in annotation_ids if isinstance(aid, int)})
self.selected_annotation_ids = ids
self.annotation_tools.set_has_selected_annotation(bool(ids))
logger.debug(f"Annotations selected on canvas: IDs={ids}")
def _on_simplify_on_finish_changed(self, enabled: bool):
"""Update canvas simplify-on-finish flag from tools widget."""
self.annotation_canvas.simplify_on_finish = enabled
logger.debug(f"Annotation simplification on finish set to {enabled}")
def _on_simplify_epsilon_changed(self, epsilon: float):
"""Update canvas RDP epsilon from tools widget."""
self.annotation_canvas.simplify_epsilon = float(epsilon)
logger.debug(f"Annotation simplification epsilon set to {epsilon}")
def _on_class_color_changed(self):
"""
Handle changes to the selected object's class color.
When the user updates a class color in the tools widget, reload the
annotations for the current image so that all polylines are redrawn
using the updated per-class colors.
"""
if not self.current_image_id:
return
logger.debug(
f"Class color changed; reloading annotations for image ID {self.current_image_id}"
)
self._load_annotations_for_current_image()
def _on_class_selected(self, class_data):
"""
Handle when an object class is selected or cleared.
When a specific class is selected, only annotations of that class are drawn.
When the selection is cleared ("-- Select Class --"), all annotations are shown.
"""
if class_data:
logger.debug(f"Object class selected: {class_data['class_name']}")
else:
logger.debug(
'No class selected ("-- Select Class --"), showing all annotations'
)
# Changing the class filter invalidates any previous selection
self.selected_annotation_ids = []
self.annotation_tools.set_has_selected_annotation(False)
# Whenever the selection changes, update which annotations are visible
self._redraw_annotations_for_current_filter()
def _on_clear_annotations(self): def _on_clear_annotations(self):
"""Handle clearing all annotations.""" """Handle clearing all annotations."""
self.annotation_canvas.clear_annotations() self.annotation_canvas.clear_annotations()
# Clear in-memory state and selection, but keep DB entries unchanged
self.current_annotations = []
self.selected_annotation_ids = []
self.annotation_tools.set_has_selected_annotation(False)
logger.info("Cleared all annotations") logger.info("Cleared all annotations")
def _on_process_annotations(self): def _on_delete_selected_annotation(self):
"""Process annotations and save to database.""" """Handle deleting the currently selected annotation(s) (if any)."""
# Check if we have an image loaded if not self.selected_annotation_ids:
if not self.current_image or not self.current_image_id: QMessageBox.information(
QMessageBox.warning(
self, "No Image", "Please load an image before processing annotations."
)
return
# Get current class
current_class = self.annotation_tools.get_current_class()
if not current_class:
QMessageBox.warning(
self, self,
"No Class Selected", "No Selection",
"Please select an object class before processing annotations.", "No annotation is currently selected.",
) )
return return
# Compute annotation parameters asbounding boxes and polylines from annotations count = len(self.selected_annotation_ids)
parameters = self.annotation_canvas.get_annotation_parameters() if count == 1:
if not parameters: question = "Are you sure you want to delete the selected annotation?"
QMessageBox.warning( title = "Delete Annotation"
self, else:
"No Annotations", question = (
"Please draw some annotations before processing.", f"Are you sure you want to delete the {count} selected annotations?"
) )
return title = "Delete Annotations"
# polyline = self.annotation_canvas.get_annotation_polyline()
for param in parameters:
bounds = param["bbox"]
polyline = param["polyline"]
try:
# Save annotation to database
annotation_id = self.db_manager.add_annotation(
image_id=self.current_image_id,
class_id=current_class["id"],
bbox=bounds,
annotator="manual",
segmentation_mask=polyline,
verified=False,
)
logger.info(
f"Saved annotation (ID: {annotation_id}) for class '{current_class['class_name']}' "
f"Bounding box: ({bounds[0]:.3f}, {bounds[1]:.3f}) to ({bounds[2]:.3f}, {bounds[3]:.3f})\n"
f"with {len(polyline)} polyline points"
)
# QMessageBox.information(
# self,
# "Success",
# f"Annotation saved successfully!\n\n"
# f"Class: {current_class['class_name']}\n"
# f"Bounding box: ({bounds[0]:.3f}, {bounds[1]:.3f}) to ({bounds[2]:.3f}, {bounds[3]:.3f})\n"
# f"Polyline points: {len(polyline)}",
# )
except Exception as e:
logger.error(f"Failed to save annotation: {e}")
QMessageBox.critical(
self, "Error", f"Failed to save annotation:\n{str(e)}"
)
# Optionally clear annotations after saving
reply = QMessageBox.question( reply = QMessageBox.question(
self, self,
"Clear Annotations", title,
"Do you want to clear the annotations to start a new one?", question,
QMessageBox.Yes | QMessageBox.No, QMessageBox.Yes | QMessageBox.No,
QMessageBox.Yes, QMessageBox.No,
) )
if reply != QMessageBox.Yes:
return
if reply == QMessageBox.Yes: failed_ids = []
self.annotation_canvas.clear_annotations() try:
logger.info("Cleared annotations after saving") for ann_id in self.selected_annotation_ids:
try:
deleted = self.db_manager.delete_annotation(ann_id)
if not deleted:
failed_ids.append(ann_id)
except Exception as e:
logger.error(f"Failed to delete annotation ID {ann_id}: {e}")
failed_ids.append(ann_id)
def _on_show_annotations(self): if failed_ids:
"""Load and display saved annotations from database.""" QMessageBox.warning(
# Check if we have an image loaded self,
if not self.current_image or not self.current_image_id: "Partial Failure",
QMessageBox.warning( "Some annotations could not be deleted:\n"
self, "No Image", "Please load an image to view its annotations." + ", ".join(str(a) for a in failed_ids),
)
else:
logger.info(
f"Deleted {count} annotation(s): "
+ ", ".join(str(a) for a in self.selected_annotation_ids)
)
# Clear selection and reload annotations for the current image from DB
self.selected_annotation_ids = []
self.annotation_tools.set_has_selected_annotation(False)
self._load_annotations_for_current_image()
except Exception as e:
logger.error(f"Failed to delete annotations: {e}")
QMessageBox.critical(
self,
"Error",
f"Failed to delete annotations:\n{str(e)}",
) )
def _load_annotations_for_current_image(self):
"""
Load all annotations for the current image from the database and
redraw them on the canvas, honoring the currently selected class
filter (if any).
"""
if not self.current_image_id:
self.current_annotations = []
self.annotation_canvas.clear_annotations()
self.selected_annotation_ids = []
self.annotation_tools.set_has_selected_annotation(False)
return return
try: try:
# Clear current annotations self.current_annotations = self.db_manager.get_annotations_for_image(
self.annotation_canvas.clear_annotations()
# Retrieve annotations from database
annotations = self.db_manager.get_annotations_for_image(
self.current_image_id self.current_image_id
) )
# New annotations loaded; reset any selection
if not annotations: self.selected_annotation_ids = []
QMessageBox.information( self.annotation_tools.set_has_selected_annotation(False)
self, "No Annotations", "No saved annotations found for this image." self._redraw_annotations_for_current_filter()
)
return
# Draw each annotation's polyline
drawn_count = 0
for ann in annotations:
if ann.get("segmentation_mask"):
polyline = ann["segmentation_mask"]
color = ann.get("class_color", "#FF0000")
# Draw the polyline
self.annotation_canvas.draw_saved_polyline(polyline, color, width=3)
self.annotation_canvas.draw_saved_bbox(
[ann["x_min"], ann["y_min"], ann["x_max"], ann["y_max"]],
color,
width=3,
)
drawn_count += 1
logger.info(f"Displayed {drawn_count} saved annotations from database")
QMessageBox.information(
self,
"Annotations Loaded",
f"Successfully loaded and displayed {drawn_count} annotation(s).",
)
except Exception as e: except Exception as e:
logger.error(f"Failed to load annotations: {e}") logger.error(
QMessageBox.critical( f"Failed to load annotations for image {self.current_image_id}: {e}"
self, "Error", f"Failed to load annotations:\n{str(e)}"
) )
QMessageBox.critical(
self,
"Error",
f"Failed to load annotations for this image:\n{str(e)}",
)
def _redraw_annotations_for_current_filter(self):
"""
Redraw annotations for the current image, optionally filtered by the
currently selected object class.
"""
# Clear current on-canvas annotations but keep the image
self.annotation_canvas.clear_annotations()
if not self.current_annotations:
return
current_class = self.annotation_tools.get_current_class()
selected_class_id = current_class["id"] if current_class else None
drawn_count = 0
for ann in self.current_annotations:
# Filter by class if one is selected
if (
selected_class_id is not None
and ann.get("class_id") != selected_class_id
):
continue
if ann.get("segmentation_mask"):
polyline = ann["segmentation_mask"]
color = ann.get("class_color", "#FF0000")
self.annotation_canvas.draw_saved_polyline(
polyline,
color,
width=3,
annotation_id=ann["id"],
)
self.annotation_canvas.draw_saved_bbox(
[ann["x_min"], ann["y_min"], ann["x_max"], ann["y_max"]],
color,
width=3,
)
drawn_count += 1
logger.info(
f"Displayed {drawn_count} annotation(s) for current image with "
f"{'no class filter' if selected_class_id is None else f'class_id={selected_class_id}'}"
)
def _restore_state(self): def _restore_state(self):
"""Restore splitter positions from settings.""" """Restore splitter positions from settings."""

View File

@@ -20,12 +20,14 @@ from PySide6.QtWidgets import (
) )
from PySide6.QtCore import Qt, QThread, Signal from PySide6.QtCore import Qt, QThread, Signal
from pathlib import Path from pathlib import Path
from typing import Optional
from src.database.db_manager import DatabaseManager from src.database.db_manager import DatabaseManager
from src.utils.config_manager import ConfigManager from src.utils.config_manager import ConfigManager
from src.utils.logger import get_logger from src.utils.logger import get_logger
from src.utils.file_utils import get_image_files from src.utils.file_utils import get_image_files
from src.model.inference import InferenceEngine from src.model.inference import InferenceEngine
from src.utils.image import Image
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -147,30 +149,66 @@ class DetectionTab(QWidget):
self.model_combo.currentIndexChanged.connect(self._on_model_changed) self.model_combo.currentIndexChanged.connect(self._on_model_changed)
def _load_models(self): def _load_models(self):
"""Load available models from database.""" """Load available models from database and local storage."""
try: try:
models = self.db_manager.get_models()
self.model_combo.clear() self.model_combo.clear()
models = self.db_manager.get_models()
has_models = False
if not models: known_paths = set()
self.model_combo.addItem("No models available", None)
self._set_buttons_enabled(False)
return
# Add base model option # Add base model option first (always available)
base_model = self.config_manager.get( base_model = self.config_manager.get(
"models.default_base_model", "yolov8s-seg.pt" "models.default_base_model", "yolov8s-seg.pt"
) )
self.model_combo.addItem( if base_model:
f"Base Model ({base_model})", {"id": 0, "path": base_model} base_data = {
) "id": 0,
"path": base_model,
"model_name": Path(base_model).stem or "Base Model",
"model_version": "pretrained",
"base_model": base_model,
"source": "base",
}
self.model_combo.addItem(f"Base Model ({base_model})", base_data)
known_paths.add(self._normalize_model_path(base_model))
has_models = True
# Add trained models # Add trained models from database
for model in models: for model in models:
display_name = f"{model['model_name']} v{model['model_version']}" display_name = f"{model['model_name']} v{model['model_version']}"
self.model_combo.addItem(display_name, model) model_data = {**model, "path": model.get("model_path")}
normalized = self._normalize_model_path(model_data.get("path"))
if normalized:
known_paths.add(normalized)
self.model_combo.addItem(display_name, model_data)
has_models = True
self._set_buttons_enabled(True) # Discover local model files not yet in the database
local_models = self._discover_local_models()
for model_path in local_models:
normalized = self._normalize_model_path(model_path)
if normalized in known_paths:
continue
display_name = f"Local Model ({Path(model_path).stem})"
model_data = {
"id": None,
"path": str(model_path),
"model_name": Path(model_path).stem,
"model_version": "local",
"base_model": Path(model_path).stem,
"source": "local",
}
self.model_combo.addItem(display_name, model_data)
known_paths.add(normalized)
has_models = True
if not has_models:
self.model_combo.addItem("No models available", None)
self._set_buttons_enabled(False)
else:
self._set_buttons_enabled(True)
except Exception as e: except Exception as e:
logger.error(f"Error loading models: {e}") logger.error(f"Error loading models: {e}")
@@ -199,7 +237,7 @@ class DetectionTab(QWidget):
self, self,
"Select Image", "Select Image",
start_dir, start_dir,
"Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)", "Images (*" + " *".join(Image.SUPPORTED_EXTENSIONS) + ")",
) )
if not file_path: if not file_path:
@@ -249,25 +287,39 @@ class DetectionTab(QWidget):
QMessageBox.warning(self, "No Model", "Please select a model first.") QMessageBox.warning(self, "No Model", "Please select a model first.")
return return
model_path = model_data["path"] model_path = model_data.get("path")
model_id = model_data["id"] if not model_path:
QMessageBox.warning(
self, "Invalid Model", "Selected model is missing a file path."
)
return
# Ensure we have a valid model ID (create entry for base model if needed) if not Path(model_path).exists():
if model_id == 0: QMessageBox.critical(
# Create database entry for base model self,
base_model = self.config_manager.get( "Model Not Found",
"models.default_base_model", "yolov8s-seg.pt" f"The selected model file could not be found:\n{model_path}",
)
model_id = self.db_manager.add_model(
model_name="Base Model",
model_version="pretrained",
model_path=base_model,
base_model=base_model,
) )
return
model_id = model_data.get("id")
# Ensure we have a database entry for the selected model
if model_id in (None, 0):
model_id = self._ensure_model_record(model_data)
if not model_id:
QMessageBox.critical(
self,
"Model Registration Failed",
"Unable to register the selected model in the database.",
)
return
normalized_model_path = self._normalize_model_path(model_path) or model_path
# Create inference engine # Create inference engine
self.inference_engine = InferenceEngine( self.inference_engine = InferenceEngine(
model_path, self.db_manager, model_id normalized_model_path, self.db_manager, model_id
) )
# Get confidence threshold # Get confidence threshold
@@ -338,6 +390,76 @@ class DetectionTab(QWidget):
self.batch_btn.setEnabled(enabled) self.batch_btn.setEnabled(enabled)
self.model_combo.setEnabled(enabled) self.model_combo.setEnabled(enabled)
def _discover_local_models(self) -> list:
"""Scan the models directory for standalone .pt files."""
models_dir = self.config_manager.get_models_directory()
if not models_dir:
return []
models_path = Path(models_dir)
if not models_path.exists():
return []
try:
return sorted(
[p for p in models_path.rglob("*.pt") if p.is_file()],
key=lambda p: str(p).lower(),
)
except Exception as e:
logger.warning(f"Error discovering local models: {e}")
return []
def _normalize_model_path(self, path_value) -> str:
"""Return a normalized absolute path string for comparison."""
if not path_value:
return ""
try:
return str(Path(path_value).resolve())
except Exception:
return str(path_value)
def _ensure_model_record(self, model_data: dict) -> Optional[int]:
"""Ensure a database record exists for the selected model."""
model_path = model_data.get("path")
if not model_path:
return None
normalized_target = self._normalize_model_path(model_path)
try:
existing_models = self.db_manager.get_models()
for model in existing_models:
existing_path = model.get("model_path")
if not existing_path:
continue
normalized_existing = self._normalize_model_path(existing_path)
if (
normalized_existing == normalized_target
or existing_path == model_path
):
return model["id"]
model_name = (
model_data.get("model_name") or Path(model_path).stem or "Custom Model"
)
model_version = (
model_data.get("model_version") or model_data.get("source") or "local"
)
base_model = model_data.get(
"base_model",
self.config_manager.get("models.default_base_model", "yolov8s-seg.pt"),
)
return self.db_manager.add_model(
model_name=model_name,
model_version=model_version,
model_path=normalized_target,
base_model=base_model,
)
except Exception as e:
logger.error(f"Failed to ensure model record for {model_path}: {e}")
return None
def refresh(self): def refresh(self):
"""Refresh the tab.""" """Refresh the tab."""
self._load_models() self._load_models()

View File

@@ -1,15 +1,39 @@
""" """
Results tab for the microscopy object detection application. Results tab for browsing stored detections and visualizing overlays.
""" """
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox from pathlib import Path
from typing import Dict, List, Optional
from PySide6.QtWidgets import (
QWidget,
QVBoxLayout,
QHBoxLayout,
QLabel,
QGroupBox,
QPushButton,
QSplitter,
QTableWidget,
QTableWidgetItem,
QHeaderView,
QAbstractItemView,
QMessageBox,
QCheckBox,
)
from PySide6.QtCore import Qt
from src.database.db_manager import DatabaseManager from src.database.db_manager import DatabaseManager
from src.utils.config_manager import ConfigManager from src.utils.config_manager import ConfigManager
from src.utils.logger import get_logger
from src.utils.image import Image, ImageLoadError
from src.gui.widgets import AnnotationCanvasWidget
logger = get_logger(__name__)
class ResultsTab(QWidget): class ResultsTab(QWidget):
"""Results tab placeholder.""" """Results tab showing detection history and preview overlays."""
def __init__( def __init__(
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
@@ -18,29 +42,398 @@ class ResultsTab(QWidget):
self.db_manager = db_manager self.db_manager = db_manager
self.config_manager = config_manager self.config_manager = config_manager
self.detection_summary: List[Dict] = []
self.current_selection: Optional[Dict] = None
self.current_image: Optional[Image] = None
self.current_detections: List[Dict] = []
self._image_path_cache: Dict[str, str] = {}
self._setup_ui() self._setup_ui()
self.refresh()
def _setup_ui(self): def _setup_ui(self):
"""Setup user interface.""" """Setup user interface."""
layout = QVBoxLayout() layout = QVBoxLayout()
group = QGroupBox("Results") # Splitter for list + preview
group_layout = QVBoxLayout() splitter = QSplitter(Qt.Horizontal)
label = QLabel(
"Results viewer will be implemented here.\n\n"
"Features:\n"
"- Detection history browser\n"
"- Advanced filtering\n"
"- Statistics dashboard\n"
"- Export functionality"
)
group_layout.addWidget(label)
group.setLayout(group_layout)
layout.addWidget(group) # Left pane: detection list
layout.addStretch() left_container = QWidget()
left_layout = QVBoxLayout()
left_layout.setContentsMargins(0, 0, 0, 0)
controls_layout = QHBoxLayout()
self.refresh_btn = QPushButton("Refresh")
self.refresh_btn.clicked.connect(self.refresh)
controls_layout.addWidget(self.refresh_btn)
controls_layout.addStretch()
left_layout.addLayout(controls_layout)
self.results_table = QTableWidget(0, 5)
self.results_table.setHorizontalHeaderLabels(
["Image", "Model", "Detections", "Classes", "Last Updated"]
)
self.results_table.horizontalHeader().setSectionResizeMode(
0, QHeaderView.Stretch
)
self.results_table.horizontalHeader().setSectionResizeMode(
1, QHeaderView.Stretch
)
self.results_table.horizontalHeader().setSectionResizeMode(
2, QHeaderView.ResizeToContents
)
self.results_table.horizontalHeader().setSectionResizeMode(
3, QHeaderView.Stretch
)
self.results_table.horizontalHeader().setSectionResizeMode(
4, QHeaderView.ResizeToContents
)
self.results_table.setSelectionBehavior(QAbstractItemView.SelectRows)
self.results_table.setSelectionMode(QAbstractItemView.SingleSelection)
self.results_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
self.results_table.itemSelectionChanged.connect(self._on_result_selected)
left_layout.addWidget(self.results_table)
left_container.setLayout(left_layout)
# Right pane: preview canvas and controls
right_container = QWidget()
right_layout = QVBoxLayout()
right_layout.setContentsMargins(0, 0, 0, 0)
preview_group = QGroupBox("Detection Preview")
preview_layout = QVBoxLayout()
self.preview_canvas = AnnotationCanvasWidget()
self.preview_canvas.set_polyline_enabled(False)
self.preview_canvas.set_show_bboxes(True)
preview_layout.addWidget(self.preview_canvas)
toggles_layout = QHBoxLayout()
self.show_masks_checkbox = QCheckBox("Show Masks")
self.show_masks_checkbox.setChecked(False)
self.show_masks_checkbox.stateChanged.connect(self._apply_detection_overlays)
self.show_bboxes_checkbox = QCheckBox("Show Bounding Boxes")
self.show_bboxes_checkbox.setChecked(True)
self.show_bboxes_checkbox.stateChanged.connect(self._toggle_bboxes)
self.show_confidence_checkbox = QCheckBox("Show Confidence")
self.show_confidence_checkbox.setChecked(False)
self.show_confidence_checkbox.stateChanged.connect(
self._apply_detection_overlays
)
toggles_layout.addWidget(self.show_masks_checkbox)
toggles_layout.addWidget(self.show_bboxes_checkbox)
toggles_layout.addWidget(self.show_confidence_checkbox)
toggles_layout.addStretch()
preview_layout.addLayout(toggles_layout)
self.summary_label = QLabel("Select a detection result to preview.")
self.summary_label.setWordWrap(True)
preview_layout.addWidget(self.summary_label)
preview_group.setLayout(preview_layout)
right_layout.addWidget(preview_group)
right_container.setLayout(right_layout)
splitter.addWidget(left_container)
splitter.addWidget(right_container)
splitter.setStretchFactor(0, 1)
splitter.setStretchFactor(1, 2)
layout.addWidget(splitter)
self.setLayout(layout) self.setLayout(layout)
def refresh(self): def refresh(self):
"""Refresh the tab.""" """Refresh the detection list and preview."""
pass self._load_detection_summary()
self._populate_results_table()
self.current_selection = None
self.current_image = None
self.current_detections = []
self.preview_canvas.clear()
self.summary_label.setText("Select a detection result to preview.")
def _load_detection_summary(self):
"""Load latest detection summaries grouped by image + model."""
try:
detections = self.db_manager.get_detections(limit=500)
summary_map: Dict[tuple, Dict] = {}
for det in detections:
key = (det["image_id"], det["model_id"])
metadata = det.get("metadata") or {}
entry = summary_map.setdefault(
key,
{
"image_id": det["image_id"],
"model_id": det["model_id"],
"image_path": det.get("image_path"),
"image_filename": det.get("image_filename")
or det.get("image_path"),
"model_name": det.get("model_name", ""),
"model_version": det.get("model_version", ""),
"last_detected": det.get("detected_at"),
"count": 0,
"classes": set(),
"source_path": metadata.get("source_path"),
"repository_root": metadata.get("repository_root"),
},
)
entry["count"] += 1
if det.get("detected_at") and (
not entry.get("last_detected")
or str(det.get("detected_at")) > str(entry.get("last_detected"))
):
entry["last_detected"] = det.get("detected_at")
if det.get("class_name"):
entry["classes"].add(det["class_name"])
if metadata.get("source_path") and not entry.get("source_path"):
entry["source_path"] = metadata.get("source_path")
if metadata.get("repository_root") and not entry.get("repository_root"):
entry["repository_root"] = metadata.get("repository_root")
self.detection_summary = sorted(
summary_map.values(),
key=lambda x: str(x.get("last_detected") or ""),
reverse=True,
)
except Exception as e:
logger.error(f"Failed to load detection summary: {e}")
QMessageBox.critical(
self,
"Error",
f"Failed to load detection results:\n{str(e)}",
)
self.detection_summary = []
def _populate_results_table(self):
"""Populate the table widget with detection summaries."""
self.results_table.setRowCount(len(self.detection_summary))
for row, entry in enumerate(self.detection_summary):
model_label = f"{entry['model_name']} {entry['model_version']}".strip()
class_list = (
", ".join(sorted(entry["classes"])) if entry["classes"] else "-"
)
items = [
QTableWidgetItem(entry.get("image_filename", "")),
QTableWidgetItem(model_label),
QTableWidgetItem(str(entry.get("count", 0))),
QTableWidgetItem(class_list),
QTableWidgetItem(str(entry.get("last_detected") or "")),
]
for col, item in enumerate(items):
item.setData(Qt.UserRole, row)
self.results_table.setItem(row, col, item)
self.results_table.clearSelection()
def _on_result_selected(self):
"""Handle selection changes in the detection table."""
selected_items = self.results_table.selectedItems()
if not selected_items:
return
row = selected_items[0].data(Qt.UserRole)
if row is None or row >= len(self.detection_summary):
return
entry = self.detection_summary[row]
if (
self.current_selection
and self.current_selection.get("image_id") == entry["image_id"]
and self.current_selection.get("model_id") == entry["model_id"]
):
return
self.current_selection = entry
image_path = self._resolve_image_path(entry)
if not image_path:
QMessageBox.warning(
self,
"Image Not Found",
"Unable to locate the image file for this detection.",
)
return
try:
self.current_image = Image(image_path)
self.preview_canvas.load_image(self.current_image)
except ImageLoadError as e:
logger.error(f"Failed to load image '{image_path}': {e}")
QMessageBox.critical(
self,
"Image Error",
f"Failed to load image for preview:\n{str(e)}",
)
return
self._load_detections_for_selection(entry)
self._apply_detection_overlays()
self._update_summary_label(entry)
def _load_detections_for_selection(self, entry: Dict):
"""Load detection records for the selected image/model pair."""
self.current_detections = []
if not entry:
return
try:
filters = {"image_id": entry["image_id"], "model_id": entry["model_id"]}
self.current_detections = self.db_manager.get_detections(filters)
except Exception as e:
logger.error(f"Failed to load detections for preview: {e}")
QMessageBox.critical(
self,
"Error",
f"Failed to load detections for this image:\n{str(e)}",
)
self.current_detections = []
def _apply_detection_overlays(self):
"""Draw detections onto the preview canvas based on current toggles."""
self.preview_canvas.clear_annotations()
self.preview_canvas.set_show_bboxes(self.show_bboxes_checkbox.isChecked())
if not self.current_detections or not self.current_image:
return
for det in self.current_detections:
color = self._get_class_color(det.get("class_name"))
if self.show_masks_checkbox.isChecked() and det.get("segmentation_mask"):
mask_points = self._convert_mask(det["segmentation_mask"])
if mask_points:
self.preview_canvas.draw_saved_polyline(mask_points, color)
bbox = [
det.get("x_min"),
det.get("y_min"),
det.get("x_max"),
det.get("y_max"),
]
if all(v is not None for v in bbox):
label = None
if self.show_confidence_checkbox.isChecked():
confidence = det.get("confidence")
if confidence is not None:
label = f"{confidence:.2f}"
self.preview_canvas.draw_saved_bbox(bbox, color, label=label)
def _convert_mask(self, mask_points: List[List[float]]) -> List[List[float]]:
"""Convert stored [x, y] masks to [y, x] format for the canvas."""
converted = []
for point in mask_points:
if len(point) >= 2:
x, y = point[0], point[1]
converted.append([y, x])
return converted
def _toggle_bboxes(self):
"""Update bounding box visibility on the canvas."""
self.preview_canvas.set_show_bboxes(self.show_bboxes_checkbox.isChecked())
# Re-render to respect show/hide when toggled
self._apply_detection_overlays()
def _update_summary_label(self, entry: Dict):
"""Display textual summary for the selected detection run."""
classes = ", ".join(sorted(entry.get("classes", []))) or "-"
summary_text = (
f"Image: {entry.get('image_filename', 'unknown')}\n"
f"Model: {entry.get('model_name', '')} {entry.get('model_version', '')}\n"
f"Detections: {entry.get('count', 0)}\n"
f"Classes: {classes}\n"
f"Last Updated: {entry.get('last_detected', 'n/a')}"
)
self.summary_label.setText(summary_text)
def _resolve_image_path(self, entry: Dict) -> Optional[str]:
"""Resolve an image path using metadata, cache, and repository hints."""
relative_path = entry.get("image_path") if entry else None
cache_key = relative_path or entry.get("source_path")
if cache_key and cache_key in self._image_path_cache:
cached = Path(self._image_path_cache[cache_key])
if cached.exists():
return self._image_path_cache[cache_key]
del self._image_path_cache[cache_key]
candidates = []
source_path = entry.get("source_path") if entry else None
if source_path:
candidates.append(Path(source_path))
repo_roots = []
if entry.get("repository_root"):
repo_roots.append(entry["repository_root"])
config_repo = self.config_manager.get_image_repository_path()
if config_repo:
repo_roots.append(config_repo)
for root in repo_roots:
if relative_path:
candidates.append(Path(root) / relative_path)
if relative_path:
candidates.append(Path(relative_path))
for candidate in candidates:
try:
if candidate and candidate.exists():
resolved = str(candidate.resolve())
if cache_key:
self._image_path_cache[cache_key] = resolved
return resolved
except Exception:
continue
# Fallback: search by filename in known roots
filename = Path(relative_path).name if relative_path else None
if filename:
search_roots = [Path(root) for root in repo_roots if root]
if not search_roots:
search_roots = [Path("data")]
match = self._search_in_roots(filename, search_roots)
if match:
resolved = str(match.resolve())
if cache_key:
self._image_path_cache[cache_key] = resolved
return resolved
return None
def _search_in_roots(self, filename: str, roots: List[Path]) -> Optional[Path]:
"""Search for a file name within a list of root directories."""
for root in roots:
try:
if not root.exists():
continue
for candidate in root.rglob(filename):
return candidate
except Exception as e:
logger.debug(f"Error searching for {filename} in {root}: {e}")
return None
def _get_class_color(self, class_name: Optional[str]) -> str:
"""Return consistent color hex for a class name."""
if not class_name:
return "#FF6B6B"
color_map = self.config_manager.get_bbox_colors()
if class_name in color_map:
return color_map[class_name]
# Deterministic fallback color based on hash
palette = [
"#FF6B6B",
"#4ECDC4",
"#FFD166",
"#1D3557",
"#F4A261",
"#E76F51",
]
return palette[hash(class_name) % len(palette)]

File diff suppressed because it is too large Load Diff

View File

@@ -1,9 +1,10 @@
""" """
Annotation canvas widget for drawing annotations on images. Annotation canvas widget for drawing annotations on images.
Supports pen tool with color selection for manual annotation. Currently supports polyline drawing tool with color selection for manual annotation.
""" """
import numpy as np import numpy as np
import math
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QScrollArea from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QScrollArea
from PySide6.QtGui import ( from PySide6.QtGui import (
@@ -15,29 +16,107 @@ from PySide6.QtGui import (
QKeyEvent, QKeyEvent,
QMouseEvent, QMouseEvent,
QPaintEvent, QPaintEvent,
QPolygonF,
) )
from PySide6.QtCore import Qt, QEvent, Signal, QPoint from PySide6.QtCore import Qt, QEvent, Signal, QPoint, QPointF, QRect
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from scipy.ndimage import binary_dilation, label, binary_fill_holes, find_objects
from skimage.measure import find_contours
from src.utils.image import Image, ImageLoadError from src.utils.image import Image, ImageLoadError
from src.utils.logger import get_logger from src.utils.logger import get_logger
# For debugging visualization
import pylab as plt
logger = get_logger(__name__) logger = get_logger(__name__)
def perpendicular_distance(
point: Tuple[float, float],
start: Tuple[float, float],
end: Tuple[float, float],
) -> float:
"""Perpendicular distance from `point` to the line defined by `start`->`end`."""
(x, y), (x1, y1), (x2, y2) = point, start, end
dx = x2 - x1
dy = y2 - y1
if dx == 0.0 and dy == 0.0:
return math.hypot(x - x1, y - y1)
num = abs(dy * x - dx * y + x2 * y1 - y2 * x1)
den = math.hypot(dx, dy)
return num / den
def rdp(points: List[Tuple[float, float]], epsilon: float) -> List[Tuple[float, float]]:
"""
Recursive Ramer-Douglas-Peucker (RDP) polyline simplification.
Args:
points: List of (x, y) points.
epsilon: Maximum allowed perpendicular distance in pixels.
Returns:
Simplified list of (x, y) points including first and last.
"""
if len(points) <= 2:
return list(points)
start = points[0]
end = points[-1]
max_dist = -1.0
index = -1
for i in range(1, len(points) - 1):
d = perpendicular_distance(points[i], start, end)
if d > max_dist:
max_dist = d
index = i
if max_dist > epsilon:
# Recursive split
left = rdp(points[: index + 1], epsilon)
right = rdp(points[index:], epsilon)
# Concatenate but avoid duplicate at split point
return left[:-1] + right
# Keep only start and end
return [start, end]
def simplify_polyline(
points: List[Tuple[float, float]], epsilon: float
) -> List[Tuple[float, float]]:
"""
Simplify a polyline with RDP while preserving closure semantics.
If the polyline is closed (first == last), the duplicate last point is removed
before simplification and then re-added after simplification.
"""
if not points:
return []
pts = [(float(x), float(y)) for x, y in points]
closed = False
if len(pts) >= 2 and pts[0] == pts[-1]:
closed = True
pts = pts[:-1] # remove duplicate last for simplification
if len(pts) <= 2:
simplified = list(pts)
else:
simplified = rdp(pts, epsilon)
if closed and simplified:
if simplified[0] != simplified[-1]:
simplified.append(simplified[0])
return simplified
class AnnotationCanvasWidget(QWidget): class AnnotationCanvasWidget(QWidget):
""" """
Widget for displaying images and drawing annotations with pen tool. Widget for displaying images and drawing annotations with zoom and drawing tools.
Features: Features:
- Display images with zoom functionality - Display images with zoom functionality
- Pen tool for drawing annotations - Polyline tool for drawing annotations
- Configurable pen color and width - Configurable pen color and width
- Mouse-based drawing interface - Mouse-based drawing interface
- Zoom in/out with mouse wheel and keyboard - Zoom in/out with mouse wheel and keyboard
@@ -49,6 +128,9 @@ class AnnotationCanvasWidget(QWidget):
zoom_changed = Signal(float) zoom_changed = Signal(float)
annotation_drawn = Signal(list) # List of (x, y) points in normalized coordinates annotation_drawn = Signal(list) # List of (x, y) points in normalized coordinates
# Emitted when the user selects an existing polyline on the canvas.
# Carries the associated annotation_id (int) or None if selection is cleared
annotation_selected = Signal(object)
def __init__(self, parent=None): def __init__(self, parent=None):
"""Initialize the annotation canvas widget.""" """Initialize the annotation canvas widget."""
@@ -63,13 +145,33 @@ class AnnotationCanvasWidget(QWidget):
self.zoom_step = 0.1 self.zoom_step = 0.1
self.zoom_wheel_step = 0.15 self.zoom_wheel_step = 0.15
# Drawing state # Drawing / interaction state
self.is_drawing = False self.is_drawing = False
self.pen_enabled = False self.polyline_enabled = False
self.pen_color = QColor(255, 0, 0, 128) # Default red with 50% alpha self.polyline_pen_color = QColor(255, 0, 0, 128) # Default red with 50% alpha
self.pen_width = 3 self.polyline_pen_width = 3
self.current_stroke = [] # Points in current stroke self.show_bboxes: bool = True # Control visibility of bounding boxes
self.all_strokes = [] # All completed strokes
# Current stroke and stored polylines (in image coordinates, pixel units)
self.current_stroke: List[Tuple[float, float]] = []
self.polylines: List[List[Tuple[float, float]]] = []
self.stroke_meta: List[Dict[str, Any]] = [] # per-polyline style (color, width)
# Optional DB annotation_id for each stored polyline (None for temporary / unsaved)
self.polyline_annotation_ids: List[Optional[int]] = []
# Indices in self.polylines of the currently selected polylines (multi-select)
self.selected_polyline_indices: List[int] = []
# Stored bounding boxes in normalized coordinates (x_min, y_min, x_max, y_max)
self.bboxes: List[List[float]] = []
self.bbox_meta: List[Dict[str, Any]] = [] # per-bbox style (color, width)
# Legacy collection of strokes in normalized coordinates (kept for API compatibility)
self.all_strokes: List[dict] = []
# RDP simplification parameters (in pixels)
self.simplify_on_finish: bool = True
self.simplify_epsilon: float = 2.0
self.sample_threshold: float = 2.0 # minimum movement to sample a new point
self._setup_ui() self._setup_ui()
@@ -128,6 +230,12 @@ class AnnotationCanvasWidget(QWidget):
"""Clear all drawn annotations.""" """Clear all drawn annotations."""
self.all_strokes = [] self.all_strokes = []
self.current_stroke = [] self.current_stroke = []
self.polylines = []
self.stroke_meta = []
self.polyline_annotation_ids = []
self.selected_polyline_indices = []
self.bboxes = []
self.bbox_meta = []
self.is_drawing = False self.is_drawing = False
if self.annotation_pixmap: if self.annotation_pixmap:
self.annotation_pixmap.fill(Qt.transparent) self.annotation_pixmap.fill(Qt.transparent)
@@ -139,10 +247,10 @@ class AnnotationCanvasWidget(QWidget):
return return
try: try:
# Get RGB image data # Get image data in a format compatible with Qt
if self.current_image.channels == 3: if self.current_image.channels in (3, 4):
image_data = self.current_image.get_rgb() image_data = self.current_image.get_rgb()
height, width, channels = image_data.shape height, width = image_data.shape[:2]
else: else:
image_data = self.current_image.get_grayscale() image_data = self.current_image.get_grayscale()
height, width = image_data.shape height, width = image_data.shape
@@ -156,7 +264,7 @@ class AnnotationCanvasWidget(QWidget):
height, height,
bytes_per_line, bytes_per_line,
self.current_image.qtimage_format, self.current_image.qtimage_format,
) ).copy() # Copy so Qt owns the buffer even after numpy array goes out of scope
self.original_pixmap = QPixmap.fromImage(qimage) self.original_pixmap = QPixmap.fromImage(qimage)
@@ -218,21 +326,21 @@ class AnnotationCanvasWidget(QWidget):
"""Update display after drawing.""" """Update display after drawing."""
self._apply_zoom() self._apply_zoom()
def set_pen_enabled(self, enabled: bool): def set_polyline_enabled(self, enabled: bool):
"""Enable or disable pen tool.""" """Enable or disable polyline tool."""
self.pen_enabled = enabled self.polyline_enabled = enabled
if enabled: if enabled:
self.canvas_label.setCursor(Qt.CrossCursor) self.canvas_label.setCursor(Qt.CrossCursor)
else: else:
self.canvas_label.setCursor(Qt.ArrowCursor) self.canvas_label.setCursor(Qt.ArrowCursor)
def set_pen_color(self, color: QColor): def set_polyline_pen_color(self, color: QColor):
"""Set pen color.""" """Set polyline pen color."""
self.pen_color = color self.polyline_pen_color = color
def set_pen_width(self, width: int): def set_polyline_pen_width(self, width: int):
"""Set pen width.""" """Set polyline pen width."""
self.pen_width = max(1, width) self.polyline_pen_width = max(1, width)
def get_zoom_percentage(self) -> int: def get_zoom_percentage(self) -> int:
"""Get current zoom level as percentage.""" """Get current zoom level as percentage."""
@@ -291,6 +399,41 @@ class AnnotationCanvasWidget(QWidget):
return (int(x), int(y)) return (int(x), int(y))
return None return None
def _find_polyline_at(
self, img_x: float, img_y: float, threshold_px: float = 5.0
) -> Optional[int]:
"""
Find index of polyline whose geometry is within threshold_px of (img_x, img_y).
Returns the index in self.polylines, or None if none is close enough.
"""
best_index: Optional[int] = None
best_dist: float = float("inf")
for idx, polyline in enumerate(self.polylines):
if len(polyline) < 2:
continue
# Quick bounding-box check to skip obviously distant polylines
xs = [p[0] for p in polyline]
ys = [p[1] for p in polyline]
if img_x < min(xs) - threshold_px or img_x > max(xs) + threshold_px:
continue
if img_y < min(ys) - threshold_px or img_y > max(ys) + threshold_px:
continue
# Precise distance to all segments
for (x1, y1), (x2, y2) in zip(polyline[:-1], polyline[1:]):
d = perpendicular_distance(
(img_x, img_y), (float(x1), float(y1)), (float(x2), float(y2))
)
if d < best_dist:
best_dist = d
best_index = idx
if best_index is not None and best_dist <= threshold_px:
return best_index
return None
def _image_to_normalized_coords(self, x: int, y: int) -> Tuple[float, float]: def _image_to_normalized_coords(self, x: int, y: int) -> Tuple[float, float]:
"""Convert image coordinates to normalized coordinates (0-1).""" """Convert image coordinates to normalized coordinates (0-1)."""
if self.original_pixmap is None: if self.original_pixmap is None:
@@ -300,26 +443,192 @@ class AnnotationCanvasWidget(QWidget):
norm_y = y / self.original_pixmap.height() norm_y = y / self.original_pixmap.height()
return (norm_x, norm_y) return (norm_x, norm_y)
def _add_polyline(
self,
img_points: List[Tuple[float, float]],
color: QColor,
width: int,
annotation_id: Optional[int] = None,
):
"""Store a polyline in image coordinates and redraw annotations."""
if not img_points or len(img_points) < 2:
return
# Ensure all points are tuples of floats
normalized_points = [(float(x), float(y)) for x, y in img_points]
self.polylines.append(normalized_points)
self.stroke_meta.append({"color": QColor(color), "width": int(width)})
self.polyline_annotation_ids.append(annotation_id)
self._redraw_annotations()
def _redraw_annotations(self):
"""Redraw all stored polylines and (optionally) bounding boxes onto the annotation pixmap."""
if self.annotation_pixmap is None:
return
# Clear existing overlay
self.annotation_pixmap.fill(Qt.transparent)
painter = QPainter(self.annotation_pixmap)
# Draw polylines
for idx, (polyline, meta) in enumerate(zip(self.polylines, self.stroke_meta)):
pen_color: QColor = meta.get("color", self.polyline_pen_color)
width: int = meta.get("width", self.polyline_pen_width)
if idx in self.selected_polyline_indices:
# Highlight selected polylines in a distinct color / width
highlight_color = QColor(255, 255, 0, 200) # yellow, semi-opaque
pen = QPen(
highlight_color,
width + 1,
Qt.SolidLine,
Qt.RoundCap,
Qt.RoundJoin,
)
else:
pen = QPen(
pen_color,
width,
Qt.SolidLine,
Qt.RoundCap,
Qt.RoundJoin,
)
painter.setPen(pen)
# Use QPolygonF for efficient polygon rendering (single call vs N-1 calls)
# drawPolygon() automatically closes the shape, ensuring proper visual closure
polygon = QPolygonF([QPointF(x, y) for x, y in polyline])
painter.drawPolygon(polygon)
# Draw bounding boxes (dashed) if enabled
if self.show_bboxes and self.original_pixmap is not None and self.bboxes:
img_width = float(self.original_pixmap.width())
img_height = float(self.original_pixmap.height())
for bbox, meta in zip(self.bboxes, self.bbox_meta):
if len(bbox) != 4:
continue
x_min_norm, y_min_norm, x_max_norm, y_max_norm = bbox
x_min = int(x_min_norm * img_width)
y_min = int(y_min_norm * img_height)
x_max = int(x_max_norm * img_width)
y_max = int(y_max_norm * img_height)
rect_width = x_max - x_min
rect_height = y_max - y_min
pen_color: QColor = meta.get("color", QColor(255, 0, 0, 128))
width: int = meta.get("width", self.polyline_pen_width)
pen = QPen(
pen_color,
width,
Qt.DashLine,
Qt.SquareCap,
Qt.MiterJoin,
)
painter.setPen(pen)
painter.drawRect(x_min, y_min, rect_width, rect_height)
label_text = meta.get("label")
if label_text:
painter.save()
font = painter.font()
font.setPointSizeF(max(10.0, width + 4))
painter.setFont(font)
metrics = painter.fontMetrics()
text_width = metrics.horizontalAdvance(label_text)
text_height = metrics.height()
padding = 4
bg_width = text_width + padding * 2
bg_height = text_height + padding * 2
canvas_width = self.original_pixmap.width()
canvas_height = self.original_pixmap.height()
bg_x = max(0, min(x_min, canvas_width - bg_width))
bg_y = y_min - bg_height
if bg_y < 0:
bg_y = min(y_min, canvas_height - bg_height)
bg_y = max(0, bg_y)
background_rect = QRect(bg_x, bg_y, bg_width, bg_height)
background_color = QColor(pen_color)
background_color.setAlpha(220)
painter.fillRect(background_rect, background_color)
text_color = QColor(0, 0, 0)
if background_color.lightness() < 128:
text_color = QColor(255, 255, 255)
painter.setPen(text_color)
painter.drawText(
background_rect.adjusted(padding, padding, -padding, -padding),
Qt.AlignLeft | Qt.AlignVCenter,
label_text,
)
painter.restore()
painter.end()
self._update_display()
def mousePressEvent(self, event: QMouseEvent): def mousePressEvent(self, event: QMouseEvent):
"""Handle mouse press events for drawing.""" """Handle mouse press events for drawing and selecting polylines."""
if not self.pen_enabled or self.annotation_pixmap is None: if self.annotation_pixmap is None:
super().mousePressEvent(event) super().mousePressEvent(event)
return return
if event.button() == Qt.LeftButton: # Map click to image coordinates
# Get accurate position using global coordinates label_pos = self.canvas_label.mapFromGlobal(event.globalPos())
label_pos = self.canvas_label.mapFromGlobal(event.globalPos()) img_coords = self._canvas_to_image_coords(label_pos)
img_coords = self._canvas_to_image_coords(label_pos)
# Left button + drawing tool enabled -> start a new stroke
if event.button() == Qt.LeftButton and self.polyline_enabled:
if img_coords: if img_coords:
self.is_drawing = True self.is_drawing = True
self.current_stroke = [img_coords] self.current_stroke = [(float(img_coords[0]), float(img_coords[1]))]
return
# Left button + drawing tool disabled -> attempt selection of existing polyline
if event.button() == Qt.LeftButton and not self.polyline_enabled:
if img_coords:
idx = self._find_polyline_at(float(img_coords[0]), float(img_coords[1]))
if idx is not None:
if event.modifiers() & Qt.ShiftModifier:
# Multi-select mode: add to current selection (if not already selected)
if idx not in self.selected_polyline_indices:
self.selected_polyline_indices.append(idx)
else:
# Single-select mode: replace current selection
self.selected_polyline_indices = [idx]
# Build list of selected annotation IDs (ignore None entries)
selected_ids: List[int] = []
for sel_idx in self.selected_polyline_indices:
if 0 <= sel_idx < len(self.polyline_annotation_ids):
ann_id = self.polyline_annotation_ids[sel_idx]
if isinstance(ann_id, int):
selected_ids.append(ann_id)
if selected_ids:
self.annotation_selected.emit(selected_ids)
else:
# No valid DB-backed annotations in selection
self.annotation_selected.emit(None)
else:
# Clicked on empty space -> clear selection
self.selected_polyline_indices = []
self.annotation_selected.emit(None)
self._redraw_annotations()
return
# Fallback for other buttons / cases
super().mousePressEvent(event)
def mouseMoveEvent(self, event: QMouseEvent): def mouseMoveEvent(self, event: QMouseEvent):
"""Handle mouse move events for drawing.""" """Handle mouse move events for drawing."""
if ( if (
not self.is_drawing not self.is_drawing
or not self.pen_enabled or not self.polyline_enabled
or self.annotation_pixmap is None or self.annotation_pixmap is None
): ):
super().mouseMoveEvent(event) super().mouseMoveEvent(event)
@@ -330,18 +639,33 @@ class AnnotationCanvasWidget(QWidget):
img_coords = self._canvas_to_image_coords(label_pos) img_coords = self._canvas_to_image_coords(label_pos)
if img_coords and len(self.current_stroke) > 0: if img_coords and len(self.current_stroke) > 0:
# Draw line from last point to current point last_point = self.current_stroke[-1]
dx = img_coords[0] - last_point[0]
dy = img_coords[1] - last_point[1]
# Only sample a new point if we moved enough pixels
if math.hypot(dx, dy) < self.sample_threshold:
return
# Draw line from last point to current point for interactive feedback
painter = QPainter(self.annotation_pixmap) painter = QPainter(self.annotation_pixmap)
pen = QPen( pen = QPen(
self.pen_color, self.pen_width, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin self.polyline_pen_color,
self.polyline_pen_width,
Qt.SolidLine,
Qt.RoundCap,
Qt.RoundJoin,
) )
painter.setPen(pen) painter.setPen(pen)
painter.drawLine(
last_point = self.current_stroke[-1] int(last_point[0]),
painter.drawLine(last_point[0], last_point[1], img_coords[0], img_coords[1]) int(last_point[1]),
int(img_coords[0]),
int(img_coords[1]),
)
painter.end() painter.end()
self.current_stroke.append(img_coords) self.current_stroke.append((float(img_coords[0]), float(img_coords[1])))
self._update_display() self._update_display()
def mouseReleaseEvent(self, event: QMouseEvent): def mouseReleaseEvent(self, event: QMouseEvent):
@@ -352,23 +676,44 @@ class AnnotationCanvasWidget(QWidget):
self.is_drawing = False self.is_drawing = False
if len(self.current_stroke) > 1: if len(self.current_stroke) > 1 and self.original_pixmap is not None:
# Convert to normalized coordinates and save stroke # Ensure the stroke is closed by connecting end -> start
normalized_stroke = [ raw_points = list(self.current_stroke)
self._image_to_normalized_coords(x, y) for x, y in self.current_stroke if raw_points[0] != raw_points[-1]:
] raw_points.append(raw_points[0])
self.all_strokes.append(
{
"points": normalized_stroke,
"color": self.pen_color.name(),
"alpha": self.pen_color.alpha(),
"width": self.pen_width,
}
)
# Emit signal with normalized coordinates # Optional RDP simplification (in image pixel space)
self.annotation_drawn.emit(normalized_stroke) if self.simplify_on_finish:
logger.debug(f"Completed stroke with {len(normalized_stroke)} points") simplified = simplify_polyline(raw_points, self.simplify_epsilon)
else:
simplified = raw_points
if len(simplified) >= 2:
# Store polyline and redraw all annotations
self._add_polyline(
simplified, self.polyline_pen_color, self.polyline_pen_width
)
# Convert to normalized coordinates for metadata + signal
normalized_stroke = [
self._image_to_normalized_coords(int(x), int(y))
for (x, y) in simplified
]
self.all_strokes.append(
{
"points": normalized_stroke,
"color": self.polyline_pen_color.name(),
"alpha": self.polyline_pen_color.alpha(),
"width": self.polyline_pen_width,
}
)
# Emit signal with normalized coordinates
self.annotation_drawn.emit(normalized_stroke)
logger.debug(
f"Completed stroke with {len(simplified)} points "
f"(normalized len={len(normalized_stroke)})"
)
self.current_stroke = [] self.current_stroke = []
@@ -376,152 +721,61 @@ class AnnotationCanvasWidget(QWidget):
"""Get all drawn strokes with metadata.""" """Get all drawn strokes with metadata."""
return self.all_strokes return self.all_strokes
# def get_annotation_bounds(self) -> Optional[Tuple[float, float, float, float]]: def get_annotation_parameters(self) -> Optional[List[Dict[str, Any]]]:
# """
# Compute bounding box that encompasses all annotation strokes.
# Returns:
# Tuple of (x_min, y_min, x_max, y_max) in normalized coordinates (0-1),
# or None if no annotations exist.
# """
# if not self.all_strokes:
# return None
# # Find min/max across all strokes
# all_x = []
# all_y = []
# for stroke in self.all_strokes:
# for x, y in stroke["points"]:
# all_x.append(x)
# all_y.append(y)
# if not all_x:
# return None
# x_min = min(all_x)
# y_min = min(all_y)
# x_max = max(all_x)
# y_max = max(all_y)
# return (x_min, y_min, x_max, y_max)
# def get_annotation_polyline(self) -> List[List[float]]:
# """
# Get polyline coordinates representing all annotation strokes.
# Returns:
# List of [x, y] coordinate pairs in normalized coordinates (0-1).
# """
# polyline = []
# fig = plt.figure()
# ax1 = fig.add_subplot(411)
# ax2 = fig.add_subplot(412)
# ax3 = fig.add_subplot(413)
# ax4 = fig.add_subplot(414)
# # Get np.arrays from annotation_pixmap accoriding to the color of the stroke
# qimage = self.annotation_pixmap.toImage()
# arr = np.ndarray(
# (qimage.height(), qimage.width(), 4),
# buffer=qimage.constBits(),
# strides=[qimage.bytesPerLine(), 4, 1],
# dtype=np.uint8,
# )
# print(arr.shape, arr.dtype, arr.min(), arr.max())
# arr = np.sum(arr, axis=2)
# ax1.imshow(arr)
# arr_bin = arr > 0
# ax2.imshow(arr_bin)
# arr_bin = binary_fill_holes(arr_bin)
# ax3.imshow(arr_bin)
# labels, _number_of_features = label(
# arr_bin,
# )
# ax4.imshow(labels)
# objects = find_objects(labels)
# bounding_boxes = np.array(
# [[obj[0].start, obj[0].stop, obj[1].start, obj[1].stop] for obj in objects]
# ) / np.array([arr.shape[0], arr.shape[1]])
# print(objects)
# print(bounding_boxes)
# print(np.array([arr.shape[0], arr.shape[1]]))
# polylines = find_contours(arr_bin, 0.5)
# for pl in polylines:
# ax1.plot(pl[:, 1], pl[:, 0], "k")
# print(arr.shape, arr.dtype, arr.min(), arr.max())
# plt.show()
# return polyline
def get_annotation_parameters(self) -> Dict[str, Any]:
""" """
Get all annotation parameters including bounding box and polyline. Get all annotation parameters including bounding box and polyline.
Returns: Returns:
Dictionary containing: List of dictionaries, each containing:
- 'bbox': Bounding box coordinates (x_min, y_min, x_max, y_max) - 'bbox': [x_min, y_min, x_max, y_max] in normalized image coordinates
- 'polyline': List of [x, y] coordinate pairs - 'polyline': List of [y_norm, x_norm] points describing the polygon
""" """
if self.original_pixmap is None or not self.polylines:
# Get np.arrays from annotation_pixmap accoriding to the color of the stroke
qimage = self.annotation_pixmap.toImage()
arr = np.ndarray(
(qimage.height(), qimage.width(), 4),
buffer=qimage.constBits(),
strides=[qimage.bytesPerLine(), 4, 1],
dtype=np.uint8,
)
arr = np.sum(arr, axis=2)
arr_bin = arr > 0
arr_bin = binary_fill_holes(arr_bin)
labels, _number_of_features = label(
arr_bin,
)
if _number_of_features == 0:
return None return None
objects = find_objects(labels) img_width = float(self.original_pixmap.width())
w, h = arr.shape img_height = float(self.original_pixmap.height())
bounding_boxes = [
[obj[0].start / w, obj[1].start / h, obj[0].stop / w, obj[1].stop / h]
for obj in objects
]
polylines = find_contours(arr_bin, 0.5) params: List[Dict[str, Any]] = []
params = []
for i, pl in enumerate(polylines):
# pl is in [row, col] format from find_contours
# We need to normalize: row/height, col/width
# w = height (rows), h = width (cols) from line 510
normalized_polyline = (pl[::-1] / np.array([w, h])).tolist()
logger.debug(f"Polyline {i}: {len(pl)} points") for idx, polyline in enumerate(self.polylines):
logger.debug(f" w={w} (height), h={h} (width)") if len(polyline) < 2:
logger.debug(f" First 3 normalized points: {normalized_polyline[:3]}") continue
xs = [p[0] for p in polyline]
ys = [p[1] for p in polyline]
x_min_norm = min(xs) / img_width
x_max_norm = max(xs) / img_width
y_min_norm = min(ys) / img_height
y_max_norm = max(ys) / img_height
# Store polyline as [y_norm, x_norm] to match DB convention and
# the expectations of draw_saved_polyline().
normalized_polyline = [
[y / img_height, x / img_width] for (x, y) in polyline
]
logger.debug(
f"Polyline {idx}: {len(polyline)} points, "
f"bbox=({x_min_norm:.3f}, {y_min_norm:.3f})-({x_max_norm:.3f}, {y_max_norm:.3f})"
)
params.append( params.append(
{ {
"bbox": bounding_boxes[i], "bbox": [x_min_norm, y_min_norm, x_max_norm, y_max_norm],
"polyline": normalized_polyline, "polyline": normalized_polyline,
} }
) )
return params return params or None
def draw_saved_polyline( def draw_saved_polyline(
self, polyline: List[List[float]], color: str, width: int = 3 self,
polyline: List[List[float]],
color: str,
width: int = 3,
annotation_id: Optional[int] = None,
): ):
""" """
Draw a polyline from database coordinates onto the annotation canvas. Draw a polyline from database coordinates onto the annotation canvas.
@@ -548,49 +802,44 @@ class AnnotationCanvasWidget(QWidget):
logger.debug(f" Image size: {img_width}x{img_height}") logger.debug(f" Image size: {img_width}x{img_height}")
logger.debug(f" First 3 normalized points from DB: {polyline[:3]}") logger.debug(f" First 3 normalized points from DB: {polyline[:3]}")
img_coords = [] img_coords: List[Tuple[float, float]] = []
for y_norm, x_norm in polyline: for y_norm, x_norm in polyline:
x = int(x_norm * img_width) x = float(x_norm * img_width)
y = int(y_norm * img_height) y = float(y_norm * img_height)
img_coords.append((x, y)) img_coords.append((x, y))
logger.debug(f" First 3 pixel coords: {img_coords[:3]}") logger.debug(f" First 3 pixel coords: {img_coords[:3]}")
# Draw polyline on annotation pixmap # Store and redraw using common pipeline
painter = QPainter(self.annotation_pixmap)
pen_color = QColor(color) pen_color = QColor(color)
pen_color.setAlpha(128) # Add semi-transparency pen_color.setAlpha(128) # Add semi-transparency
pen = QPen(pen_color, width, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin) self._add_polyline(img_coords, pen_color, width, annotation_id=annotation_id)
painter.setPen(pen)
# Draw lines between consecutive points # Store in all_strokes for consistency (uses normalized coordinates)
for i in range(len(img_coords) - 1):
x1, y1 = img_coords[i]
x2, y2 = img_coords[i + 1]
painter.drawLine(x1, y1, x2, y2)
painter.end()
# Store in all_strokes for consistency
self.all_strokes.append( self.all_strokes.append(
{"points": polyline, "color": color, "alpha": 128, "width": width} {"points": polyline, "color": color, "alpha": 128, "width": width}
) )
# Update display
self._update_display()
logger.debug( logger.debug(
f"Drew saved polyline with {len(polyline)} points in color {color}" f"Drew saved polyline with {len(polyline)} points in color {color}"
) )
def draw_saved_bbox(self, bbox: List[float], color: str, width: int = 3): def draw_saved_bbox(
self,
bbox: List[float],
color: str,
width: int = 3,
label: Optional[str] = None,
):
""" """
Draw a bounding box from database coordinates onto the annotation canvas. Draw a bounding box from database coordinates onto the annotation canvas.
Args: Args:
bbox: Bounding box as [y_min_norm, x_min_norm, y_max_norm, x_max_norm] bbox: Bounding box as [x_min_norm, y_min_norm, x_max_norm, y_max_norm]
in normalized coordinates (0-1) in normalized coordinates (0-1)
color: Color hex string (e.g., '#FF0000') color: Color hex string (e.g., '#FF0000')
width: Line width in pixels width: Line width in pixels
label: Optional text label to render near the bounding box
""" """
if not self.annotation_pixmap or not self.original_pixmap: if not self.annotation_pixmap or not self.original_pixmap:
logger.warning("Cannot draw bounding box: no image loaded") logger.warning("Cannot draw bounding box: no image loaded")
@@ -602,12 +851,11 @@ class AnnotationCanvasWidget(QWidget):
) )
return return
# Convert normalized coordinates to image coordinates # Convert normalized coordinates to image coordinates (for logging/debug)
# bbox format: [y_min_norm, x_min_norm, y_max_norm, x_max_norm]
img_width = self.original_pixmap.width() img_width = self.original_pixmap.width()
img_height = self.original_pixmap.height() img_height = self.original_pixmap.height()
y_min_norm, x_min_norm, y_max_norm, x_max_norm = bbox x_min_norm, y_min_norm, x_max_norm, y_max_norm = bbox
x_min = int(x_min_norm * img_width) x_min = int(x_min_norm * img_width)
y_min = int(y_min_norm * img_height) y_min = int(y_min_norm * img_height)
x_max = int(x_max_norm * img_width) x_max = int(x_max_norm * img_width)
@@ -617,29 +865,35 @@ class AnnotationCanvasWidget(QWidget):
logger.debug(f" Image size: {img_width}x{img_height}") logger.debug(f" Image size: {img_width}x{img_height}")
logger.debug(f" Pixel coords: ({x_min}, {y_min}) to ({x_max}, {y_max})") logger.debug(f" Pixel coords: ({x_min}, {y_min}) to ({x_max}, {y_max})")
# Draw bounding box on annotation pixmap # Store bounding box (normalized) and its style; actual drawing happens
painter = QPainter(self.annotation_pixmap) # in _redraw_annotations() together with all polylines.
pen_color = QColor(color) pen_color = QColor(color)
pen_color.setAlpha(128) # Add semi-transparency pen_color.setAlpha(128) # Add semi-transparency
pen = QPen(pen_color, width, Qt.SolidLine, Qt.SquareCap, Qt.MiterJoin) self.bboxes.append(
painter.setPen(pen) [float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)]
)
# Draw rectangle self.bbox_meta.append({"color": pen_color, "width": int(width), "label": label})
rect_width = x_max - x_min
rect_height = y_max - y_min
painter.drawRect(x_min, y_min, rect_width, rect_height)
painter.end()
# Store in all_strokes for consistency # Store in all_strokes for consistency
self.all_strokes.append( self.all_strokes.append(
{"bbox": bbox, "color": color, "alpha": 128, "width": width} {"bbox": bbox, "color": color, "alpha": 128, "width": width, "label": label}
) )
# Update display # Redraw overlay (polylines + all bounding boxes)
self._update_display() self._redraw_annotations()
logger.debug(f"Drew saved bounding box in color {color}") logger.debug(f"Drew saved bounding box in color {color}")
def set_show_bboxes(self, show: bool):
"""
Enable or disable drawing of bounding boxes.
Args:
show: If True, draw bounding boxes; if False, hide them.
"""
self.show_bboxes = bool(show)
logger.debug(f"Set show_bboxes to {self.show_bboxes}")
self._redraw_annotations()
def keyPressEvent(self, event: QKeyEvent): def keyPressEvent(self, event: QKeyEvent):
"""Handle keyboard events for zooming.""" """Handle keyboard events for zooming."""
if event.key() in (Qt.Key_Plus, Qt.Key_Equal): if event.key() in (Qt.Key_Plus, Qt.Key_Equal):

View File

@@ -1,6 +1,6 @@
""" """
Annotation tools widget for controlling annotation parameters. Annotation tools widget for controlling annotation parameters.
Includes pen tool, color picker, class selection, and annotation management. Includes polyline tool, color picker, class selection, and annotation management.
""" """
from PySide6.QtWidgets import ( from PySide6.QtWidgets import (
@@ -12,6 +12,8 @@ from PySide6.QtWidgets import (
QPushButton, QPushButton,
QComboBox, QComboBox,
QSpinBox, QSpinBox,
QDoubleSpinBox,
QCheckBox,
QColorDialog, QColorDialog,
QInputDialog, QInputDialog,
QMessageBox, QMessageBox,
@@ -31,28 +33,33 @@ class AnnotationToolsWidget(QWidget):
Widget for annotation tool controls. Widget for annotation tool controls.
Features: Features:
- Enable/disable pen tool - Enable/disable polyline tool
- Color selection for pen - Color selection for polyline pen
- Object class selection - Object class selection
- Add new object classes - Add new object classes
- Pen width control - Pen width control
- Clear annotations - Clear annotations
Signals: Signals:
pen_enabled_changed: Emitted when pen tool is enabled/disabled (bool) polyline_enabled_changed: Emitted when polyline tool is enabled/disabled (bool)
pen_color_changed: Emitted when pen color changes (QColor) polyline_pen_color_changed: Emitted when polyline pen color changes (QColor)
pen_width_changed: Emitted when pen width changes (int) polyline_pen_width_changed: Emitted when polyline pen width changes (int)
class_selected: Emitted when object class is selected (dict) class_selected: Emitted when object class is selected (dict)
clear_annotations_requested: Emitted when clear button is pressed clear_annotations_requested: Emitted when clear button is pressed
""" """
pen_enabled_changed = Signal(bool) polyline_enabled_changed = Signal(bool)
pen_color_changed = Signal(QColor) polyline_pen_color_changed = Signal(QColor)
pen_width_changed = Signal(int) polyline_pen_width_changed = Signal(int)
simplify_on_finish_changed = Signal(bool)
simplify_epsilon_changed = Signal(float)
# Toggle visibility of bounding boxes on the canvas
show_bboxes_changed = Signal(bool)
class_selected = Signal(dict) class_selected = Signal(dict)
class_color_changed = Signal()
clear_annotations_requested = Signal() clear_annotations_requested = Signal()
process_annotations_requested = Signal() # Request deletion of the currently selected annotation on the canvas
show_annotations_requested = Signal() delete_selected_annotation_requested = Signal()
def __init__(self, db_manager: DatabaseManager, parent=None): def __init__(self, db_manager: DatabaseManager, parent=None):
""" """
@@ -64,7 +71,7 @@ class AnnotationToolsWidget(QWidget):
""" """
super().__init__(parent) super().__init__(parent)
self.db_manager = db_manager self.db_manager = db_manager
self.pen_enabled = False self.polyline_enabled = False
self.current_color = QColor(255, 0, 0, 128) # Red with 50% alpha self.current_color = QColor(255, 0, 0, 128) # Red with 50% alpha
self.current_class = None self.current_class = None
@@ -75,43 +82,51 @@ class AnnotationToolsWidget(QWidget):
"""Setup user interface.""" """Setup user interface."""
layout = QVBoxLayout() layout = QVBoxLayout()
# Pen Tool Group # Polyline Tool Group
pen_group = QGroupBox("Pen Tool") polyline_group = QGroupBox("Polyline Tool")
pen_layout = QVBoxLayout() polyline_layout = QVBoxLayout()
# Enable/Disable pen # Enable/Disable polyline tool
button_layout = QHBoxLayout() button_layout = QHBoxLayout()
self.pen_toggle_btn = QPushButton("Enable Pen") self.polyline_toggle_btn = QPushButton("Start Drawing Polyline")
self.pen_toggle_btn.setCheckable(True) self.polyline_toggle_btn.setCheckable(True)
self.pen_toggle_btn.clicked.connect(self._on_pen_toggle) self.polyline_toggle_btn.clicked.connect(self._on_polyline_toggle)
button_layout.addWidget(self.pen_toggle_btn) button_layout.addWidget(self.polyline_toggle_btn)
pen_layout.addLayout(button_layout) polyline_layout.addLayout(button_layout)
# Pen width control # Polyline pen width control
width_layout = QHBoxLayout() width_layout = QHBoxLayout()
width_layout.addWidget(QLabel("Pen Width:")) width_layout.addWidget(QLabel("Pen Width:"))
self.pen_width_spin = QSpinBox() self.polyline_pen_width_spin = QSpinBox()
self.pen_width_spin.setMinimum(1) self.polyline_pen_width_spin.setMinimum(1)
self.pen_width_spin.setMaximum(20) self.polyline_pen_width_spin.setMaximum(20)
self.pen_width_spin.setValue(3) self.polyline_pen_width_spin.setValue(3)
self.pen_width_spin.valueChanged.connect(self._on_pen_width_changed) self.polyline_pen_width_spin.valueChanged.connect(
width_layout.addWidget(self.pen_width_spin) self._on_polyline_pen_width_changed
)
width_layout.addWidget(self.polyline_pen_width_spin)
width_layout.addStretch() width_layout.addStretch()
pen_layout.addLayout(width_layout) polyline_layout.addLayout(width_layout)
# Color selection # Simplification controls (RDP)
color_layout = QHBoxLayout() simplify_layout = QHBoxLayout()
color_layout.addWidget(QLabel("Color:")) self.simplify_checkbox = QCheckBox("Simplify on finish")
self.color_btn = QPushButton() self.simplify_checkbox.setChecked(True)
self.color_btn.setFixedSize(40, 30) self.simplify_checkbox.stateChanged.connect(self._on_simplify_toggle)
self.color_btn.clicked.connect(self._on_color_picker) simplify_layout.addWidget(self.simplify_checkbox)
self._update_color_button()
color_layout.addWidget(self.color_btn)
color_layout.addStretch()
pen_layout.addLayout(color_layout)
pen_group.setLayout(pen_layout) simplify_layout.addWidget(QLabel("epsilon (px):"))
layout.addWidget(pen_group) self.eps_spin = QDoubleSpinBox()
self.eps_spin.setRange(0.0, 1000.0)
self.eps_spin.setSingleStep(0.5)
self.eps_spin.setValue(2.0)
self.eps_spin.valueChanged.connect(self._on_eps_change)
simplify_layout.addWidget(self.eps_spin)
simplify_layout.addStretch()
polyline_layout.addLayout(simplify_layout)
polyline_group.setLayout(polyline_layout)
layout.addWidget(polyline_group)
# Object Class Group # Object Class Group
class_group = QGroupBox("Object Class") class_group = QGroupBox("Object Class")
@@ -122,7 +137,7 @@ class AnnotationToolsWidget(QWidget):
self.class_combo.currentIndexChanged.connect(self._on_class_selected) self.class_combo.currentIndexChanged.connect(self._on_class_selected)
class_layout.addWidget(self.class_combo) class_layout.addWidget(self.class_combo)
# Add class button # Add / manage classes
class_button_layout = QHBoxLayout() class_button_layout = QHBoxLayout()
self.add_class_btn = QPushButton("Add New Class") self.add_class_btn = QPushButton("Add New Class")
self.add_class_btn.clicked.connect(self._on_add_class) self.add_class_btn.clicked.connect(self._on_add_class)
@@ -133,6 +148,17 @@ class AnnotationToolsWidget(QWidget):
class_button_layout.addWidget(self.refresh_classes_btn) class_button_layout.addWidget(self.refresh_classes_btn)
class_layout.addLayout(class_button_layout) class_layout.addLayout(class_button_layout)
# Class color (associated with selected object class)
color_layout = QHBoxLayout()
color_layout.addWidget(QLabel("Class Color:"))
self.color_btn = QPushButton()
self.color_btn.setFixedSize(40, 30)
self.color_btn.clicked.connect(self._on_color_picker)
self._update_color_button()
color_layout.addWidget(self.color_btn)
color_layout.addStretch()
class_layout.addLayout(color_layout)
# Selected class info # Selected class info
self.class_info_label = QLabel("No class selected") self.class_info_label = QLabel("No class selected")
self.class_info_label.setWordWrap(True) self.class_info_label.setWordWrap(True)
@@ -148,24 +174,22 @@ class AnnotationToolsWidget(QWidget):
actions_group = QGroupBox("Actions") actions_group = QGroupBox("Actions")
actions_layout = QVBoxLayout() actions_layout = QVBoxLayout()
self.process_btn = QPushButton("Process Annotations") # Show / hide bounding boxes
self.process_btn.clicked.connect(self._on_process_annotations) self.show_bboxes_checkbox = QCheckBox("Show bounding boxes")
self.process_btn.setStyleSheet( self.show_bboxes_checkbox.setChecked(True)
"QPushButton { background-color: #2196F3; color: white; font-weight: bold; }" self.show_bboxes_checkbox.stateChanged.connect(self._on_show_bboxes_toggle)
) actions_layout.addWidget(self.show_bboxes_checkbox)
actions_layout.addWidget(self.process_btn)
self.show_btn = QPushButton("Show Saved Annotations")
self.show_btn.clicked.connect(self._on_show_annotations)
self.show_btn.setStyleSheet(
"QPushButton { background-color: #4CAF50; color: white; }"
)
actions_layout.addWidget(self.show_btn)
self.clear_btn = QPushButton("Clear All Annotations") self.clear_btn = QPushButton("Clear All Annotations")
self.clear_btn.clicked.connect(self._on_clear_annotations) self.clear_btn.clicked.connect(self._on_clear_annotations)
actions_layout.addWidget(self.clear_btn) actions_layout.addWidget(self.clear_btn)
# Delete currently selected annotation (enabled when a selection exists)
self.delete_selected_btn = QPushButton("Delete Selected Annotation")
self.delete_selected_btn.clicked.connect(self._on_delete_selected_annotation)
self.delete_selected_btn.setEnabled(False)
actions_layout.addWidget(self.delete_selected_btn)
actions_group.setLayout(actions_layout) actions_group.setLayout(actions_layout)
layout.addWidget(actions_group) layout.addWidget(actions_group)
@@ -193,7 +217,7 @@ class AnnotationToolsWidget(QWidget):
# Clear and repopulate combo box # Clear and repopulate combo box
self.class_combo.clear() self.class_combo.clear()
self.class_combo.addItem("-- Select Class --", None) self.class_combo.addItem("-- Select Class / Show All --", None)
for cls in classes: for cls in classes:
self.class_combo.addItem(cls["class_name"], cls) self.class_combo.addItem(cls["class_name"], cls)
@@ -206,46 +230,115 @@ class AnnotationToolsWidget(QWidget):
self, "Error", f"Failed to load object classes:\n{str(e)}" self, "Error", f"Failed to load object classes:\n{str(e)}"
) )
def _on_pen_toggle(self, checked: bool): def _on_polyline_toggle(self, checked: bool):
"""Handle pen tool enable/disable.""" """Handle polyline tool enable/disable."""
self.pen_enabled = checked self.polyline_enabled = checked
if checked: if checked:
self.pen_toggle_btn.setText("Disable Pen") self.polyline_toggle_btn.setText("Stop Drawing Polyline")
self.pen_toggle_btn.setStyleSheet( self.polyline_toggle_btn.setStyleSheet(
"QPushButton { background-color: #4CAF50; }" "QPushButton { background-color: #4CAF50; }"
) )
else: else:
self.pen_toggle_btn.setText("Enable Pen") self.polyline_toggle_btn.setText("Start Drawing Polyline")
self.pen_toggle_btn.setStyleSheet("") self.polyline_toggle_btn.setStyleSheet("")
self.pen_enabled_changed.emit(self.pen_enabled) self.polyline_enabled_changed.emit(self.polyline_enabled)
logger.debug(f"Pen tool {'enabled' if checked else 'disabled'}") logger.debug(f"Polyline tool {'enabled' if checked else 'disabled'}")
def _on_pen_width_changed(self, width: int): def _on_polyline_pen_width_changed(self, width: int):
"""Handle pen width changes.""" """Handle polyline pen width changes."""
self.pen_width_changed.emit(width) self.polyline_pen_width_changed.emit(width)
logger.debug(f"Pen width changed to {width}") logger.debug(f"Polyline pen width changed to {width}")
def _on_simplify_toggle(self, state: int):
"""Handle simplify-on-finish checkbox toggle."""
enabled = bool(state)
self.simplify_on_finish_changed.emit(enabled)
logger.debug(f"Simplify on finish set to {enabled}")
def _on_eps_change(self, val: float):
"""Handle epsilon (RDP tolerance) value changes."""
epsilon = float(val)
self.simplify_epsilon_changed.emit(epsilon)
logger.debug(f"Simplification epsilon changed to {epsilon}")
def _on_show_bboxes_toggle(self, state: int):
"""Handle 'Show bounding boxes' checkbox toggle."""
show = bool(state)
self.show_bboxes_changed.emit(show)
logger.debug(f"Show bounding boxes set to {show}")
def _on_color_picker(self): def _on_color_picker(self):
"""Open color picker dialog with alpha support.""" """Open color picker dialog and update the selected object's class color."""
if not self.current_class:
QMessageBox.warning(
self,
"No Class Selected",
"Please select an object class before changing its color.",
)
return
# Use current class color (without alpha) as the base
base_color = QColor(self.current_class.get("color", self.current_color.name()))
color = QColorDialog.getColor( color = QColorDialog.getColor(
self.current_color, base_color,
self, self,
"Select Pen Color", "Select Class Color",
QColorDialog.ShowAlphaChannel, # Enable alpha channel selection QColorDialog.ShowAlphaChannel, # Allow alpha in UI, but store RGB in DB
) )
if color.isValid(): if not color.isValid():
self.current_color = color return
self._update_color_button()
self.pen_color_changed.emit(color) # Normalize to opaque RGB for storage
logger.debug( new_color = QColor(color)
f"Pen color changed to {color.name()} with alpha {color.alpha()}" new_color.setAlpha(255)
hex_color = new_color.name()
try:
# Update in database
self.db_manager.update_object_class(
class_id=self.current_class["id"], color=hex_color
) )
except Exception as e:
logger.error(f"Failed to update class color in database: {e}")
QMessageBox.critical(
self,
"Error",
f"Failed to update class color in database:\n{str(e)}",
)
return
# Update local class data and combo box item data
self.current_class["color"] = hex_color
current_index = self.class_combo.currentIndex()
if current_index >= 0:
self.class_combo.setItemData(current_index, dict(self.current_class))
# Update info label text
info_text = f"Class: {self.current_class['class_name']}\nColor: {hex_color}"
if self.current_class.get("description"):
info_text += f"\nDescription: {self.current_class['description']}"
self.class_info_label.setText(info_text)
# Use semi-transparent version for polyline pen / button preview
class_color = QColor(hex_color)
class_color.setAlpha(128)
self.current_color = class_color
self._update_color_button()
self.polyline_pen_color_changed.emit(class_color)
logger.debug(
f"Updated class '{self.current_class['class_name']}' color to "
f"{hex_color} (polyline pen alpha={class_color.alpha()})"
)
# Notify listeners (e.g., AnnotationTab) so they can reload/redraw
self.class_color_changed.emit()
def _on_class_selected(self, index: int): def _on_class_selected(self, index: int):
"""Handle object class selection.""" """Handle object class selection (including '-- Select Class --')."""
class_data = self.class_combo.currentData() class_data = self.class_combo.currentData()
if class_data: if class_data:
@@ -260,20 +353,23 @@ class AnnotationToolsWidget(QWidget):
self.class_info_label.setText(info_text) self.class_info_label.setText(info_text)
# Update pen color to match class color with semi-transparency # Update polyline pen color to match class color with semi-transparency
class_color = QColor(class_data["color"]) class_color = QColor(class_data["color"])
if class_color.isValid(): if class_color.isValid():
# Add 50% alpha for semi-transparency # Add 50% alpha for semi-transparency
class_color.setAlpha(128) class_color.setAlpha(128)
self.current_color = class_color self.current_color = class_color
self._update_color_button() self._update_color_button()
self.pen_color_changed.emit(class_color) self.polyline_pen_color_changed.emit(class_color)
self.class_selected.emit(class_data) self.class_selected.emit(class_data)
logger.debug(f"Selected class: {class_data['class_name']}") logger.debug(f"Selected class: {class_data['class_name']}")
else: else:
# "-- Select Class --" chosen: clear current class and show all annotations
self.current_class = None self.current_class = None
self.class_info_label.setText("No class selected") self.class_info_label.setText("No class selected")
self.class_selected.emit(None)
logger.debug("Class selection cleared: showing annotations for all classes")
def _on_add_class(self): def _on_add_class(self):
"""Handle adding a new object class.""" """Handle adding a new object class."""
@@ -351,36 +447,32 @@ class AnnotationToolsWidget(QWidget):
self.clear_annotations_requested.emit() self.clear_annotations_requested.emit()
logger.debug("Clear annotations requested") logger.debug("Clear annotations requested")
def _on_process_annotations(self): def _on_delete_selected_annotation(self):
"""Handle process annotations button.""" """Handle delete selected annotation button."""
if not self.current_class: self.delete_selected_annotation_requested.emit()
QMessageBox.warning( logger.debug("Delete selected annotation requested")
self,
"No Class Selected",
"Please select an object class before processing annotations.",
)
return
self.process_annotations_requested.emit() def set_has_selected_annotation(self, has_selection: bool):
logger.debug("Process annotations requested") """
Enable/disable actions that require a selected annotation.
def _on_show_annotations(self): Args:
"""Handle show annotations button.""" has_selection: True if an annotation is currently selected on the canvas.
self.show_annotations_requested.emit() """
logger.debug("Show annotations requested") self.delete_selected_btn.setEnabled(bool(has_selection))
def get_current_class(self) -> Optional[Dict]: def get_current_class(self) -> Optional[Dict]:
"""Get currently selected object class.""" """Get currently selected object class."""
return self.current_class return self.current_class
def get_pen_color(self) -> QColor: def get_polyline_pen_color(self) -> QColor:
"""Get current pen color.""" """Get current polyline pen color."""
return self.current_color return self.current_color
def get_pen_width(self) -> int: def get_polyline_pen_width(self) -> int:
"""Get current pen width.""" """Get current polyline pen width."""
return self.pen_width_spin.value() return self.polyline_pen_width_spin.value()
def is_pen_enabled(self) -> bool: def is_polyline_enabled(self) -> bool:
"""Check if pen tool is enabled.""" """Check if polyline tool is enabled."""
return self.pen_enabled return self.polyline_enabled

View File

@@ -137,7 +137,7 @@ class ImageDisplayWidget(QWidget):
height, height,
bytes_per_line, bytes_per_line,
self.current_image.qtimage_format, self.current_image.qtimage_format,
) ).copy() # Copy to ensure Qt owns its memory after this scope
# Convert to pixmap # Convert to pixmap
pixmap = QPixmap.fromImage(qimage) pixmap = QPixmap.fromImage(qimage)

View File

@@ -5,12 +5,12 @@ Handles detection inference and result storage.
from typing import List, Dict, Optional, Callable from typing import List, Dict, Optional, Callable
from pathlib import Path from pathlib import Path
from PIL import Image
import cv2 import cv2
import numpy as np import numpy as np
from src.model.yolo_wrapper import YOLOWrapper from src.model.yolo_wrapper import YOLOWrapper
from src.database.db_manager import DatabaseManager from src.database.db_manager import DatabaseManager
from src.utils.image import Image
from src.utils.logger import get_logger from src.utils.logger import get_logger
from src.utils.file_utils import get_relative_path from src.utils.file_utils import get_relative_path
@@ -42,6 +42,7 @@ class InferenceEngine:
relative_path: str, relative_path: str,
conf: float = 0.25, conf: float = 0.25,
save_to_db: bool = True, save_to_db: bool = True,
repository_root: Optional[str] = None,
) -> Dict: ) -> Dict:
""" """
Detect objects in a single image. Detect objects in a single image.
@@ -51,49 +52,79 @@ class InferenceEngine:
relative_path: Relative path from repository root relative_path: Relative path from repository root
conf: Confidence threshold conf: Confidence threshold
save_to_db: Whether to save results to database save_to_db: Whether to save results to database
repository_root: Base directory used to compute relative_path (if known)
Returns: Returns:
Dictionary with detection results Dictionary with detection results
""" """
try: try:
# Normalize storage path (fall back to absolute path when repo root is unknown)
stored_relative_path = relative_path
if not repository_root:
stored_relative_path = str(Path(image_path).resolve())
# Get image dimensions # Get image dimensions
img = Image.open(image_path) img = Image(image_path)
width, height = img.size width = img.width
img.close() height = img.height
# Perform detection # Perform detection
detections = self.yolo.predict(image_path, conf=conf) detections = self.yolo.predict(image_path, conf=conf)
# Add/get image in database # Add/get image in database
image_id = self.db_manager.get_or_create_image( image_id = self.db_manager.get_or_create_image(
relative_path=relative_path, relative_path=stored_relative_path,
filename=Path(image_path).name, filename=Path(image_path).name,
width=width, width=width,
height=height, height=height,
) )
# Save detections to database inserted_count = 0
if save_to_db and detections: deleted_count = 0
detection_records = []
for det in detections:
# Use normalized bbox from detection
bbox_normalized = det[
"bbox_normalized"
] # [x_min, y_min, x_max, y_max]
record = { # Save detections to database, replacing any previous results for this image/model
"image_id": image_id, if save_to_db:
"model_id": self.model_id, deleted_count = self.db_manager.delete_detections_for_image(
"class_name": det["class_name"], image_id, self.model_id
"bbox": tuple(bbox_normalized), )
"confidence": det["confidence"], if detections:
"segmentation_mask": det.get("segmentation_mask"), detection_records = []
"metadata": {"class_id": det["class_id"]}, for det in detections:
} # Use normalized bbox from detection
detection_records.append(record) bbox_normalized = det[
"bbox_normalized"
] # [x_min, y_min, x_max, y_max]
self.db_manager.add_detections_batch(detection_records) metadata = {
logger.info(f"Saved {len(detection_records)} detections to database") "class_id": det["class_id"],
"source_path": str(Path(image_path).resolve()),
}
if repository_root:
metadata["repository_root"] = str(
Path(repository_root).resolve()
)
record = {
"image_id": image_id,
"model_id": self.model_id,
"class_name": det["class_name"],
"bbox": tuple(bbox_normalized),
"confidence": det["confidence"],
"segmentation_mask": det.get("segmentation_mask"),
"metadata": metadata,
}
detection_records.append(record)
inserted_count = self.db_manager.add_detections_batch(
detection_records
)
logger.info(
f"Saved {inserted_count} detections to database (replaced {deleted_count})"
)
else:
logger.info(
f"Detection run removed {deleted_count} stale entries but produced no new detections"
)
return { return {
"success": True, "success": True,
@@ -142,7 +173,12 @@ class InferenceEngine:
rel_path = get_relative_path(image_path, repository_root) rel_path = get_relative_path(image_path, repository_root)
# Perform detection # Perform detection
result = self.detect_single(image_path, rel_path, conf) result = self.detect_single(
image_path,
rel_path,
conf=conf,
repository_root=repository_root,
)
results.append(result) results.append(result)
# Update progress # Update progress

View File

@@ -7,6 +7,9 @@ from ultralytics import YOLO
from pathlib import Path from pathlib import Path
from typing import Optional, List, Dict, Callable, Any from typing import Optional, List, Dict, Callable, Any
import torch import torch
import tempfile
import os
from src.utils.image import Image
from src.utils.logger import get_logger from src.utils.logger import get_logger
@@ -55,6 +58,7 @@ class YOLOWrapper:
save_dir: str = "data/models", save_dir: str = "data/models",
name: str = "custom_model", name: str = "custom_model",
resume: bool = False, resume: bool = False,
callbacks: Optional[Dict[str, Callable]] = None,
**kwargs, **kwargs,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
@@ -69,13 +73,15 @@ class YOLOWrapper:
save_dir: Directory to save trained model save_dir: Directory to save trained model
name: Name for the training run name: Name for the training run
resume: Resume training from last checkpoint resume: Resume training from last checkpoint
callbacks: Optional Ultralytics callback dictionary
**kwargs: Additional training arguments **kwargs: Additional training arguments
Returns: Returns:
Dictionary with training results Dictionary with training results
""" """
if self.model is None: if self.model is None:
self.load_model() if not self.load_model():
raise RuntimeError(f"Failed to load model from {self.model_path}")
try: try:
logger.info(f"Starting training: {name}") logger.info(f"Starting training: {name}")
@@ -117,7 +123,8 @@ class YOLOWrapper:
Dictionary with validation metrics Dictionary with validation metrics
""" """
if self.model is None: if self.model is None:
self.load_model() if not self.load_model():
raise RuntimeError(f"Failed to load model from {self.model_path}")
try: try:
logger.info(f"Starting validation on {split} split") logger.info(f"Starting validation on {split} split")
@@ -158,12 +165,15 @@ class YOLOWrapper:
List of detection dictionaries List of detection dictionaries
""" """
if self.model is None: if self.model is None:
self.load_model() if not self.load_model():
raise RuntimeError(f"Failed to load model from {self.model_path}")
prepared_source, cleanup_path = self._prepare_source(source)
try: try:
logger.info(f"Running inference on {source}") logger.info(f"Running inference on {source}")
results = self.model.predict( results = self.model.predict(
source=source, source=prepared_source,
conf=conf, conf=conf,
iou=iou, iou=iou,
save=save, save=save,
@@ -180,6 +190,14 @@ class YOLOWrapper:
except Exception as e: except Exception as e:
logger.error(f"Error during inference: {e}") logger.error(f"Error during inference: {e}")
raise raise
finally:
if 0: # cleanup_path:
try:
os.remove(cleanup_path)
except OSError as cleanup_error:
logger.warning(
f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}"
)
def export( def export(
self, format: str = "onnx", output_path: Optional[str] = None, **kwargs self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
@@ -196,7 +214,8 @@ class YOLOWrapper:
Path to exported model Path to exported model
""" """
if self.model is None: if self.model is None:
self.load_model() if not self.load_model():
raise RuntimeError(f"Failed to load model from {self.model_path}")
try: try:
logger.info(f"Exporting model to {format} format") logger.info(f"Exporting model to {format} format")
@@ -208,6 +227,38 @@ class YOLOWrapper:
logger.error(f"Error exporting model: {e}") logger.error(f"Error exporting model: {e}")
raise raise
def _prepare_source(self, source):
"""Convert single-channel images to RGB temporarily for inference."""
cleanup_path = None
if isinstance(source, (str, Path)):
source_path = Path(source)
if source_path.is_file():
try:
img_obj = Image(source_path)
pil_img = img_obj.pil_image
if len(pil_img.getbands()) == 1:
rgb_img = img_obj.convert_grayscale_to_rgb_preserve_range()
else:
rgb_img = pil_img.convert("RGB")
suffix = source_path.suffix or ".png"
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
tmp_path = tmp.name
tmp.close()
rgb_img.save(tmp_path)
cleanup_path = tmp_path
logger.info(
f"Converted image {source_path} to RGB for inference at {tmp_path}"
)
return tmp_path, cleanup_path
except Exception as convert_error:
logger.warning(
f"Failed to preprocess {source_path} as RGB, continuing with original file: {convert_error}"
)
return source, cleanup_path
def _format_training_results(self, results) -> Dict[str, Any]: def _format_training_results(self, results) -> Dict[str, Any]:
"""Format training results into dictionary.""" """Format training results into dictionary."""
try: try:

View File

@@ -7,6 +7,7 @@ import yaml
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from src.utils.logger import get_logger from src.utils.logger import get_logger
from src.utils.image import Image
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -46,18 +47,15 @@ class ConfigManager:
"database": {"path": "data/detections.db"}, "database": {"path": "data/detections.db"},
"image_repository": { "image_repository": {
"base_path": "", "base_path": "",
"allowed_extensions": [ "allowed_extensions": Image.SUPPORTED_EXTENSIONS,
".jpg",
".jpeg",
".png",
".tif",
".tiff",
".bmp",
],
}, },
"models": { "models": {
"default_base_model": "yolov8s-seg.pt", "default_base_model": "yolov8s-seg.pt",
"models_directory": "data/models", "models_directory": "data/models",
"base_model_choices": [
"yolov8s-seg.pt",
"yolov11s-seg.pt",
],
}, },
"training": { "training": {
"default_epochs": 100, "default_epochs": 100,
@@ -65,6 +63,20 @@ class ConfigManager:
"default_imgsz": 640, "default_imgsz": 640,
"default_patience": 50, "default_patience": 50,
"default_lr0": 0.01, "default_lr0": 0.01,
"two_stage": {
"enabled": False,
"stage1": {
"epochs": 20,
"lr0": 0.0005,
"patience": 10,
"freeze": 10,
},
"stage2": {
"epochs": 150,
"lr0": 0.0003,
"patience": 30,
},
},
}, },
"detection": { "detection": {
"default_confidence": 0.25, "default_confidence": 0.25,
@@ -214,5 +226,5 @@ class ConfigManager:
def get_allowed_extensions(self) -> list: def get_allowed_extensions(self) -> list:
"""Get list of allowed image file extensions.""" """Get list of allowed image file extensions."""
return self.get( return self.get(
"image_repository.allowed_extensions", [".jpg", ".jpeg", ".png"] "image_repository.allowed_extensions", Image.SUPPORTED_EXTENSIONS
) )

View File

@@ -28,7 +28,9 @@ def get_image_files(
List of absolute paths to image files List of absolute paths to image files
""" """
if allowed_extensions is None: if allowed_extensions is None:
allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"] from src.utils.image import Image
allowed_extensions = Image.SUPPORTED_EXTENSIONS
# Normalize extensions to lowercase # Normalize extensions to lowercase
allowed_extensions = [ext.lower() for ext in allowed_extensions] allowed_extensions = [ext.lower() for ext in allowed_extensions]
@@ -204,7 +206,9 @@ def is_image_file(
True if file is an image True if file is an image
""" """
if allowed_extensions is None: if allowed_extensions is None:
allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"] from src.utils.image import Image
allowed_extensions = Image.SUPPORTED_EXTENSIONS
extension = Path(file_path).suffix.lower() extension = Path(file_path).suffix.lower()
return extension in [ext.lower() for ext in allowed_extensions] return extension in [ext.lower() for ext in allowed_extensions]

View File

@@ -277,6 +277,38 @@ class Image:
""" """
return self._channels >= 3 return self._channels >= 3
def convert_grayscale_to_rgb_preserve_range(
self,
) -> PILImage.Image:
"""Convert a single-channel PIL image to RGB while preserving dynamic range.
Returns:
PIL Image in RGB mode with intensities normalized to 0-255.
"""
if self._channels == 3:
return self.pil_image
grayscale = self.data
if grayscale.ndim == 3:
grayscale = grayscale[:, :, 0]
original_dtype = grayscale.dtype
grayscale = grayscale.astype(np.float32)
if grayscale.size == 0:
return PILImage.new("RGB", self.shape, color=(0, 0, 0))
if np.issubdtype(original_dtype, np.integer):
denom = float(max(np.iinfo(original_dtype).max, 1))
else:
max_val = float(grayscale.max())
denom = max(max_val, 1.0)
grayscale = np.clip(grayscale / denom, 0.0, 1.0)
grayscale_u8 = (grayscale * 255.0).round().astype(np.uint8)
rgb_arr = np.repeat(grayscale_u8[:, :, None], 3, axis=2)
return PILImage.fromarray(rgb_arr, mode="RGB")
def __repr__(self) -> str: def __repr__(self) -> str:
"""String representation of the Image object.""" """String representation of the Image object."""
return ( return (

View File

@@ -0,0 +1,160 @@
import numpy as np
from roifile import ImagejRoi
from tifffile import TiffFile, TiffWriter
from pathlib import Path
class UT:
"""
Docstring for UT
Operetta files along with rois drawn in ImageJ
"""
def __init__(self, roifile_fn: Path, no_labels: bool):
self.roifile_fn = roifile_fn
print("is file", self.roifile_fn.is_file())
self.rois = None
if no_labels:
self.rois = ImagejRoi.fromfile(self.roifile_fn)
self.stem = self.roifile_fn.stem.split("Roi-")[1]
else:
self.roifile_fn = roifile_fn / roifile_fn.parts[-1]
self.stem = self.roifile_fn.stem
print(self.roifile_fn)
print(self.stem)
self.image, self.image_props = self._load_images()
def _load_images(self):
"""Loading sequence of tif files
array sequence is CZYX
"""
print("Loading images:", self.roifile_fn.parent, self.stem)
fns = list(self.roifile_fn.parent.glob(f"{self.stem.lower()}*.tif*"))
stems = [fn.stem.split(self.stem)[-1] for fn in fns]
n_ch = len(set([stem.split("-ch")[-1].split("t")[0] for stem in stems]))
n_p = len(set([stem.split("-")[0] for stem in stems]))
n_t = len(set([stem.split("t")[1] for stem in stems]))
with TiffFile(fns[0]) as tif:
img = tif.asarray()
w, h = img.shape
dtype = img.dtype
self.image_props = {
"channels": n_ch,
"planes": n_p,
"tiles": n_t,
"width": w,
"height": h,
"dtype": dtype,
}
print("Image props", self.image_props)
image_stack = np.zeros((n_ch, n_p, w, h), dtype=dtype)
for fn in fns:
with TiffFile(fn) as tif:
img = tif.asarray()
stem = fn.stem.split(self.stem)[-1]
ch = int(stem.split("-ch")[-1].split("t")[0])
p = int(stem.split("-")[0].split("p")[1])
t = int(stem.split("t")[1])
print(fn.stem, "ch", ch, "p", p, "t", t)
image_stack[ch - 1, p - 1] = img
print(image_stack.shape)
return image_stack, self.image_props
@property
def width(self):
return self.image_props["width"]
@property
def height(self):
return self.image_props["height"]
@property
def nchannels(self):
return self.image_props["channels"]
@property
def nplanes(self):
return self.image_props["planes"]
def export_rois(
self,
path: Path,
subfolder: str = "labels",
class_index: int = 0,
):
"""Export rois to a file"""
with open(path / subfolder / f"{self.stem}.txt", "w") as f:
for i, roi in enumerate(self.rois):
rc = roi.subpixel_coordinates
if rc is None:
print(
f"No coordinates: {self.roifile_fn}, element {i}, out of {len(self.rois)}"
)
continue
xmn, ymn = rc.min(axis=0)
xmx, ymx = rc.max(axis=0)
xc = (xmn + xmx) / 2
yc = (ymn + ymx) / 2
bw = xmx - xmn
bh = ymx - ymn
coords = f"{xc/self.width} {yc/self.height} {bw/self.width} {bh/self.height} "
for x, y in rc:
coords += f"{x/self.width} {y/self.height} "
f.write(f"{class_index} {coords}\n")
return
def export_image(
self,
path: Path,
subfolder: str = "images",
plane_mode: str = "max projection",
channel: int = 0,
):
"""Export image to a file"""
if plane_mode == "max projection":
self.image = np.max(self.image[channel], axis=0)
print(self.image.shape)
print(path / subfolder / f"{self.stem}.tif")
with TiffWriter(path / subfolder / f"{self.stem}.tif") as tif:
tif.write(self.image)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input", nargs="*", type=Path)
parser.add_argument("-o", "--output", type=Path)
parser.add_argument(
"--no-labels",
action="store_false",
help="Source does not have labels, export only images",
)
args = parser.parse_args()
for path in args.input:
print("Path:", path)
if not args.no_labels:
print("No labels")
ut = UT(path, args.no_labels)
ut.export_image(args.output, plane_mode="max projection", channel=0)
else:
for rfn in Path(path).glob("*.zip"):
print("Roi FN:", rfn)
ut = UT(rfn, args.no_labels)
ut.export_rois(args.output, class_index=0)
ut.export_image(args.output, plane_mode="max projection", channel=0)
print()

184
tests/show_yolo_seg.py Normal file
View File

@@ -0,0 +1,184 @@
#!/usr/bin/env python3
"""
show_yolo_seg.py
Usage:
python show_yolo_seg.py /path/to/image.jpg /path/to/labels.txt
Supports:
- Segmentation polygons: "class x1 y1 x2 y2 ... xn yn"
- YOLO bbox lines as fallback: "class x_center y_center width height"
Coordinates can be normalized [0..1] or absolute pixels (auto-detected).
"""
import sys
import cv2
import numpy as np
import matplotlib.pyplot as plt
import argparse
from pathlib import Path
import random
def parse_label_line(line):
parts = line.strip().split()
if not parts:
return None
cls = int(float(parts[0]))
coords = [float(x) for x in parts[1:]]
return cls, coords
def coords_are_normalized(coords):
# If every coordinate is between 0 and 1 (inclusive-ish), assume normalized
if not coords:
return False
return max(coords) <= 1.001
def yolo_bbox_to_xyxy(coords, img_w, img_h):
# coords: [xc, yc, w, h] normalized or absolute
xc, yc, w, h = coords[:4]
if max(coords) <= 1.001:
xc *= img_w
yc *= img_h
w *= img_w
h *= img_h
x1 = int(round(xc - w / 2))
y1 = int(round(yc - h / 2))
x2 = int(round(xc + w / 2))
y2 = int(round(yc + h / 2))
return x1, y1, x2, y2
def poly_to_pts(coords, img_w, img_h):
# coords: [x1 y1 x2 y2 ...] either normalized or absolute
if coords_are_normalized(coords[4:]):
coords = [
coords[i] * (img_w if i % 2 == 0 else img_h) for i in range(len(coords))
]
pts = np.array(coords, dtype=np.int32).reshape(-1, 2)
return pts
def random_color_for_class(cls):
random.seed(cls) # deterministic per class
return tuple(int(x) for x in np.array([random.randint(0, 255) for _ in range(3)]))
def draw_annotations(img, labels, alpha=0.4, draw_bbox_for_poly=True):
# img: BGR numpy array
overlay = img.copy()
h, w = img.shape[:2]
for cls, coords in labels:
if not coords:
continue
# polygon case (>=6 coordinates)
if len(coords) >= 6:
color = random_color_for_class(cls)
x1, y1, x2, y2 = yolo_bbox_to_xyxy(coords[:4], w, h)
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
pts = poly_to_pts(coords[4:], w, h)
# fill on overlay
cv2.fillPoly(overlay, [pts], color)
# outline on base image
cv2.polylines(img, [pts], isClosed=True, color=color, thickness=2)
# put class text at first point
x, y = int(pts[0, 0]), int(pts[0, 1]) - 6
cv2.putText(
img,
str(cls),
(x, max(6, y)),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(255, 255, 255),
2,
cv2.LINE_AA,
)
# YOLO bbox case (4 coords)
elif len(coords) == 4:
x1, y1, x2, y2 = yolo_bbox_to_xyxy(coords, w, h)
color = random_color_for_class(cls)
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
cv2.putText(
img,
str(cls),
(x1, max(6, y1 - 4)),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(255, 255, 255),
2,
cv2.LINE_AA,
)
else:
# Unknown / invalid format, skip
continue
# blend overlay for filled polygons
cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0, img)
return img
def load_labels_file(label_path):
labels = []
with open(label_path, "r") as f:
for raw in f:
line = raw.strip()
if not line:
continue
parsed = parse_label_line(line)
if parsed:
labels.append(parsed)
return labels
def main():
parser = argparse.ArgumentParser(
description="Show YOLO segmentation / polygon annotations"
)
parser.add_argument("image", type=str, help="Path to image file")
parser.add_argument("labels", type=str, help="Path to YOLO label file (polygons)")
parser.add_argument(
"--alpha", type=float, default=0.4, help="Polygon fill alpha (0..1)"
)
parser.add_argument(
"--no-bbox", action="store_true", help="Don't draw bounding boxes for polygons"
)
args = parser.parse_args()
img_path = Path(args.image)
lbl_path = Path(args.labels)
if not img_path.exists():
print("Image not found:", img_path)
sys.exit(1)
if not lbl_path.exists():
print("Label file not found:", lbl_path)
sys.exit(1)
img = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
if img is None:
print("Could not load image:", img_path)
sys.exit(1)
labels = load_labels_file(str(lbl_path))
if not labels:
print("No labels parsed from", lbl_path)
# continue and just show image
out = draw_annotations(
img.copy(), labels, alpha=args.alpha, draw_bbox_for_poly=(not args.no_bbox)
)
# Convert BGR -> RGB for matplotlib display
out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(10, 10 * out.shape[0] / out.shape[1]))
plt.imshow(out_rgb)
plt.axis("off")
plt.title(f"{img_path.name} ({lbl_path.name})")
plt.show()
if __name__ == "__main__":
main()

View File

@@ -27,7 +27,7 @@ class TestImage:
def test_supported_extensions(self): def test_supported_extensions(self):
"""Test that supported extensions are correctly defined.""" """Test that supported extensions are correctly defined."""
expected_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"] expected_extensions = Image.SUPPORTED_EXTENSIONS
assert Image.SUPPORTED_EXTENSIONS == expected_extensions assert Image.SUPPORTED_EXTENSIONS == expected_extensions
def test_image_properties(self, tmp_path): def test_image_properties(self, tmp_path):