From 310e0b228575c8593a6e2802588c2dec12b539be Mon Sep 17 00:00:00 2001 From: Martin Laasmaa Date: Fri, 5 Dec 2025 15:51:16 +0200 Subject: [PATCH 01/13] Making it installabel package and switching to segmentation mode --- ARCHITECTURE.md | 25 +++-- BUILD.md | 178 +++++++++++++++++++++++++++++++ LICENSE | 21 ++++ MANIFEST.in | 37 +++++++ QUICKSTART.md | 17 ++- README.md | 62 ++++++++--- config/app_config.yaml | 2 +- main.py | 5 +- pyproject.toml | 103 ++++++++++++++++++ setup.py | 56 ++++++++++ src/__init__.py | 19 ++++ src/cli.py | 61 +++++++++++ src/database/db_manager.py | 47 ++++++-- src/database/models.py | 8 +- src/database/schema.sql | 2 + src/gui/dialogs/config_dialog.py | 4 +- src/gui/tabs/detection_tab.py | 4 +- src/model/inference.py | 37 +++++-- src/model/yolo_wrapper.py | 33 +++++- src/utils/config_manager.py | 2 +- 20 files changed, 667 insertions(+), 56 deletions(-) create mode 100644 BUILD.md create mode 100644 LICENSE create mode 100644 MANIFEST.in create mode 100644 pyproject.toml create mode 100644 setup.py create mode 100644 src/cli.py diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 4391c8e..ddc8784 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -2,11 +2,11 @@ ## Project Overview -A desktop application for detecting organelles and membrane branching structures in microscopy images using YOLOv8s, with comprehensive training, validation, and visualization capabilities. +A desktop application for detecting and segmenting organelles and membrane branching structures in microscopy images using YOLOv8s-seg, with comprehensive training, validation, and visualization capabilities including pixel-accurate segmentation masks. ## Technology Stack -- **ML Framework**: Ultralytics YOLOv8 (YOLOv8s.pt model) +- **ML Framework**: Ultralytics YOLOv8 (YOLOv8s-seg.pt segmentation model) - **GUI Framework**: PySide6 (Qt6 for Python) - **Visualization**: pyqtgraph - **Database**: SQLite3 @@ -110,6 +110,7 @@ erDiagram float x_max float y_max float confidence + text segmentation_mask datetime detected_at json metadata } @@ -122,6 +123,7 @@ erDiagram float y_min float x_max float y_max + text segmentation_mask string annotator datetime created_at boolean verified @@ -139,7 +141,7 @@ Stores information about trained models and their versions. | model_name | TEXT | NOT NULL | User-friendly model name | | model_version | TEXT | NOT NULL | Version string (e.g., "v1.0") | | model_path | TEXT | NOT NULL | Path to model weights file | -| base_model | TEXT | NOT NULL | Base model used (e.g., "yolov8s.pt") | +| base_model | TEXT | NOT NULL | Base model used (e.g., "yolov8s-seg.pt") | | created_at | TIMESTAMP | DEFAULT CURRENT_TIMESTAMP | Model creation timestamp | | training_params | JSON | | Training hyperparameters | | metrics | JSON | | Validation metrics (mAP, precision, recall) | @@ -159,7 +161,7 @@ Stores metadata about microscopy images. | checksum | TEXT | | MD5 hash for integrity verification | #### **detections** table -Stores object detection results. +Stores object detection results with optional segmentation masks. | Column | Type | Constraints | Description | |--------|------|-------------|-------------| @@ -172,11 +174,12 @@ Stores object detection results. | x_max | REAL | NOT NULL | Bounding box right coordinate (normalized 0-1) | | y_max | REAL | NOT NULL | Bounding box bottom coordinate (normalized 0-1) | | confidence | REAL | NOT NULL | Detection confidence score (0-1) | +| segmentation_mask | TEXT | | JSON array of polygon coordinates [[x1,y1], [x2,y2], ...] (normalized 0-1) | | detected_at | TIMESTAMP | DEFAULT CURRENT_TIMESTAMP | When detection was performed | | metadata | JSON | | Additional metadata (processing time, etc.) | #### **annotations** table -Stores manual annotations for training data (future feature). +Stores manual annotations for training data with optional segmentation masks (future feature). | Column | Type | Constraints | Description | |--------|------|-------------|-------------| @@ -187,6 +190,7 @@ Stores manual annotations for training data (future feature). | y_min | REAL | NOT NULL | Bounding box top coordinate (normalized) | | x_max | REAL | NOT NULL | Bounding box right coordinate (normalized) | | y_max | REAL | NOT NULL | Bounding box bottom coordinate (normalized) | +| segmentation_mask | TEXT | | JSON array of polygon coordinates [[x1,y1], [x2,y2], ...] (normalized 0-1) | | annotator | TEXT | | Name of person who created annotation | | created_at | TIMESTAMP | DEFAULT CURRENT_TIMESTAMP | Annotation timestamp | | verified | BOOLEAN | DEFAULT 0 | Whether annotation is verified | @@ -245,8 +249,9 @@ graph TB ### Key Components #### 1. **YOLO Wrapper** ([`src/model/yolo_wrapper.py`](src/model/yolo_wrapper.py)) -Encapsulates YOLOv8 operations: -- Load pre-trained YOLOv8s model +Encapsulates YOLOv8-seg operations: +- Load pre-trained YOLOv8s-seg segmentation model +- Extract pixel-accurate segmentation masks - Fine-tune on custom microscopy dataset - Export trained models - Provide training progress callbacks @@ -255,10 +260,10 @@ Encapsulates YOLOv8 operations: **Key Methods:** ```python class YOLOWrapper: - def __init__(self, model_path: str = "yolov8s.pt") + def __init__(self, model_path: str = "yolov8s-seg.pt") def train(self, data_yaml: str, epochs: int, callbacks: dict) def validate(self, data_yaml: str) -> dict - def predict(self, image_path: str, conf: float) -> list + def predict(self, image_path: str, conf: float) -> list # Returns detections with segmentation masks def export_model(self, format: str, output_path: str) ``` @@ -435,7 +440,7 @@ image_repository: allowed_extensions: [".jpg", ".jpeg", ".png", ".tif", ".tiff"] models: - default_base_model: "yolov8s.pt" + default_base_model: "yolov8s-seg.pt" models_directory: "data/models" training: diff --git a/BUILD.md b/BUILD.md new file mode 100644 index 0000000..ab19583 --- /dev/null +++ b/BUILD.md @@ -0,0 +1,178 @@ +# Building and Publishing Guide + +This guide explains how to build and publish the microscopy-object-detection package. + +## Prerequisites + +```bash +pip install build twine +``` + +## Building the Package + +### 1. Clean Previous Builds + +```bash +rm -rf build/ dist/ *.egg-info +``` + +### 2. Build Distribution Archives + +```bash +python -m build +``` + +This will create both wheel (`.whl`) and source distribution (`.tar.gz`) in the `dist/` directory. + +### 3. Verify the Build + +```bash +ls dist/ +# Should show: +# microscopy_object_detection-1.0.0-py3-none-any.whl +# microscopy_object_detection-1.0.0.tar.gz +``` + +## Testing the Package Locally + +### Install in Development Mode + +```bash +pip install -e . +``` + +### Install from Built Package + +```bash +pip install dist/microscopy_object_detection-1.0.0-py3-none-any.whl +``` + +### Test the Installation + +```bash +# Test CLI +microscopy-detect --version + +# Test GUI launcher +microscopy-detect-gui +``` + +## Publishing to PyPI + +### 1. Configure PyPI Credentials + +Create or update `~/.pypirc`: + +```ini +[pypi] +username = __token__ +password = pypi-YOUR-API-TOKEN-HERE +``` + +### 2. Upload to Test PyPI (Recommended First) + +```bash +python -m twine upload --repository testpypi dist/* +``` + +Then test installation: + +```bash +pip install --index-url https://test.pypi.org/simple/ microscopy-object-detection +``` + +### 3. Upload to PyPI + +```bash +python -m twine upload dist/* +``` + +## Version Management + +Update version in multiple files: +- `setup.py`: Update `version` parameter +- `pyproject.toml`: Update `version` field +- `src/__init__.py`: Update `__version__` variable + +## Git Tags + +After publishing, tag the release: + +```bash +git tag -a v1.0.0 -m "Release version 1.0.0" +git push origin v1.0.0 +``` + +## Package Structure + +The built package includes: +- All Python source files in `src/` +- Configuration files in `config/` +- Database schema file (`src/database/schema.sql`) +- Documentation files (README.md, LICENSE, etc.) +- Entry points for CLI and GUI + +## Troubleshooting + +### Import Errors +If you get import errors, ensure: +- All `__init__.py` files are present +- Package structure follows the setup configuration +- Dependencies are listed in `requirements.txt` + +### Missing Files +If files are missing in the built package: +- Check `MANIFEST.in` includes the required patterns +- Check `pyproject.toml` package-data configuration +- Rebuild with `python -m build --no-isolation` for debugging + +### Version Conflicts +If version conflicts occur: +- Ensure version is consistent across all files +- Clear build artifacts and rebuild +- Check for cached installations: `pip list | grep microscopy` + +## CI/CD Integration + +### GitHub Actions Example + +```yaml +name: Build and Publish + +on: + release: + types: [created] + +jobs: + deploy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: '3.8' + - name: Install dependencies + run: | + pip install build twine + - name: Build package + run: python -m build + - name: Publish to PyPI + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} + run: twine upload dist/* +``` + +## Best Practices + +1. **Version Bumping**: Use semantic versioning (MAJOR.MINOR.PATCH) +2. **Testing**: Always test on Test PyPI before publishing to PyPI +3. **Documentation**: Update README.md and CHANGELOG.md for each release +4. **Git Tags**: Tag releases in git for easy reference +5. **Dependencies**: Keep requirements.txt updated and specify version ranges + +## Resources + +- [Python Packaging Guide](https://packaging.python.org/) +- [setuptools Documentation](https://setuptools.pypa.io/) +- [PyPI Publishing Guide](https://packaging.python.org/tutorials/packaging-projects/) \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..f7df838 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Your Name + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..6246c45 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,37 @@ +# Include documentation files +include README.md +include LICENSE +include ARCHITECTURE.md +include IMPLEMENTATION_GUIDE.md +include QUICKSTART.md +include PLAN_SUMMARY.md + +# Include requirements +include requirements.txt + +# Include configuration files +recursive-include config *.yaml +recursive-include config *.yml + +# Include database schema +recursive-include src/database *.sql + +# Include tests +recursive-include tests *.py + +# Exclude compiled Python files +global-exclude *.pyc +global-exclude *.pyo +global-exclude __pycache__ +global-exclude *.so +global-exclude .DS_Store + +# Exclude git and IDE files +global-exclude .git* +global-exclude .vscode +global-exclude .idea + +# Exclude build artifacts +prune build +prune dist +prune *.egg-info \ No newline at end of file diff --git a/QUICKSTART.md b/QUICKSTART.md index e211216..0c97a79 100644 --- a/QUICKSTART.md +++ b/QUICKSTART.md @@ -38,7 +38,7 @@ This will install: - OpenCV and Pillow (image processing) - And other dependencies -**Note:** The first run will automatically download the YOLOv8s.pt model (~22MB). +**Note:** The first run will automatically download the YOLOv8s-seg.pt segmentation model (~23MB). ### 4. Verify Installation @@ -84,11 +84,11 @@ In the Settings dialog: ### Single Image Detection 1. Go to the **Detection** tab -2. Select a model from the dropdown (default: Base Model yolov8s.pt) +2. Select a model from the dropdown (default: Base Model yolov8s-seg.pt) 3. Adjust confidence threshold with the slider 4. Click "Detect Single Image" 5. Select an image file -6. View results in the results panel +6. View results with segmentation masks overlaid on the image ### Batch Detection @@ -108,9 +108,18 @@ Detection results include: - **Class names**: Types of objects detected (e.g., organelle, membrane_branch) - **Confidence scores**: Detection confidence (0-1) - **Bounding boxes**: Object locations (stored in database) +- **Segmentation masks**: Pixel-accurate polygon coordinates for each detected object All results are stored in the SQLite database at [`data/detections.db`](data/detections.db). +### Segmentation Visualization + +The application automatically displays segmentation masks when available: +- Semi-transparent colored overlay (30% opacity) showing the exact shape of detected objects +- Polygon contours outlining each segmentation +- Color-coded by object class +- Toggle-able in future versions + ## Database The application uses SQLite to store: @@ -176,7 +185,7 @@ sudo apt-get install libxcb-xinerama0 ### Detection Not Working **No models available** -- The base YOLOv8s model will be downloaded automatically on first use +- The base YOLOv8s-seg segmentation model will be downloaded automatically on first use - Make sure you have internet connection for the first run **Images not found** diff --git a/README.md b/README.md index 61f376f..f12cb79 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Microscopy Object Detection Application -A desktop application for detecting organelles and membrane branching structures in microscopy images using YOLOv8, featuring comprehensive training, validation, and visualization capabilities. +A desktop application for detecting and segmenting organelles and membrane branching structures in microscopy images using YOLOv8-seg, featuring comprehensive training, validation, and visualization capabilities with pixel-accurate segmentation masks. ![Python](https://img.shields.io/badge/python-3.8+-blue.svg) ![PySide6](https://img.shields.io/badge/PySide6-6.5+-green.svg) @@ -8,8 +8,8 @@ A desktop application for detecting organelles and membrane branching structures ## Features -- **🎯 Object Detection**: Real-time and batch detection of microscopy objects -- **🎓 Model Training**: Fine-tune YOLOv8s on custom microscopy datasets +- **🎯 Object Detection & Segmentation**: Real-time and batch detection with pixel-accurate segmentation masks +- **🎓 Model Training**: Fine-tune YOLOv8s-seg on custom microscopy datasets - **📊 Validation & Metrics**: Comprehensive model validation with visualization - **💾 Database Storage**: SQLite database for detection results and metadata - **📈 Visualization**: Interactive plots and charts using pyqtgraph @@ -34,14 +34,24 @@ A desktop application for detecting organelles and membrane branching structures ## Installation -### 1. Clone the Repository +### Option 1: Install from PyPI (Recommended) + +```bash +pip install microscopy-object-detection +``` + +This will install the package and all its dependencies. + +### Option 2: Install from Source + +#### 1. Clone the Repository ```bash git clone cd object_detection ``` -### 2. Create Virtual Environment +#### 2. Create Virtual Environment ```bash # Linux/Mac @@ -53,25 +63,44 @@ python -m venv venv venv\Scripts\activate ``` -### 3. Install Dependencies +#### 3. Install in Development Mode ```bash -pip install -r requirements.txt +# Install in editable mode with dev dependencies +pip install -e ".[dev]" + +# Or install just the package +pip install . ``` ### 4. Download Base Model -The application will automatically download the YOLOv8s.pt model on first use, or you can download it manually: +The application will automatically download the YOLOv8s-seg.pt segmentation model on first use, or you can download it manually: ```bash # The model will be downloaded automatically by ultralytics # Or download manually from: https://github.com/ultralytics/assets/releases ``` +**Note:** YOLOv8s-seg is a segmentation model that provides pixel-accurate masks for detected objects, enabling more precise analysis than standard bounding box detection. + ## Quick Start ### 1. Launch the Application +After installation, you can launch the application in two ways: + +**Using the GUI launcher:** +```bash +microscopy-detect-gui +``` + +**Or using Python directly:** +```bash +python -m microscopy_object_detection +``` + +**If installed from source:** ```bash python main.py ``` @@ -85,11 +114,12 @@ python main.py ### 3. Perform Detection 1. Navigate to the **Detection** tab -2. Select a model (default: yolov8s.pt) +2. Select a model (default: yolov8s-seg.pt) 3. Choose an image or folder 4. Set confidence threshold 5. Click **Detect** -6. View results and save to database +6. View results with segmentation masks overlaid +7. Save results to database ### 4. Train Custom Model @@ -212,8 +242,8 @@ The application uses SQLite with the following main tables: - **models**: Stores trained model information and metrics - **images**: Stores image metadata and paths -- **detections**: Stores detection results with bounding boxes -- **annotations**: Stores manual annotations (future feature) +- **detections**: Stores detection results with bounding boxes and segmentation masks (polygon coordinates) +- **annotations**: Stores manual annotations with optional segmentation masks (future feature) See [`ARCHITECTURE.md`](ARCHITECTURE.md) for detailed schema information. @@ -230,7 +260,7 @@ image_repository: allowed_extensions: [".jpg", ".jpeg", ".png", ".tif", ".tiff"] models: - default_base_model: "yolov8s.pt" + default_base_model: "yolov8s-seg.pt" models_directory: "data/models" training: @@ -258,7 +288,7 @@ visualization: from src.model.yolo_wrapper import YOLOWrapper # Initialize wrapper -yolo = YOLOWrapper("yolov8s.pt") +yolo = YOLOWrapper("yolov8s-seg.pt") # Train model results = yolo.train( @@ -393,10 +423,10 @@ make html **Issue**: Model not found error -**Solution**: Ensure YOLOv8s.pt is downloaded. Run: +**Solution**: Ensure YOLOv8s-seg.pt is downloaded. Run: ```python from ultralytics import YOLO -model = YOLO('yolov8s.pt') # Will auto-download +model = YOLO('yolov8s-seg.pt') # Will auto-download ``` diff --git a/config/app_config.yaml b/config/app_config.yaml index e47c74a..be8d2d1 100644 --- a/config/app_config.yaml +++ b/config/app_config.yaml @@ -10,7 +10,7 @@ image_repository: - .tiff - .bmp models: - default_base_model: yolov8s.pt + default_base_model: yolov8s-seg.pt models_directory: data/models training: default_epochs: 100 diff --git a/main.py b/main.py index 77b6a30..652272f 100644 --- a/main.py +++ b/main.py @@ -6,12 +6,13 @@ Main entry point for the application. import sys from pathlib import Path -# Add src directory to path +# Add src directory to path for development mode sys.path.insert(0, str(Path(__file__).parent)) from PySide6.QtWidgets import QApplication from PySide6.QtCore import Qt +from src import __version__ from src.gui.main_window import MainWindow from src.utils.logger import setup_logging from src.utils.config_manager import ConfigManager @@ -37,7 +38,7 @@ def main(): app = QApplication(sys.argv) app.setApplicationName("Microscopy Object Detection") app.setOrganizationName("MicroscopyLab") - app.setApplicationVersion("1.0.0") + app.setApplicationVersion(__version__) # Set application style app.setStyle("Fusion") diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..d5993d6 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,103 @@ +[build-system] +requires = ["setuptools>=45", "wheel", "setuptools_scm[toml]>=6.2"] +build-backend = "setuptools.build_meta" + +[project] +name = "microscopy-object-detection" +version = "1.0.0" +description = "Desktop application for detecting and segmenting organelles in microscopy images using YOLOv8-seg" +readme = "README.md" +requires-python = ">=3.8" +license = { text = "MIT" } +authors = [{ name = "Your Name", email = "your.email@example.com" }] +keywords = [ + "microscopy", + "yolov8", + "object-detection", + "segmentation", + "computer-vision", + "deep-learning", +] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering :: Image Recognition", + "Topic :: Scientific/Engineering :: Bio-Informatics", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Operating System :: OS Independent", +] + +dependencies = [ + "ultralytics>=8.0.0", + "PySide6>=6.5.0", + "pyqtgraph>=0.13.0", + "numpy>=1.24.0", + "opencv-python>=4.8.0", + "Pillow>=10.0.0", + "PyYAML>=6.0", + "pandas>=2.0.0", + "openpyxl>=3.1.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0.0", + "pytest-cov>=4.0.0", + "black>=23.0.0", + "pylint>=2.17.0", + "mypy>=1.0.0", +] + +[project.urls] +Homepage = "https://github.com/yourusername/object_detection" +Documentation = "https://github.com/yourusername/object_detection/blob/main/README.md" +Repository = "https://github.com/yourusername/object_detection" +"Bug Tracker" = "https://github.com/yourusername/object_detection/issues" + +[project.scripts] +microscopy-detect = "src.cli:main" + +[project.gui-scripts] +microscopy-detect-gui = "main:main" + +[tool.setuptools] +package-dir = { "" = "." } +packages = [ + "src", + "src.database", + "src.model", + "src.gui", + "src.gui.tabs", + "src.gui.dialogs", + "src.gui.widgets", + "src.utils", +] + +[tool.setuptools.package-data] +src = ["database/*.sql"] +"" = ["config/*.yaml"] + +[tool.black] +line-length = 88 +target-version = ['py38', 'py39', 'py310', 'py311'] +include = '\.pyi?$' + +[tool.pylint.messages_control] +max-line-length = 88 + +[tool.mypy] +python_version = "3.8" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = false + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_functions = ["test_*"] +addopts = "-v --cov=src --cov-report=term-missing" diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..59a43a1 --- /dev/null +++ b/setup.py @@ -0,0 +1,56 @@ +"""Setup script for Microscopy Object Detection Application.""" + +from setuptools import setup, find_packages +from pathlib import Path + +# Read the contents of README file +this_directory = Path(__file__).parent +long_description = (this_directory / "README.md").read_text(encoding="utf-8") + +# Read requirements +requirements = (this_directory / "requirements.txt").read_text().splitlines() +requirements = [ + req.strip() for req in requirements if req.strip() and not req.startswith("#") +] + +setup( + name="microscopy-object-detection", + version="1.0.0", + author="Your Name", + author_email="your.email@example.com", + description="Desktop application for detecting and segmenting organelles in microscopy images using YOLOv8-seg", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/yourusername/object_detection", + packages=find_packages(exclude=["tests", "tests.*", "docs"]), + include_package_data=True, + install_requires=requirements, + python_requires=">=3.8", + classifiers=[ + "Development Status :: 4 - Beta", + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering :: Image Recognition", + "Topic :: Scientific/Engineering :: Bio-Informatics", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Operating System :: OS Independent", + ], + entry_points={ + "console_scripts": [ + "microscopy-detect=src.cli:main", + ], + "gui_scripts": [ + "microscopy-detect-gui=main:main", + ], + }, + keywords="microscopy yolov8 object-detection segmentation computer-vision deep-learning", + project_urls={ + "Bug Reports": "https://github.com/yourusername/object_detection/issues", + "Source": "https://github.com/yourusername/object_detection", + "Documentation": "https://github.com/yourusername/object_detection/blob/main/README.md", + }, +) diff --git a/src/__init__.py b/src/__init__.py index e69de29..424aa41 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -0,0 +1,19 @@ +""" +Microscopy Object Detection Application + +A desktop application for detecting and segmenting organelles and membrane +branching structures in microscopy images using YOLOv8-seg. +""" + +__version__ = "1.0.0" +__author__ = "Your Name" +__email__ = "your.email@example.com" +__license__ = "MIT" + +# Package metadata +__all__ = [ + "__version__", + "__author__", + "__email__", + "__license__", +] diff --git a/src/cli.py b/src/cli.py new file mode 100644 index 0000000..482ee52 --- /dev/null +++ b/src/cli.py @@ -0,0 +1,61 @@ +""" +Command-line interface for microscopy object detection application. +""" + +import sys +import argparse +from pathlib import Path + +from src import __version__ + + +def main(): + """Main CLI entry point.""" + parser = argparse.ArgumentParser( + description="Microscopy Object Detection Application - CLI Interface", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Launch GUI + microscopy-detect-gui + + # Show version + microscopy-detect --version + + # Get help + microscopy-detect --help + """, + ) + + parser.add_argument( + "--version", + action="version", + version=f"microscopy-object-detection {__version__}", + ) + + parser.add_argument( + "--gui", + action="store_true", + help="Launch the GUI application (same as microscopy-detect-gui)", + ) + + args = parser.parse_args() + + if args.gui: + # Launch GUI + try: + from main import main as gui_main + + gui_main() + except Exception as e: + print(f"Error launching GUI: {e}", file=sys.stderr) + sys.exit(1) + else: + # Show help if no arguments provided + parser.print_help() + print("\nTo launch the GUI, use: microscopy-detect-gui") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/database/db_manager.py b/src/database/db_manager.py index 4329499..5dcf5c2 100644 --- a/src/database/db_manager.py +++ b/src/database/db_manager.py @@ -6,7 +6,7 @@ Handles all database operations including CRUD operations, queries, and exports. import sqlite3 import json from datetime import datetime -from typing import List, Dict, Optional, Tuple, Any +from typing import List, Dict, Optional, Tuple, Any, Union from pathlib import Path import csv import hashlib @@ -56,7 +56,7 @@ class DatabaseManager: model_name: str, model_version: str, model_path: str, - base_model: str = "yolov8s.pt", + base_model: str = "yolov8s-seg.pt", training_params: Optional[Dict] = None, metrics: Optional[Dict] = None, ) -> int: @@ -243,6 +243,7 @@ class DatabaseManager: class_name: str, bbox: Tuple[float, float, float, float], # (x_min, y_min, x_max, y_max) confidence: float, + segmentation_mask: Optional[List[List[float]]] = None, metadata: Optional[Dict] = None, ) -> int: """ @@ -254,6 +255,7 @@ class DatabaseManager: class_name: Detected object class bbox: Bounding box coordinates (normalized 0-1) confidence: Detection confidence score + segmentation_mask: Polygon coordinates for segmentation [[x1,y1], [x2,y2], ...] metadata: Additional metadata Returns: @@ -265,8 +267,8 @@ class DatabaseManager: x_min, y_min, x_max, y_max = bbox cursor.execute( """ - INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, metadata) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, segmentation_mask, metadata) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( image_id, @@ -277,6 +279,7 @@ class DatabaseManager: x_max, y_max, confidence, + json.dumps(segmentation_mask) if segmentation_mask else None, json.dumps(metadata) if metadata else None, ), ) @@ -302,8 +305,8 @@ class DatabaseManager: bbox = det["bbox"] cursor.execute( """ - INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, metadata) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, segmentation_mask, metadata) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( det["image_id"], @@ -314,6 +317,11 @@ class DatabaseManager: bbox[2], bbox[3], det["confidence"], + ( + json.dumps(det.get("segmentation_mask")) + if det.get("segmentation_mask") + else None + ), ( json.dumps(det.get("metadata")) if det.get("metadata") @@ -385,9 +393,11 @@ class DatabaseManager: detections = [] for row in cursor.fetchall(): det = dict(row) - # Parse JSON metadata + # Parse JSON fields if det.get("metadata"): det["metadata"] = json.loads(det["metadata"]) + if det.get("segmentation_mask"): + det["segmentation_mask"] = json.loads(det["segmentation_mask"]) detections.append(det) return detections @@ -538,6 +548,7 @@ class DatabaseManager: "x_max", "y_max", "confidence", + "segmentation_mask", "detected_at", ] writer = csv.DictWriter(csvfile, fieldnames=fieldnames) @@ -545,6 +556,11 @@ class DatabaseManager: for det in detections: row = {k: det[k] for k in fieldnames if k in det} + # Convert segmentation mask list to JSON string for CSV + if row.get("segmentation_mask") and isinstance( + row["segmentation_mask"], list + ): + row["segmentation_mask"] = json.dumps(row["segmentation_mask"]) writer.writerow(row) return True @@ -580,6 +596,7 @@ class DatabaseManager: class_name: str, bbox: Tuple[float, float, float, float], annotator: str, + segmentation_mask: Optional[List[List[float]]] = None, verified: bool = False, ) -> int: """Add manual annotation.""" @@ -589,10 +606,20 @@ class DatabaseManager: x_min, y_min, x_max, y_max = bbox cursor.execute( """ - INSERT INTO annotations (image_id, class_name, x_min, y_min, x_max, y_max, annotator, verified) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO annotations (image_id, class_name, x_min, y_min, x_max, y_max, segmentation_mask, annotator, verified) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) """, - (image_id, class_name, x_min, y_min, x_max, y_max, annotator, verified), + ( + image_id, + class_name, + x_min, + y_min, + x_max, + y_max, + json.dumps(segmentation_mask) if segmentation_mask else None, + annotator, + verified, + ), ) conn.commit() return cursor.lastrowid diff --git a/src/database/models.py b/src/database/models.py index cdcd237..6659954 100644 --- a/src/database/models.py +++ b/src/database/models.py @@ -5,7 +5,7 @@ These dataclasses represent the database entities. from dataclasses import dataclass from datetime import datetime -from typing import Optional, Dict, Tuple +from typing import Optional, Dict, Tuple, List @dataclass @@ -46,6 +46,9 @@ class Detection: class_name: str bbox: Tuple[float, float, float, float] # (x_min, y_min, x_max, y_max) confidence: float + segmentation_mask: Optional[ + List[List[float]] + ] # List of polygon coordinates [[x1,y1], [x2,y2], ...] detected_at: datetime metadata: Optional[Dict] @@ -58,6 +61,9 @@ class Annotation: image_id: int class_name: str bbox: Tuple[float, float, float, float] # (x_min, y_min, x_max, y_max) + segmentation_mask: Optional[ + List[List[float]] + ] # List of polygon coordinates [[x1,y1], [x2,y2], ...] annotator: str created_at: datetime verified: bool diff --git a/src/database/schema.sql b/src/database/schema.sql index f6080f7..b09ffee 100644 --- a/src/database/schema.sql +++ b/src/database/schema.sql @@ -37,6 +37,7 @@ CREATE TABLE IF NOT EXISTS detections ( x_max REAL NOT NULL CHECK(x_max >= 0 AND x_max <= 1), y_max REAL NOT NULL CHECK(y_max >= 0 AND y_max <= 1), confidence REAL NOT NULL CHECK(confidence >= 0 AND confidence <= 1), + segmentation_mask TEXT, -- JSON string of polygon coordinates [[x1,y1], [x2,y2], ...] detected_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, metadata TEXT, -- JSON string for additional metadata FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE, @@ -52,6 +53,7 @@ CREATE TABLE IF NOT EXISTS annotations ( y_min REAL NOT NULL CHECK(y_min >= 0 AND y_min <= 1), x_max REAL NOT NULL CHECK(x_max >= 0 AND x_max <= 1), y_max REAL NOT NULL CHECK(y_max >= 0 AND y_max <= 1), + segmentation_mask TEXT, -- JSON string of polygon coordinates [[x1,y1], [x2,y2], ...] annotator TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, verified BOOLEAN DEFAULT 0, diff --git a/src/gui/dialogs/config_dialog.py b/src/gui/dialogs/config_dialog.py index 9abfe1b..27ed4a7 100644 --- a/src/gui/dialogs/config_dialog.py +++ b/src/gui/dialogs/config_dialog.py @@ -121,7 +121,7 @@ class ConfigDialog(QDialog): models_layout.addRow("Models Directory:", self.models_dir_edit) self.base_model_edit = QLineEdit() - self.base_model_edit.setPlaceholderText("yolov8s.pt") + self.base_model_edit.setPlaceholderText("yolov8s-seg.pt") models_layout.addRow("Default Base Model:", self.base_model_edit) models_group.setLayout(models_layout) @@ -232,7 +232,7 @@ class ConfigDialog(QDialog): self.config_manager.get("models.models_directory", "data/models") ) self.base_model_edit.setText( - self.config_manager.get("models.default_base_model", "yolov8s.pt") + self.config_manager.get("models.default_base_model", "yolov8s-seg.pt") ) # Training settings diff --git a/src/gui/tabs/detection_tab.py b/src/gui/tabs/detection_tab.py index 4fe71ce..01a3861 100644 --- a/src/gui/tabs/detection_tab.py +++ b/src/gui/tabs/detection_tab.py @@ -159,7 +159,7 @@ class DetectionTab(QWidget): # Add base model option base_model = self.config_manager.get( - "models.default_base_model", "yolov8s.pt" + "models.default_base_model", "yolov8s-seg.pt" ) self.model_combo.addItem( f"Base Model ({base_model})", {"id": 0, "path": base_model} @@ -256,7 +256,7 @@ class DetectionTab(QWidget): if model_id == 0: # Create database entry for base model base_model = self.config_manager.get( - "models.default_base_model", "yolov8s.pt" + "models.default_base_model", "yolov8s-seg.pt" ) model_id = self.db_manager.add_model( model_name="Base Model", diff --git a/src/model/inference.py b/src/model/inference.py index 1fc5ab8..2a3780b 100644 --- a/src/model/inference.py +++ b/src/model/inference.py @@ -87,6 +87,7 @@ class InferenceEngine: "class_name": det["class_name"], "bbox": tuple(bbox_normalized), "confidence": det["confidence"], + "segmentation_mask": det.get("segmentation_mask"), "metadata": {"class_id": det["class_id"]}, } detection_records.append(record) @@ -160,6 +161,7 @@ class InferenceEngine: conf: float = 0.25, bbox_thickness: int = 2, bbox_colors: Optional[Dict[str, str]] = None, + draw_masks: bool = True, ) -> tuple: """ Detect objects and return annotated image. @@ -169,6 +171,7 @@ class InferenceEngine: conf: Confidence threshold bbox_thickness: Thickness of bounding boxes bbox_colors: Dictionary mapping class names to hex colors + draw_masks: Whether to draw segmentation masks (if available) Returns: Tuple of (detections, annotated_image_array) @@ -189,12 +192,8 @@ class InferenceEngine: bbox_colors = {} default_color = self._hex_to_bgr(bbox_colors.get("default", "#00FF00")) - # Draw bounding boxes + # Draw detections for det in detections: - # Get absolute coordinates - bbox_abs = det["bbox_absolute"] - x1, y1, x2, y2 = [int(v) for v in bbox_abs] - # Get color for this class class_name = det["class_name"] color_hex = bbox_colors.get( @@ -202,7 +201,33 @@ class InferenceEngine: ) color = self._hex_to_bgr(color_hex) - # Draw box + # Draw segmentation mask if available and requested + if draw_masks and det.get("segmentation_mask"): + mask_normalized = det["segmentation_mask"] + if mask_normalized and len(mask_normalized) > 0: + # Convert normalized coordinates to absolute pixels + mask_points = np.array( + [ + [int(pt[0] * width), int(pt[1] * height)] + for pt in mask_normalized + ], + dtype=np.int32, + ) + + # Create a semi-transparent overlay + overlay = img.copy() + cv2.fillPoly(overlay, [mask_points], color) + # Blend with original image (30% opacity) + cv2.addWeighted(overlay, 0.3, img, 0.7, 0, img) + + # Draw mask contour + cv2.polylines(img, [mask_points], True, color, bbox_thickness) + + # Get absolute coordinates for bounding box + bbox_abs = det["bbox_absolute"] + x1, y1, x2, y2 = [int(v) for v in bbox_abs] + + # Draw bounding box cv2.rectangle(img, (x1, y1), (x2, y2), color, bbox_thickness) # Prepare label diff --git a/src/model/yolo_wrapper.py b/src/model/yolo_wrapper.py index d4a8050..fa1fd8a 100644 --- a/src/model/yolo_wrapper.py +++ b/src/model/yolo_wrapper.py @@ -16,7 +16,7 @@ logger = get_logger(__name__) class YOLOWrapper: """Wrapper for YOLOv8 model operations.""" - def __init__(self, model_path: str = "yolov8s.pt"): + def __init__(self, model_path: str = "yolov8s-seg.pt"): """ Initialize YOLO model. @@ -282,6 +282,10 @@ class YOLOWrapper: boxes = result.boxes image_path = str(result.path) orig_shape = result.orig_shape # (height, width) + height, width = orig_shape + + # Check if this is a segmentation model with masks + has_masks = hasattr(result, "masks") and result.masks is not None for i in range(len(boxes)): # Get normalized coordinates @@ -299,6 +303,33 @@ class YOLOWrapper: float(v) for v in boxes.xyxy[i].cpu().numpy() ], # Absolute pixels } + + # Extract segmentation mask if available + if has_masks: + try: + # Get the mask for this detection + mask_data = result.masks.xy[ + i + ] # Polygon coordinates in absolute pixels + + # Convert to normalized coordinates + if len(mask_data) > 0: + mask_normalized = [] + for point in mask_data: + x_norm = float(point[0]) / width + y_norm = float(point[1]) / height + mask_normalized.append([x_norm, y_norm]) + detection["segmentation_mask"] = mask_normalized + else: + detection["segmentation_mask"] = None + except Exception as mask_error: + logger.warning( + f"Error extracting mask for detection {i}: {mask_error}" + ) + detection["segmentation_mask"] = None + else: + detection["segmentation_mask"] = None + detections.append(detection) return detections diff --git a/src/utils/config_manager.py b/src/utils/config_manager.py index 385d2a9..a31516d 100644 --- a/src/utils/config_manager.py +++ b/src/utils/config_manager.py @@ -56,7 +56,7 @@ class ConfigManager: ], }, "models": { - "default_base_model": "yolov8s.pt", + "default_base_model": "yolov8s-seg.pt", "models_directory": "data/models", }, "training": { -- 2.49.1 From 42fb2b782d536b034428e1af754a6de22d8fc1c9 Mon Sep 17 00:00:00 2001 From: Martin Laasmaa Date: Fri, 5 Dec 2025 16:18:37 +0200 Subject: [PATCH 02/13] Bug fix in installing and lauching the program --- INSTALL_TEST.md | 236 ++++++++++++++++++++++++++++++++++++++++++++ pyproject.toml | 9 +- setup.py | 2 +- src/cli.py | 2 +- src/gui_launcher.py | 49 +++++++++ 5 files changed, 291 insertions(+), 7 deletions(-) create mode 100644 INSTALL_TEST.md create mode 100644 src/gui_launcher.py diff --git a/INSTALL_TEST.md b/INSTALL_TEST.md new file mode 100644 index 0000000..8740da4 --- /dev/null +++ b/INSTALL_TEST.md @@ -0,0 +1,236 @@ +# Installation Testing Guide + +This guide helps you verify that the package installation works correctly. + +## Clean Installation Test + +### 1. Remove Any Previous Installations + +```bash +# Deactivate any active virtual environment +deactivate + +# Remove old virtual environment (if exists) +rm -rf venv + +# Create fresh virtual environment +python3 -m venv venv +source venv/bin/activate # On Linux/Mac +# or +venv\Scripts\activate # On Windows +``` + +### 2. Install the Package + +#### Option A: Editable/Development Install + +```bash +pip install -e . +``` + +This allows you to modify source code and see changes immediately. + +#### Option B: Regular Install + +```bash +pip install . +``` + +This installs the package as if it were from PyPI. + +### 3. Verify Installation + +```bash +# Check package is installed +pip list | grep microscopy + +# Check version +microscopy-detect --version +# Expected output: microscopy-object-detection 1.0.0 + +# Test Python import +python -c "import src; print(src.__version__)" +# Expected output: 1.0.0 +``` + +### 4. Test Entry Points + +```bash +# Test CLI +microscopy-detect --help + +# Test GUI launcher (will open window) +microscopy-detect-gui +``` + +### 5. Verify Package Contents + +```python +# Run this in Python shell +import src +import src.database +import src.model +import src.gui + +# Check schema file is included +from pathlib import Path +import src.database +db_path = Path(src.database.__file__).parent +schema_file = db_path / 'schema.sql' +print(f"Schema file exists: {schema_file.exists()}") +# Expected: Schema file exists: True +``` + +## Troubleshooting + +### Issue: ModuleNotFoundError + +**Error:** +``` +ModuleNotFoundError: No module named 'src' +``` + +**Solution:** +```bash +# Reinstall with verbose output +pip install -e . -v + +# Or try regular install +pip install . --force-reinstall +``` + +### Issue: Entry Points Not Working + +**Error:** +``` +microscopy-detect: command not found +``` + +**Solution:** +```bash +# Check if scripts are in PATH +which microscopy-detect + +# If not found, check pip install location +pip show microscopy-object-detection + +# You might need to add to PATH or use full path +~/.local/bin/microscopy-detect # Linux +``` + +### Issue: Import Errors for PySide6 + +**Error:** +``` +ImportError: cannot import name 'QApplication' from 'PySide6.QtWidgets' +``` + +**Solution:** +```bash +# Install Qt dependencies (Linux only) +sudo apt-get install libxcb-xinerama0 + +# Reinstall PySide6 +pip uninstall PySide6 +pip install PySide6 +``` + +### Issue: Config Files Not Found + +**Error:** +``` +FileNotFoundError: config/app_config.yaml +``` + +**Solution:** +The config file should be created automatically. If not: +```bash +# Create config directory in your home +mkdir -p ~/.microscopy-detect +cp config/app_config.yaml ~/.microscopy-detect/ + +# Or run from source directory first time +cd /home/martin/code/object_detection +python main.py +``` + +## Manual Testing Checklist + +- [ ] Package installs without errors +- [ ] Version command works (`microscopy-detect --version`) +- [ ] Help command works (`microscopy-detect --help`) +- [ ] GUI launches (`microscopy-detect-gui`) +- [ ] Can import all modules in Python +- [ ] Database schema file is accessible +- [ ] Configuration loads correctly + +## Build and Install from Wheel + +```bash +# Build the package +python -m build + +# Install from wheel +pip install dist/microscopy_object_detection-1.0.0-py3-none-any.whl + +# Test +microscopy-detect --version +``` + +## Uninstall + +```bash +pip uninstall microscopy-object-detection +``` + +## Development Workflow + +### After Code Changes + +If installed with `-e` (editable mode): +- Python code changes are immediately available +- No need to reinstall + +If installed with regular `pip install .`: +- Reinstall after changes: `pip install . --force-reinstall` + +### After Adding New Files + +```bash +# Reinstall to include new files +pip install -e . --force-reinstall +``` + +## Expected Installation Output + +``` +Processing /home/martin/code/object_detection + Installing build dependencies ... done + Getting requirements to build wheel ... done + Preparing metadata (pyproject.toml) ... done +Building wheels for collected packages: microscopy-object-detection + Building wheel for microscopy-object-detection (pyproject.toml) ... done +Successfully built microscopy-object-detection +Installing collected packages: microscopy-object-detection +Successfully installed microscopy-object-detection-1.0.0 +``` + +## Success Criteria + +Installation is successful when: +1. ✅ No error messages during installation +2. ✅ `pip list` shows the package +3. ✅ `microscopy-detect --version` returns correct version +4. ✅ GUI launches without errors +5. ✅ All Python modules can be imported +6. ✅ Database operations work +7. ✅ Detection functionality works + +## Next Steps After Successful Install + +1. Configure image repository path +2. Run first detection +3. Train a custom model +4. Export results + +For usage instructions, see [QUICKSTART.md](QUICKSTART.md) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index d5993d6..7737939 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=45", "wheel", "setuptools_scm[toml]>=6.2"] +requires = ["setuptools>=45", "wheel"] build-backend = "setuptools.build_meta" [project] @@ -63,10 +63,9 @@ Repository = "https://github.com/yourusername/object_detection" microscopy-detect = "src.cli:main" [project.gui-scripts] -microscopy-detect-gui = "main:main" +microscopy-detect-gui = "src.gui_launcher:main" [tool.setuptools] -package-dir = { "" = "." } packages = [ "src", "src.database", @@ -77,10 +76,10 @@ packages = [ "src.gui.widgets", "src.utils", ] +include-package-data = true [tool.setuptools.package-data] -src = ["database/*.sql"] -"" = ["config/*.yaml"] +"src.database" = ["*.sql"] [tool.black] line-length = 88 diff --git a/setup.py b/setup.py index 59a43a1..0366ec6 100644 --- a/setup.py +++ b/setup.py @@ -44,7 +44,7 @@ setup( "microscopy-detect=src.cli:main", ], "gui_scripts": [ - "microscopy-detect-gui=main:main", + "microscopy-detect-gui=src.gui_launcher:main", ], }, keywords="microscopy yolov8 object-detection segmentation computer-vision deep-learning", diff --git a/src/cli.py b/src/cli.py index 482ee52..5c02da4 100644 --- a/src/cli.py +++ b/src/cli.py @@ -44,7 +44,7 @@ Examples: if args.gui: # Launch GUI try: - from main import main as gui_main + from src.gui_launcher import main as gui_main gui_main() except Exception as e: diff --git a/src/gui_launcher.py b/src/gui_launcher.py new file mode 100644 index 0000000..1b921b0 --- /dev/null +++ b/src/gui_launcher.py @@ -0,0 +1,49 @@ +"""GUI launcher module for microscopy object detection application.""" + +import sys +from pathlib import Path + +from PySide6.QtWidgets import QApplication +from PySide6.QtCore import Qt + +from src import __version__ +from src.gui.main_window import MainWindow +from src.utils.logger import setup_logging +from src.utils.config_manager import ConfigManager + + +def main(): + """Launch the GUI application.""" + # Setup logging + config_manager = ConfigManager() + log_config = config_manager.get_section("logging") + setup_logging( + log_file=log_config.get("file", "logs/app.log"), + level=log_config.get("level", "INFO"), + log_format=log_config.get("format"), + ) + + # Enable High DPI scaling + QApplication.setHighDpiScaleFactorRoundingPolicy( + Qt.HighDpiScaleFactorRoundingPolicy.PassThrough + ) + + # Create Qt application + app = QApplication(sys.argv) + app.setApplicationName("Microscopy Object Detection") + app.setOrganizationName("MicroscopyLab") + app.setApplicationVersion(__version__) + + # Set application style + app.setStyle("Fusion") + + # Create and show main window + window = MainWindow() + window.show() + + # Run application + sys.exit(app.exec()) + + +if __name__ == "__main__": + main() -- 2.49.1 From 4b5d2a7c45cf0d8e59d89c6b5bcd0df21f42d912 Mon Sep 17 00:00:00 2001 From: Martin Laasmaa Date: Mon, 8 Dec 2025 16:28:58 +0200 Subject: [PATCH 03/13] Adding image loading --- docs/IMAGE_CLASS_USAGE.md | 220 ++++++++++++++++++++++++++++ examples/image_demo.py | 151 +++++++++++++++++++ src/gui/tabs/annotation_tab.py | 180 +++++++++++++++++++++-- src/utils/__init__.py | 7 + src/utils/image.py | 259 +++++++++++++++++++++++++++++++++ tests/test_image.py | 145 ++++++++++++++++++ 6 files changed, 952 insertions(+), 10 deletions(-) create mode 100644 docs/IMAGE_CLASS_USAGE.md create mode 100644 examples/image_demo.py create mode 100644 src/utils/image.py create mode 100644 tests/test_image.py diff --git a/docs/IMAGE_CLASS_USAGE.md b/docs/IMAGE_CLASS_USAGE.md new file mode 100644 index 0000000..dd4c2a8 --- /dev/null +++ b/docs/IMAGE_CLASS_USAGE.md @@ -0,0 +1,220 @@ +# Image Class Usage Guide + +The `Image` class provides a convenient way to load and work with images in the microscopy object detection application. + +## Supported Formats + +The Image class supports the following image formats: +- `.jpg`, `.jpeg` - JPEG images +- `.png` - PNG images +- `.tif`, `.tiff` - TIFF images (commonly used in microscopy) +- `.bmp` - Bitmap images + +## Basic Usage + +### Loading an Image + +```python +from src.utils import Image, ImageLoadError + +# Load an image from a file path +try: + img = Image("path/to/image.jpg") + print(f"Loaded image: {img.width}x{img.height} pixels") +except ImageLoadError as e: + print(f"Failed to load image: {e}") +``` + +### Accessing Image Properties + +```python +img = Image("microscopy_image.tif") + +# Basic properties +print(f"Width: {img.width} pixels") +print(f"Height: {img.height} pixels") +print(f"Channels: {img.channels}") +print(f"Format: {img.format}") +print(f"Shape: {img.shape}") # (height, width, channels) + +# File information +print(f"File size: {img.size_mb:.2f} MB") +print(f"File size: {img.size_bytes} bytes") + +# Image type checks +print(f"Is color: {img.is_color()}") +print(f"Is grayscale: {img.is_grayscale()}") + +# String representation +print(img) # Shows summary of image properties +``` + +### Working with Image Data + +```python +import numpy as np + +img = Image("sample.png") + +# Get image data as numpy array (OpenCV format, BGR) +bgr_data = img.data +print(f"Data shape: {bgr_data.shape}") +print(f"Data type: {bgr_data.dtype}") + +# Get image as RGB (for display or processing) +rgb_data = img.get_rgb() + +# Get grayscale version +gray_data = img.get_grayscale() + +# Create a copy (for modifications) +img_copy = img.copy() +img_copy[0, 0] = [255, 255, 255] # Modify copy, original unchanged + +# Resize image (returns new array, doesn't modify original) +resized = img.resize(640, 640) +``` + +### Using PIL Image + +```python +img = Image("photo.jpg") + +# Access as PIL Image (RGB format) +pil_img = img.pil_image + +# Use PIL methods +pil_img.show() # Display image +pil_img.save("output.png") # Save with PIL +``` + +## Integration with YOLO + +```python +from src.utils import Image +from ultralytics import YOLO + +# Load model and image +model = YOLO("yolov8n.pt") +img = Image("microscopy/cell_01.tif") + +# Run inference (YOLO accepts file paths or numpy arrays) +results = model(img.data) + +# Or use the file path directly +results = model(str(img.path)) +``` + +## Error Handling + +```python +from src.utils import Image, ImageLoadError + +def process_image(image_path): + try: + img = Image(image_path) + # Process the image... + return img + except ImageLoadError as e: + print(f"Cannot load image: {e}") + return None +``` + +## Advanced Usage + +### Batch Processing + +```python +from pathlib import Path +from src.utils import Image, ImageLoadError + +def process_image_directory(directory): + """Process all images in a directory.""" + image_paths = Path(directory).glob("*.tif") + + for path in image_paths: + try: + img = Image(path) + print(f"Processing {img.path.name}: {img.width}x{img.height}") + # Process the image... + except ImageLoadError as e: + print(f"Skipping {path}: {e}") +``` + +### Using with OpenCV Operations + +```python +import cv2 +from src.utils import Image + +img = Image("input.jpg") + +# Apply OpenCV operations on the data +blurred = cv2.GaussianBlur(img.data, (5, 5), 0) +edges = cv2.Canny(img.data, 100, 200) + +# Note: These operations don't modify the original img.data +``` + +### Memory Efficient Processing + +```python +from src.utils import Image + +# The Image class loads data into memory +img = Image("large_image.tif") +print(f"Image size in memory: {img.data.nbytes / (1024**2):.2f} MB") + +# When processing many images, consider loading one at a time +# and releasing memory by deleting the object +del img +``` + +## Best Practices + +1. **Always use try-except** when loading images to handle errors gracefully +2. **Check image properties** before processing to ensure compatibility +3. **Use copy()** when you need to modify image data without affecting the original +4. **Path objects work too** - The class accepts both strings and Path objects +5. **Consider memory usage** when working with large images or batches + +## Example: Complete Workflow + +```python +from src.utils import Image, ImageLoadError +from src.utils.file_utils import get_image_files + +def analyze_microscopy_images(directory): + """Analyze all microscopy images in a directory.""" + + # Get all image files + image_files = get_image_files(directory, recursive=True) + + results = [] + for image_path in image_files: + try: + # Load image + img = Image(image_path) + + # Analyze + result = { + 'filename': img.path.name, + 'width': img.width, + 'height': img.height, + 'channels': img.channels, + 'format': img.format, + 'size_mb': img.size_mb, + 'is_color': img.is_color() + } + + results.append(result) + print(f"✓ Analyzed {img.path.name}") + + except ImageLoadError as e: + print(f"✗ Failed to load {image_path}: {e}") + + return results + +# Run analysis +results = analyze_microscopy_images("data/datasets/cells") +print(f"\nProcessed {len(results)} images") \ No newline at end of file diff --git a/examples/image_demo.py b/examples/image_demo.py new file mode 100644 index 0000000..9f5c4eb --- /dev/null +++ b/examples/image_demo.py @@ -0,0 +1,151 @@ +""" +Example script demonstrating the Image class functionality. +""" + +import sys +from pathlib import Path + +# Add parent directory to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from src.utils import Image, ImageLoadError + + +def demonstrate_image_loading(): + """Demonstrate basic image loading functionality.""" + print("=" * 60) + print("Image Class Demonstration") + print("=" * 60) + + # Example 1: Try to load an image (replace with your own path) + example_paths = [ + "data/datasets/example.jpg", + "data/datasets/sample.png", + "tests/test_image.jpg", + ] + + loaded_img = None + for image_path in example_paths: + if Path(image_path).exists(): + try: + print(f"\n1. Loading image: {image_path}") + img = Image(image_path) + loaded_img = img + print(f" ✓ Successfully loaded!") + print(f" {img}") + break + except ImageLoadError as e: + print(f" ✗ Failed: {e}") + else: + print(f"\n1. Image not found: {image_path}") + + if loaded_img is None: + print("\nNo example images found. Creating a test image...") + create_test_image() + return + + # Example 2: Access image properties + print(f"\n2. Image Properties:") + print(f" Width: {loaded_img.width} pixels") + print(f" Height: {loaded_img.height} pixels") + print(f" Channels: {loaded_img.channels}") + print(f" Format: {loaded_img.format.upper()}") + print(f" Shape: {loaded_img.shape}") + print(f" File size: {loaded_img.size_mb:.2f} MB") + print(f" Is color: {loaded_img.is_color()}") + print(f" Is grayscale: {loaded_img.is_grayscale()}") + + # Example 3: Get different formats + print(f"\n3. Accessing Image Data:") + print(f" BGR data shape: {loaded_img.data.shape}") + print(f" RGB data shape: {loaded_img.get_rgb().shape}") + print(f" Grayscale shape: {loaded_img.get_grayscale().shape}") + print(f" PIL image mode: {loaded_img.pil_image.mode}") + + # Example 4: Resizing + print(f"\n4. Resizing Image:") + resized = loaded_img.resize(320, 320) + print(f" Original size: {loaded_img.width}x{loaded_img.height}") + print(f" Resized to: {resized.shape[1]}x{resized.shape[0]}") + + # Example 5: Working with copies + print(f"\n5. Creating Copies:") + copy = loaded_img.copy() + print(f" Created copy with shape: {copy.shape}") + print(f" Original data unchanged: {(loaded_img.data == copy).all()}") + + print("\n" + "=" * 60) + print("Demonstration Complete!") + print("=" * 60) + + +def create_test_image(): + """Create a test image for demonstration purposes.""" + import cv2 + import numpy as np + + print("\nCreating a test image...") + + # Create a colorful test image + width, height = 400, 300 + test_img = np.zeros((height, width, 3), dtype=np.uint8) + + # Add some colors + test_img[:100, :] = [255, 0, 0] # Blue section + test_img[100:200, :] = [0, 255, 0] # Green section + test_img[200:, :] = [0, 0, 255] # Red section + + # Save the test image + test_path = Path("test_demo_image.png") + cv2.imwrite(str(test_path), test_img) + print(f"Test image created: {test_path}") + + # Now load and demonstrate with it + try: + img = Image(test_path) + print(f"\nLoaded test image: {img}") + print(f"Dimensions: {img.width}x{img.height}") + print(f"Channels: {img.channels}") + print(f"Format: {img.format}") + + # Clean up + test_path.unlink() + print(f"\nTest image cleaned up.") + + except ImageLoadError as e: + print(f"Error loading test image: {e}") + + +def demonstrate_error_handling(): + """Demonstrate error handling.""" + print("\n" + "=" * 60) + print("Error Handling Demonstration") + print("=" * 60) + + # Try to load non-existent file + print("\n1. Loading non-existent file:") + try: + img = Image("nonexistent.jpg") + except ImageLoadError as e: + print(f" ✓ Caught error: {e}") + + # Try unsupported format + print("\n2. Loading unsupported format:") + try: + # Create a text file + test_file = Path("test.txt") + test_file.write_text("not an image") + img = Image(test_file) + except ImageLoadError as e: + print(f" ✓ Caught error: {e}") + test_file.unlink() # Clean up + + print("\n" + "=" * 60) + + +if __name__ == "__main__": + print("\n") + demonstrate_image_loading() + print("\n") + demonstrate_error_handling() + print("\n") diff --git a/src/gui/tabs/annotation_tab.py b/src/gui/tabs/annotation_tab.py index 7b2f5fc..065b4e1 100644 --- a/src/gui/tabs/annotation_tab.py +++ b/src/gui/tabs/annotation_tab.py @@ -3,10 +3,27 @@ Annotation tab for the microscopy object detection application. Future feature for manual annotation. """ -from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox +from PySide6.QtWidgets import ( + QWidget, + QVBoxLayout, + QHBoxLayout, + QLabel, + QGroupBox, + QPushButton, + QFileDialog, + QMessageBox, + QScrollArea, +) +from PySide6.QtGui import QPixmap, QImage +from PySide6.QtCore import Qt +from pathlib import Path from src.database.db_manager import DatabaseManager from src.utils.config_manager import ConfigManager +from src.utils.image import Image, ImageLoadError +from src.utils.logger import get_logger + +logger = get_logger(__name__) class AnnotationTab(QWidget): @@ -18,6 +35,8 @@ class AnnotationTab(QWidget): super().__init__(parent) self.db_manager = db_manager self.config_manager = config_manager + self.current_image = None + self.current_image_path = None self._setup_ui() @@ -25,24 +44,165 @@ class AnnotationTab(QWidget): """Setup user interface.""" layout = QVBoxLayout() - group = QGroupBox("Annotation Tool (Future Feature)") - group_layout = QVBoxLayout() - label = QLabel( - "Annotation functionality will be implemented in future version.\n\n" + # Image loading section + load_group = QGroupBox("Image Loading") + load_layout = QVBoxLayout() + + # Load image button + button_layout = QHBoxLayout() + self.load_image_btn = QPushButton("Load Image") + self.load_image_btn.clicked.connect(self._load_image) + button_layout.addWidget(self.load_image_btn) + button_layout.addStretch() + + load_layout.addLayout(button_layout) + + # Image info label + self.image_info_label = QLabel("No image loaded") + load_layout.addWidget(self.image_info_label) + + load_group.setLayout(load_layout) + layout.addWidget(load_group) + + # Image display section + display_group = QGroupBox("Image Display") + display_layout = QVBoxLayout() + + # Scroll area for image + scroll_area = QScrollArea() + scroll_area.setWidgetResizable(True) + scroll_area.setMinimumHeight(400) + + self.image_label = QLabel("No image loaded") + self.image_label.setAlignment(Qt.AlignCenter) + self.image_label.setStyleSheet( + "QLabel { background-color: #2b2b2b; color: #888; }" + ) + self.image_label.setScaledContents(False) + + scroll_area.setWidget(self.image_label) + display_layout.addWidget(scroll_area) + + display_group.setLayout(display_layout) + layout.addWidget(display_group) + + # Future features info + info_group = QGroupBox("Annotation Tool (Future Feature)") + info_layout = QVBoxLayout() + info_label = QLabel( + "Full annotation functionality will be implemented in future version.\n\n" "Planned Features:\n" - "- Image browser\n" "- Drawing tools for bounding boxes\n" "- Class label assignment\n" "- Export annotations to YOLO format\n" "- Annotation verification" ) - group_layout.addWidget(label) - group.setLayout(group_layout) + info_layout.addWidget(info_label) + info_group.setLayout(info_layout) - layout.addWidget(group) - layout.addStretch() + layout.addWidget(info_group) self.setLayout(layout) + def _load_image(self): + """Load and display an image file.""" + # Get image repository path or use home directory + repo_path = self.config_manager.get_image_repository_path() + start_dir = repo_path if repo_path else str(Path.home()) + + # Open file dialog + file_path, _ = QFileDialog.getOpenFileName( + self, + "Select Image", + start_dir, + "Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)", + ) + + if not file_path: + return + + try: + # Load image using Image class + self.current_image = Image(file_path) + self.current_image_path = file_path + + # Update info label + info_text = ( + f"File: {Path(file_path).name}\n" + f"Size: {self.current_image.width}x{self.current_image.height} pixels\n" + f"Channels: {self.current_image.channels}\n" + f"Format: {self.current_image.format.upper()}\n" + f"File size: {self.current_image.size_mb:.2f} MB" + ) + self.image_info_label.setText(info_text) + + # Convert to QPixmap and display + self._display_image() + + logger.info(f"Loaded image: {file_path}") + + except ImageLoadError as e: + logger.error(f"Failed to load image: {e}") + QMessageBox.critical( + self, "Error Loading Image", f"Failed to load image:\n{str(e)}" + ) + except Exception as e: + logger.error(f"Unexpected error loading image: {e}") + QMessageBox.critical(self, "Error", f"Unexpected error:\n{str(e)}") + + def _display_image(self): + """Display the current image in the image label.""" + if self.current_image is None: + return + + try: + # Get RGB image data + rgb_data = self.current_image.get_rgb() + + # Convert numpy array to QImage + height, width, channels = rgb_data.shape + bytes_per_line = channels * width + + if channels == 3: + qimage = QImage( + rgb_data.data, + width, + height, + bytes_per_line, + QImage.Format_RGB888, + ) + else: + # Grayscale + qimage = QImage( + rgb_data.data, + width, + height, + bytes_per_line, + QImage.Format_Grayscale8, + ) + + # Convert to pixmap + pixmap = QPixmap.fromImage(qimage) + + # Scale to fit display (max 800px width or height) + max_size = 800 + if pixmap.width() > max_size or pixmap.height() > max_size: + pixmap = pixmap.scaled( + max_size, + max_size, + Qt.KeepAspectRatio, + Qt.SmoothTransformation, + ) + + # Display in label + self.image_label.setPixmap(pixmap) + self.image_label.setScaledContents(False) + + except Exception as e: + logger.error(f"Error displaying image: {e}") + QMessageBox.warning( + self, "Display Error", f"Failed to display image:\n{str(e)}" + ) + def refresh(self): """Refresh the tab.""" pass diff --git a/src/utils/__init__.py b/src/utils/__init__.py index e69de29..5f75af4 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -0,0 +1,7 @@ +""" +Utility modules for the microscopy object detection application. +""" + +from src.utils.image import Image, ImageLoadError + +__all__ = ["Image", "ImageLoadError"] diff --git a/src/utils/image.py b/src/utils/image.py new file mode 100644 index 0000000..75f40e4 --- /dev/null +++ b/src/utils/image.py @@ -0,0 +1,259 @@ +""" +Image loading and management utilities for the microscopy object detection application. +""" + +import cv2 +import numpy as np +from pathlib import Path +from typing import Optional, Tuple, Union +from PIL import Image as PILImage + +from src.utils.logger import get_logger +from src.utils.file_utils import validate_file_path, is_image_file + +logger = get_logger(__name__) + + +class ImageLoadError(Exception): + """Exception raised when an image cannot be loaded.""" + + pass + + +class Image: + """ + A class for loading and managing images from file paths. + + Supports multiple image formats: .jpg, .jpeg, .png, .tif, .tiff, .bmp + Provides access to image data in multiple formats (OpenCV/numpy, PIL). + + Attributes: + path: Path to the image file + data: Image data as numpy array (OpenCV format, BGR) + pil_image: Image data as PIL Image (RGB) + width: Image width in pixels + height: Image height in pixels + channels: Number of color channels + format: Image file format + size_bytes: File size in bytes + """ + + SUPPORTED_EXTENSIONS = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"] + + def __init__(self, image_path: Union[str, Path]): + """ + Initialize an Image object by loading from a file path. + + Args: + image_path: Path to the image file (string or Path object) + + Raises: + ImageLoadError: If the image cannot be loaded or is invalid + """ + self.path = Path(image_path) + self._data: Optional[np.ndarray] = None + self._pil_image: Optional[PILImage.Image] = None + self._width: int = 0 + self._height: int = 0 + self._channels: int = 0 + self._format: str = "" + self._size_bytes: int = 0 + + # Load the image + self._load() + + def _load(self) -> None: + """ + Load the image from disk. + + Raises: + ImageLoadError: If the image cannot be loaded + """ + # Validate path + if not validate_file_path(str(self.path), must_exist=True): + raise ImageLoadError(f"Invalid or non-existent file path: {self.path}") + + # Check file extension + if not is_image_file(str(self.path), self.SUPPORTED_EXTENSIONS): + ext = self.path.suffix.lower() + raise ImageLoadError( + f"Unsupported image format: {ext}. " + f"Supported formats: {', '.join(self.SUPPORTED_EXTENSIONS)}" + ) + + try: + # Load with OpenCV (returns BGR format) + self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED) + + if self._data is None: + raise ImageLoadError(f"Failed to load image with OpenCV: {self.path}") + + # Extract metadata + self._height, self._width = self._data.shape[:2] + self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1 + self._format = self.path.suffix.lower().lstrip(".") + self._size_bytes = self.path.stat().st_size + + # Load PIL version for compatibility (convert BGR to RGB) + if self._channels == 3: + rgb_data = cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB) + self._pil_image = PILImage.fromarray(rgb_data) + elif self._channels == 4: + rgba_data = cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA) + self._pil_image = PILImage.fromarray(rgba_data) + else: + # Grayscale + self._pil_image = PILImage.fromarray(self._data) + + logger.info( + f"Successfully loaded image: {self.path.name} " + f"({self._width}x{self._height}, {self._channels} channels, " + f"{self._format.upper()})" + ) + + except Exception as e: + logger.error(f"Error loading image {self.path}: {e}") + raise ImageLoadError(f"Failed to load image: {e}") from e + + @property + def data(self) -> np.ndarray: + """ + Get image data as numpy array (OpenCV format, BGR or grayscale). + + Returns: + Image data as numpy array + """ + if self._data is None: + raise ImageLoadError("Image data not available") + return self._data + + @property + def pil_image(self) -> PILImage.Image: + """ + Get image data as PIL Image (RGB or grayscale). + + Returns: + PIL Image object + """ + if self._pil_image is None: + raise ImageLoadError("PIL image not available") + return self._pil_image + + @property + def width(self) -> int: + """Get image width in pixels.""" + return self._width + + @property + def height(self) -> int: + """Get image height in pixels.""" + return self._height + + @property + def shape(self) -> Tuple[int, int, int]: + """ + Get image shape as (height, width, channels). + + Returns: + Tuple of (height, width, channels) + """ + return (self._height, self._width, self._channels) + + @property + def channels(self) -> int: + """Get number of color channels.""" + return self._channels + + @property + def format(self) -> str: + """Get image file format (e.g., 'jpg', 'png').""" + return self._format + + @property + def size_bytes(self) -> int: + """Get file size in bytes.""" + return self._size_bytes + + @property + def size_mb(self) -> float: + """Get file size in megabytes.""" + return self._size_bytes / (1024 * 1024) + + def get_rgb(self) -> np.ndarray: + """ + Get image data as RGB numpy array. + + Returns: + Image data in RGB format as numpy array + """ + if self._channels == 3: + return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB) + elif self._channels == 4: + return cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA) + else: + return self._data + + def get_grayscale(self) -> np.ndarray: + """ + Get image as grayscale numpy array. + + Returns: + Grayscale image as numpy array + """ + if self._channels == 1: + return self._data + else: + return cv2.cvtColor(self._data, cv2.COLOR_BGR2GRAY) + + def copy(self) -> np.ndarray: + """ + Get a copy of the image data. + + Returns: + Copy of image data as numpy array + """ + return self._data.copy() + + def resize(self, width: int, height: int) -> np.ndarray: + """ + Resize the image to specified dimensions. + + Args: + width: Target width in pixels + height: Target height in pixels + + Returns: + Resized image as numpy array (does not modify original) + """ + return cv2.resize(self._data, (width, height)) + + def is_grayscale(self) -> bool: + """ + Check if image is grayscale. + + Returns: + True if image is grayscale (1 channel) + """ + return self._channels == 1 + + def is_color(self) -> bool: + """ + Check if image is color. + + Returns: + True if image has 3 or more channels + """ + return self._channels >= 3 + + def __repr__(self) -> str: + """String representation of the Image object.""" + return ( + f"Image(path='{self.path.name}', " + f"shape=({self._width}x{self._height}x{self._channels}), " + f"format={self._format}, " + f"size={self.size_mb:.2f}MB)" + ) + + def __str__(self) -> str: + """String representation of the Image object.""" + return self.__repr__() diff --git a/tests/test_image.py b/tests/test_image.py new file mode 100644 index 0000000..88b617f --- /dev/null +++ b/tests/test_image.py @@ -0,0 +1,145 @@ +""" +Tests for the Image class. +""" + +import pytest +import numpy as np +from pathlib import Path +from src.utils.image import Image, ImageLoadError + + +class TestImage: + """Test cases for the Image class.""" + + def test_load_nonexistent_file(self): + """Test loading a non-existent file raises ImageLoadError.""" + with pytest.raises(ImageLoadError): + Image("nonexistent_file.jpg") + + def test_load_unsupported_format(self, tmp_path): + """Test loading an unsupported format raises ImageLoadError.""" + # Create a dummy file with unsupported extension + test_file = tmp_path / "test.txt" + test_file.write_text("not an image") + + with pytest.raises(ImageLoadError): + Image(test_file) + + def test_supported_extensions(self): + """Test that supported extensions are correctly defined.""" + expected_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"] + assert Image.SUPPORTED_EXTENSIONS == expected_extensions + + def test_image_properties(self, tmp_path): + """Test image properties after loading.""" + # Create a simple test image using numpy and cv2 + import cv2 + + test_img = np.zeros((100, 200, 3), dtype=np.uint8) + test_img[:, :] = [255, 0, 0] # Blue in BGR + + test_file = tmp_path / "test.jpg" + cv2.imwrite(str(test_file), test_img) + + # Load the image + img = Image(test_file) + + # Check properties + assert img.width == 200 + assert img.height == 100 + assert img.channels == 3 + assert img.format == "jpg" + assert img.shape == (100, 200, 3) + assert img.size_bytes > 0 + assert img.is_color() + assert not img.is_grayscale() + + def test_get_rgb(self, tmp_path): + """Test RGB conversion.""" + import cv2 + + # Create BGR image + test_img = np.zeros((50, 50, 3), dtype=np.uint8) + test_img[:, :] = [255, 0, 0] # Blue in BGR + + test_file = tmp_path / "test_rgb.png" + cv2.imwrite(str(test_file), test_img) + + img = Image(test_file) + rgb_data = img.get_rgb() + + # RGB should have red channel at 255 + assert rgb_data[0, 0, 0] == 0 # R + assert rgb_data[0, 0, 1] == 0 # G + assert rgb_data[0, 0, 2] == 255 # B (was BGR blue) + + def test_get_grayscale(self, tmp_path): + """Test grayscale conversion.""" + import cv2 + + test_img = np.zeros((50, 50, 3), dtype=np.uint8) + test_img[:, :] = [128, 128, 128] + + test_file = tmp_path / "test_gray.png" + cv2.imwrite(str(test_file), test_img) + + img = Image(test_file) + gray_data = img.get_grayscale() + + assert len(gray_data.shape) == 2 # Should be 2D + assert gray_data.shape == (50, 50) + + def test_copy(self, tmp_path): + """Test copying image data.""" + import cv2 + + test_img = np.zeros((50, 50, 3), dtype=np.uint8) + + test_file = tmp_path / "test_copy.png" + cv2.imwrite(str(test_file), test_img) + + img = Image(test_file) + copied = img.copy() + + # Modify copy + copied[0, 0] = [255, 255, 255] + + # Original should be unchanged + assert not np.array_equal(img.data[0, 0], copied[0, 0]) + + def test_resize(self, tmp_path): + """Test image resizing.""" + import cv2 + + test_img = np.zeros((100, 100, 3), dtype=np.uint8) + + test_file = tmp_path / "test_resize.png" + cv2.imwrite(str(test_file), test_img) + + img = Image(test_file) + resized = img.resize(50, 50) + + assert resized.shape == (50, 50, 3) + # Original should be unchanged + assert img.width == 100 + assert img.height == 100 + + def test_str_repr(self, tmp_path): + """Test string representation.""" + import cv2 + + test_img = np.zeros((100, 200, 3), dtype=np.uint8) + + test_file = tmp_path / "test_str.jpg" + cv2.imwrite(str(test_file), test_img) + + img = Image(test_file) + + str_repr = str(img) + assert "test_str.jpg" in str_repr + assert "100x200x3" in str_repr + assert "jpg" in str_repr + + repr_str = repr(img) + assert "Image" in repr_str + assert "test_str.jpg" in repr_str -- 2.49.1 From bb26d43dd79b6e6ae2d237e2ae1b579c5879ec22 Mon Sep 17 00:00:00 2001 From: Martin Laasmaa Date: Mon, 8 Dec 2025 17:33:32 +0200 Subject: [PATCH 04/13] Adding image_display widget --- src/gui/tabs/annotation_tab.py | 126 ++++------- src/gui/widgets/__init__.py | 5 + src/gui/widgets/image_display_widget.py | 282 ++++++++++++++++++++++++ src/utils/image.py | 32 +++ 4 files changed, 366 insertions(+), 79 deletions(-) create mode 100644 src/gui/widgets/image_display_widget.py diff --git a/src/gui/tabs/annotation_tab.py b/src/gui/tabs/annotation_tab.py index 065b4e1..6aa2462 100644 --- a/src/gui/tabs/annotation_tab.py +++ b/src/gui/tabs/annotation_tab.py @@ -12,16 +12,15 @@ from PySide6.QtWidgets import ( QPushButton, QFileDialog, QMessageBox, - QScrollArea, ) -from PySide6.QtGui import QPixmap, QImage -from PySide6.QtCore import Qt +from PySide6.QtCore import Qt, QSettings from pathlib import Path from src.database.db_manager import DatabaseManager from src.utils.config_manager import ConfigManager from src.utils.image import Image, ImageLoadError from src.utils.logger import get_logger +from src.gui.widgets import ImageDisplayWidget logger = get_logger(__name__) @@ -68,20 +67,10 @@ class AnnotationTab(QWidget): display_group = QGroupBox("Image Display") display_layout = QVBoxLayout() - # Scroll area for image - scroll_area = QScrollArea() - scroll_area.setWidgetResizable(True) - scroll_area.setMinimumHeight(400) - - self.image_label = QLabel("No image loaded") - self.image_label.setAlignment(Qt.AlignCenter) - self.image_label.setStyleSheet( - "QLabel { background-color: #2b2b2b; color: #888; }" - ) - self.image_label.setScaledContents(False) - - scroll_area.setWidget(self.image_label) - display_layout.addWidget(scroll_area) + # Use the reusable ImageDisplayWidget + self.image_display_widget = ImageDisplayWidget() + self.image_display_widget.zoom_changed.connect(self._on_zoom_changed) + display_layout.addWidget(self.image_display_widget) display_group.setLayout(display_layout) layout.addWidget(display_group) @@ -101,13 +90,26 @@ class AnnotationTab(QWidget): info_group.setLayout(info_layout) layout.addWidget(info_group) + + # Zoom controls info + zoom_info = QLabel("Zoom: Mouse wheel or +/- keys to zoom in/out") + zoom_info.setStyleSheet("QLabel { color: #888; font-style: italic; }") + layout.addWidget(zoom_info) + self.setLayout(layout) def _load_image(self): """Load and display an image file.""" - # Get image repository path or use home directory - repo_path = self.config_manager.get_image_repository_path() - start_dir = repo_path if repo_path else str(Path.home()) + # Get last opened directory from QSettings + settings = QSettings("microscopy_app", "object_detection") + last_dir = settings.value("annotation_tab/last_directory", None) + + # Fallback to image repository path or home directory + if last_dir and Path(last_dir).exists(): + start_dir = last_dir + else: + repo_path = self.config_manager.get_image_repository_path() + start_dir = repo_path if repo_path else str(Path.home()) # Open file dialog file_path, _ = QFileDialog.getOpenFileName( @@ -125,18 +127,16 @@ class AnnotationTab(QWidget): self.current_image = Image(file_path) self.current_image_path = file_path - # Update info label - info_text = ( - f"File: {Path(file_path).name}\n" - f"Size: {self.current_image.width}x{self.current_image.height} pixels\n" - f"Channels: {self.current_image.channels}\n" - f"Format: {self.current_image.format.upper()}\n" - f"File size: {self.current_image.size_mb:.2f} MB" + # Store the directory for next time + settings.setValue( + "annotation_tab/last_directory", str(Path(file_path).parent) ) - self.image_info_label.setText(info_text) - # Convert to QPixmap and display - self._display_image() + # Display image using the ImageDisplayWidget + self.image_display_widget.load_image(self.current_image) + + # Update info label + self._update_image_info() logger.info(f"Loaded image: {file_path}") @@ -149,59 +149,27 @@ class AnnotationTab(QWidget): logger.error(f"Unexpected error loading image: {e}") QMessageBox.critical(self, "Error", f"Unexpected error:\n{str(e)}") - def _display_image(self): - """Display the current image in the image label.""" + def _update_image_info(self): + """Update the image info label with current image details.""" if self.current_image is None: + self.image_info_label.setText("No image loaded") return - try: - # Get RGB image data - rgb_data = self.current_image.get_rgb() + zoom_percentage = self.image_display_widget.get_zoom_percentage() + info_text = ( + f"File: {Path(self.current_image_path).name}\n" + f"Size: {self.current_image.width}x{self.current_image.height} pixels\n" + f"Channels: {self.current_image.channels}\n" + f"Data type: {self.current_image.dtype}\n" + f"Format: {self.current_image.format.upper()}\n" + f"File size: {self.current_image.size_mb:.2f} MB\n" + f"Zoom: {zoom_percentage}%" + ) + self.image_info_label.setText(info_text) - # Convert numpy array to QImage - height, width, channels = rgb_data.shape - bytes_per_line = channels * width - - if channels == 3: - qimage = QImage( - rgb_data.data, - width, - height, - bytes_per_line, - QImage.Format_RGB888, - ) - else: - # Grayscale - qimage = QImage( - rgb_data.data, - width, - height, - bytes_per_line, - QImage.Format_Grayscale8, - ) - - # Convert to pixmap - pixmap = QPixmap.fromImage(qimage) - - # Scale to fit display (max 800px width or height) - max_size = 800 - if pixmap.width() > max_size or pixmap.height() > max_size: - pixmap = pixmap.scaled( - max_size, - max_size, - Qt.KeepAspectRatio, - Qt.SmoothTransformation, - ) - - # Display in label - self.image_label.setPixmap(pixmap) - self.image_label.setScaledContents(False) - - except Exception as e: - logger.error(f"Error displaying image: {e}") - QMessageBox.warning( - self, "Display Error", f"Failed to display image:\n{str(e)}" - ) + def _on_zoom_changed(self, zoom_scale: float): + """Handle zoom level changes from the image display widget.""" + self._update_image_info() def refresh(self): """Refresh the tab.""" diff --git a/src/gui/widgets/__init__.py b/src/gui/widgets/__init__.py index e69de29..2946406 100644 --- a/src/gui/widgets/__init__.py +++ b/src/gui/widgets/__init__.py @@ -0,0 +1,5 @@ +"""GUI widgets for the microscopy object detection application.""" + +from src.gui.widgets.image_display_widget import ImageDisplayWidget + +__all__ = ["ImageDisplayWidget"] diff --git a/src/gui/widgets/image_display_widget.py b/src/gui/widgets/image_display_widget.py new file mode 100644 index 0000000..52d2ce2 --- /dev/null +++ b/src/gui/widgets/image_display_widget.py @@ -0,0 +1,282 @@ +""" +Image display widget with zoom functionality for the microscopy object detection application. +Reusable widget for displaying images with zoom controls. +""" + +from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QScrollArea +from PySide6.QtGui import QPixmap, QImage, QKeyEvent +from PySide6.QtCore import Qt, QEvent, Signal +from pathlib import Path +import numpy as np + +from src.utils.image import Image, ImageLoadError +from src.utils.logger import get_logger + +logger = get_logger(__name__) + + +class ImageDisplayWidget(QWidget): + """ + Reusable widget for displaying images with zoom functionality. + + Features: + - Display images from Image objects + - Zoom in/out with mouse wheel + - Zoom in/out with +/- keyboard keys + - Reset zoom with Ctrl+0 + - Scroll area for large images + + Signals: + zoom_changed: Emitted when zoom level changes (float zoom_scale) + """ + + zoom_changed = Signal(float) # Emitted when zoom level changes + + def __init__(self, parent=None): + """ + Initialize the image display widget. + + Args: + parent: Parent widget + """ + super().__init__(parent) + + self.current_image = None + self.original_pixmap = None # Store original pixmap for zoom + self.zoom_scale = 1.0 # Current zoom scale + self.zoom_min = 0.1 # Minimum zoom (10%) + self.zoom_max = 10.0 # Maximum zoom (1000%) + self.zoom_step = 0.1 # Zoom step for +/- keys + self.zoom_wheel_step = 0.15 # Zoom step for mouse wheel + + self._setup_ui() + + def _setup_ui(self): + """Setup user interface.""" + layout = QVBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + + # Scroll area for image + self.scroll_area = QScrollArea() + self.scroll_area.setWidgetResizable(True) + self.scroll_area.setMinimumHeight(400) + + self.image_label = QLabel("No image loaded") + self.image_label.setAlignment(Qt.AlignCenter) + self.image_label.setStyleSheet( + "QLabel { background-color: #2b2b2b; color: #888; }" + ) + self.image_label.setScaledContents(False) + + # Enable mouse tracking for wheel events + self.image_label.setMouseTracking(True) + self.scroll_area.setWidget(self.image_label) + + # Install event filter to capture wheel events on scroll area + self.scroll_area.viewport().installEventFilter(self) + + layout.addWidget(self.scroll_area) + self.setLayout(layout) + + # Set focus policy to receive keyboard events + self.setFocusPolicy(Qt.StrongFocus) + + def load_image(self, image: Image): + """ + Load and display an image. + + Args: + image: Image object to display + + Raises: + ImageLoadError: If image cannot be displayed + """ + self.current_image = image + + # Reset zoom when loading new image + self.zoom_scale = 1.0 + + # Convert to QPixmap and display + self._display_image() + + logger.debug(f"Loaded image into display widget: {image.width}x{image.height}") + + def clear(self): + """Clear the displayed image.""" + self.current_image = None + self.original_pixmap = None + self.zoom_scale = 1.0 + self.image_label.setText("No image loaded") + self.image_label.setPixmap(QPixmap()) + logger.debug("Cleared image display") + + def _display_image(self): + """Display the current image in the image label.""" + if self.current_image is None: + return + + try: + # Get RGB image data + if self.current_image.channels == 3: + image_data = self.current_image.get_rgb() + height, width, channels = image_data.shape + else: + image_data = self.current_image.get_grayscale() + height, width = image_data.shape + channels = 1 + + # Ensure data is contiguous for proper QImage display + image_data = np.ascontiguousarray(image_data) + + # Use actual stride from numpy array for correct display + bytes_per_line = image_data.strides[0] + + qimage = QImage( + image_data.data, + width, + height, + bytes_per_line, + self.current_image.qtimage_format, + ) + + # Convert to pixmap + pixmap = QPixmap.fromImage(qimage) + + # Store original pixmap for zooming + self.original_pixmap = pixmap + + # Apply zoom and display + self._apply_zoom() + + except Exception as e: + logger.error(f"Error displaying image: {e}") + raise ImageLoadError(f"Failed to display image: {str(e)}") + + def _apply_zoom(self): + """Apply current zoom level to the displayed image.""" + if self.original_pixmap is None: + return + + # Calculate scaled size + scaled_width = int(self.original_pixmap.width() * self.zoom_scale) + scaled_height = int(self.original_pixmap.height() * self.zoom_scale) + + # Scale pixmap + scaled_pixmap = self.original_pixmap.scaled( + scaled_width, + scaled_height, + Qt.KeepAspectRatio, + ( + Qt.SmoothTransformation + if self.zoom_scale >= 1.0 + else Qt.FastTransformation + ), + ) + + # Display in label + self.image_label.setPixmap(scaled_pixmap) + self.image_label.setScaledContents(False) + self.image_label.adjustSize() + + # Emit zoom changed signal + self.zoom_changed.emit(self.zoom_scale) + + def zoom_in(self): + """Zoom in on the image.""" + if self.original_pixmap is None: + return + + new_scale = self.zoom_scale + self.zoom_step + if new_scale <= self.zoom_max: + self.zoom_scale = new_scale + self._apply_zoom() + logger.debug(f"Zoomed in to {int(self.zoom_scale * 100)}%") + + def zoom_out(self): + """Zoom out from the image.""" + if self.original_pixmap is None: + return + + new_scale = self.zoom_scale - self.zoom_step + if new_scale >= self.zoom_min: + self.zoom_scale = new_scale + self._apply_zoom() + logger.debug(f"Zoomed out to {int(self.zoom_scale * 100)}%") + + def reset_zoom(self): + """Reset zoom to 100%.""" + if self.original_pixmap is None: + return + + self.zoom_scale = 1.0 + self._apply_zoom() + logger.debug("Reset zoom to 100%") + + def set_zoom(self, scale: float): + """ + Set zoom to a specific scale. + + Args: + scale: Zoom scale (1.0 = 100%) + """ + if self.original_pixmap is None: + return + + # Clamp to min/max + scale = max(self.zoom_min, min(self.zoom_max, scale)) + + self.zoom_scale = scale + self._apply_zoom() + logger.debug(f"Set zoom to {int(self.zoom_scale * 100)}%") + + def get_zoom_percentage(self) -> int: + """ + Get current zoom level as percentage. + + Returns: + Zoom level as integer percentage (e.g., 100 for 100%) + """ + return int(self.zoom_scale * 100) + + def keyPressEvent(self, event: QKeyEvent): + """Handle keyboard events for zooming.""" + if event.key() in (Qt.Key_Plus, Qt.Key_Equal): + # + or = key (= is the unshifted + on many keyboards) + self.zoom_in() + event.accept() + elif event.key() == Qt.Key_Minus: + # - key + self.zoom_out() + event.accept() + elif event.key() == Qt.Key_0 and event.modifiers() == Qt.ControlModifier: + # Ctrl+0 to reset zoom + self.reset_zoom() + event.accept() + else: + super().keyPressEvent(event) + + def eventFilter(self, obj, event: QEvent) -> bool: + """Event filter to capture wheel events for zooming.""" + if event.type() == QEvent.Wheel: + wheel_event = event + if self.original_pixmap is not None: + # Get wheel angle delta + delta = wheel_event.angleDelta().y() + + # Zoom in/out based on wheel direction + if delta > 0: + # Scroll up = zoom in + new_scale = self.zoom_scale + self.zoom_wheel_step + if new_scale <= self.zoom_max: + self.zoom_scale = new_scale + self._apply_zoom() + else: + # Scroll down = zoom out + new_scale = self.zoom_scale - self.zoom_wheel_step + if new_scale >= self.zoom_min: + self.zoom_scale = new_scale + self._apply_zoom() + + return True # Event handled + + return super().eventFilter(obj, event) diff --git a/src/utils/image.py b/src/utils/image.py index 75f40e4..9dc867d 100644 --- a/src/utils/image.py +++ b/src/utils/image.py @@ -11,6 +11,8 @@ from PIL import Image as PILImage from src.utils.logger import get_logger from src.utils.file_utils import validate_file_path, is_image_file +from PySide6.QtGui import QImage + logger = get_logger(__name__) @@ -58,6 +60,7 @@ class Image: self._channels: int = 0 self._format: str = "" self._size_bytes: int = 0 + self._dtype: Optional[np.dtype] = None # Load the image self._load() @@ -93,6 +96,7 @@ class Image: self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1 self._format = self.path.suffix.lower().lstrip(".") self._size_bytes = self.path.stat().st_size + self._dtype = self._data.dtype # Load PIL version for compatibility (convert BGR to RGB) if self._channels == 3: @@ -157,6 +161,7 @@ class Image: Returns: Tuple of (height, width, channels) """ + print("shape", self._height, self._width, self._channels) return (self._height, self._width, self._channels) @property @@ -179,6 +184,33 @@ class Image: """Get file size in megabytes.""" return self._size_bytes / (1024 * 1024) + @property + def dtype(self) -> np.dtype: + """Get the data type of the image array.""" + if self._dtype is None: + raise ImageLoadError("Image dtype not available") + return self._dtype + + @property + def qtimage_format(self) -> QImage.Format: + """ + Get the appropriate QImage format for the image. + + Returns: + QImage.Format enum value + """ + if self._channels == 3: + return QImage.Format_RGB888 + elif self._channels == 4: + return QImage.Format_RGBA8888 + elif self._channels == 1: + if self._dtype == np.uint16: + return QImage.Format_Grayscale16 + else: + return QImage.Format_Grayscale8 + else: + raise ImageLoadError(f"Unsupported number of channels: {self._channels}") + def get_rgb(self) -> np.ndarray: """ Get image data as RGB numpy array. -- 2.49.1 From f84dea0bff0327128f36270b973a23b030a2c29d Mon Sep 17 00:00:00 2001 From: Martin Laasmaa Date: Mon, 8 Dec 2025 22:40:07 +0200 Subject: [PATCH 05/13] Adding splitter and saving layout state when closing the app --- src/gui/main_window.py | 31 ++++++++- src/gui/tabs/annotation_tab.py | 117 ++++++++++++++++++++++++++------- 2 files changed, 121 insertions(+), 27 deletions(-) diff --git a/src/gui/main_window.py b/src/gui/main_window.py index 71e322b..99eefa9 100644 --- a/src/gui/main_window.py +++ b/src/gui/main_window.py @@ -13,7 +13,7 @@ from PySide6.QtWidgets import ( QVBoxLayout, QLabel, ) -from PySide6.QtCore import Qt, QTimer +from PySide6.QtCore import Qt, QTimer, QSettings from PySide6.QtGui import QAction, QKeySequence from src.database.db_manager import DatabaseManager @@ -52,8 +52,8 @@ class MainWindow(QMainWindow): self._create_tab_widget() self._create_status_bar() - # Center window on screen - self._center_window() + # Restore window geometry or center window on screen + self._restore_window_state() logger.info("Main window initialized") @@ -156,6 +156,24 @@ class MainWindow(QMainWindow): (screen.width() - size.width()) // 2, (screen.height() - size.height()) // 2 ) + def _restore_window_state(self): + """Restore window geometry from settings or center window.""" + settings = QSettings("microscopy_app", "object_detection") + geometry = settings.value("main_window/geometry") + + if geometry: + self.restoreGeometry(geometry) + logger.debug("Restored window geometry from settings") + else: + self._center_window() + logger.debug("Centered window on screen") + + def _save_window_state(self): + """Save window geometry to settings.""" + settings = QSettings("microscopy_app", "object_detection") + settings.setValue("main_window/geometry", self.saveGeometry()) + logger.debug("Saved window geometry to settings") + def _show_settings(self): """Show settings dialog.""" logger.info("Opening settings dialog") @@ -276,6 +294,13 @@ class MainWindow(QMainWindow): ) if reply == QMessageBox.Yes: + # Save window state before closing + self._save_window_state() + + # Save annotation tab state if it exists + if hasattr(self, "annotation_tab"): + self.annotation_tab.save_state() + logger.info("Application closing") event.accept() else: diff --git a/src/gui/tabs/annotation_tab.py b/src/gui/tabs/annotation_tab.py index 6aa2462..a373ad0 100644 --- a/src/gui/tabs/annotation_tab.py +++ b/src/gui/tabs/annotation_tab.py @@ -12,6 +12,7 @@ from PySide6.QtWidgets import ( QPushButton, QFileDialog, QMessageBox, + QSplitter, ) from PySide6.QtCore import Qt, QSettings from pathlib import Path @@ -43,25 +44,13 @@ class AnnotationTab(QWidget): """Setup user interface.""" layout = QVBoxLayout() - # Image loading section - load_group = QGroupBox("Image Loading") - load_layout = QVBoxLayout() + # Main horizontal splitter to divide left (image) and right (controls) + self.main_splitter = QSplitter(Qt.Horizontal) + self.main_splitter.setHandleWidth(10) - # Load image button - button_layout = QHBoxLayout() - self.load_image_btn = QPushButton("Load Image") - self.load_image_btn.clicked.connect(self._load_image) - button_layout.addWidget(self.load_image_btn) - button_layout.addStretch() - - load_layout.addLayout(button_layout) - - # Image info label - self.image_info_label = QLabel("No image loaded") - load_layout.addWidget(self.image_info_label) - - load_group.setLayout(load_layout) - layout.addWidget(load_group) + # { Left splitter for image display and zoom info + self.left_splitter = QSplitter(Qt.Vertical) + self.left_splitter.setHandleWidth(10) # Image display section display_group = QGroupBox("Image Display") @@ -73,7 +62,21 @@ class AnnotationTab(QWidget): display_layout.addWidget(self.image_display_widget) display_group.setLayout(display_layout) - layout.addWidget(display_group) + self.left_splitter.addWidget(display_group) + + # Zoom controls info + zoom_info = QLabel("Zoom: Mouse wheel or +/- keys to zoom in/out") + zoom_info.setStyleSheet("QLabel { color: #888; font-style: italic; }") + self.left_splitter.addWidget(zoom_info) + # } + + # { Right splitter for annotation tools and controls + self.right_splitter = QSplitter(Qt.Vertical) + self.right_splitter.setHandleWidth(10) + + # Image loading section + load_group = QGroupBox("Image Loading") + load_layout = QVBoxLayout() # Future features info info_group = QGroupBox("Annotation Tool (Future Feature)") @@ -86,18 +89,41 @@ class AnnotationTab(QWidget): "- Export annotations to YOLO format\n" "- Annotation verification" ) + info_label.setWordWrap(True) info_layout.addWidget(info_label) info_group.setLayout(info_layout) - layout.addWidget(info_group) + self.right_splitter.addWidget(info_group) - # Zoom controls info - zoom_info = QLabel("Zoom: Mouse wheel or +/- keys to zoom in/out") - zoom_info.setStyleSheet("QLabel { color: #888; font-style: italic; }") - layout.addWidget(zoom_info) + # Load image button + button_layout = QHBoxLayout() + self.load_image_btn = QPushButton("Load Image") + self.load_image_btn.clicked.connect(self._load_image) + button_layout.addWidget(self.load_image_btn) + button_layout.addStretch() + load_layout.addLayout(button_layout) + # Image info label + self.image_info_label = QLabel("No image loaded") + load_layout.addWidget(self.image_info_label) + + load_group.setLayout(load_layout) + self.right_splitter.addWidget(load_group) + # } + + # Add both splitters to the main horizontal splitter + self.main_splitter.addWidget(self.left_splitter) + self.main_splitter.addWidget(self.right_splitter) + + # Set initial sizes: 75% for left (image), 25% for right (controls) + self.main_splitter.setSizes([750, 250]) + + layout.addWidget(self.main_splitter) self.setLayout(layout) + # Restore splitter positions from settings + self._restore_state() + def _load_image(self): """Load and display an image file.""" # Get last opened directory from QSettings @@ -171,6 +197,49 @@ class AnnotationTab(QWidget): """Handle zoom level changes from the image display widget.""" self._update_image_info() + def _restore_state(self): + """Restore splitter positions from settings.""" + settings = QSettings("microscopy_app", "object_detection") + + # Restore main splitter state + main_state = settings.value("annotation_tab/main_splitter_state") + if main_state: + self.main_splitter.restoreState(main_state) + logger.debug("Restored main splitter state") + + # Restore left splitter state + left_state = settings.value("annotation_tab/left_splitter_state") + if left_state: + self.left_splitter.restoreState(left_state) + logger.debug("Restored left splitter state") + + # Restore right splitter state + right_state = settings.value("annotation_tab/right_splitter_state") + if right_state: + self.right_splitter.restoreState(right_state) + logger.debug("Restored right splitter state") + + def save_state(self): + """Save splitter positions to settings.""" + settings = QSettings("microscopy_app", "object_detection") + + # Save main splitter state + settings.setValue( + "annotation_tab/main_splitter_state", self.main_splitter.saveState() + ) + + # Save left splitter state + settings.setValue( + "annotation_tab/left_splitter_state", self.left_splitter.saveState() + ) + + # Save right splitter state + settings.setValue( + "annotation_tab/right_splitter_state", self.right_splitter.saveState() + ) + + logger.debug("Saved annotation tab splitter states") + def refresh(self): """Refresh the tab.""" pass -- 2.49.1 From fc22479621e5802967dabe24d8d0eabe7c71aa1d Mon Sep 17 00:00:00 2001 From: Martin Laasmaa Date: Mon, 8 Dec 2025 23:15:54 +0200 Subject: [PATCH 06/13] Adding pen tool for annotation --- src/database/db_manager.py | 228 ++++++++++- src/database/schema.sql | 27 +- src/gui/tabs/annotation_tab.py | 116 ++++-- src/gui/widgets/__init__.py | 4 +- src/gui/widgets/annotation_canvas_widget.py | 406 ++++++++++++++++++++ src/gui/widgets/annotation_tools_widget.py | 352 +++++++++++++++++ 6 files changed, 1079 insertions(+), 54 deletions(-) create mode 100644 src/gui/widgets/annotation_canvas_widget.py create mode 100644 src/gui/widgets/annotation_tools_widget.py diff --git a/src/database/db_manager.py b/src/database/db_manager.py index 5dcf5c2..db2da9e 100644 --- a/src/database/db_manager.py +++ b/src/database/db_manager.py @@ -30,18 +30,48 @@ class DatabaseManager: # Create directory if it doesn't exist Path(self.db_path).parent.mkdir(parents=True, exist_ok=True) - # Read schema file and execute - schema_path = Path(__file__).parent / "schema.sql" - with open(schema_path, "r") as f: - schema_sql = f.read() - conn = self.get_connection() try: + # Check if annotations table needs migration + self._migrate_annotations_table(conn) + + # Read schema file and execute + schema_path = Path(__file__).parent / "schema.sql" + with open(schema_path, "r") as f: + schema_sql = f.read() + conn.executescript(schema_sql) conn.commit() finally: conn.close() + def _migrate_annotations_table(self, conn: sqlite3.Connection) -> None: + """ + Migrate annotations table from old schema (class_name) to new schema (class_id). + """ + cursor = conn.cursor() + + # Check if annotations table exists + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='annotations'" + ) + if not cursor.fetchone(): + # Table doesn't exist yet, no migration needed + return + + # Check if table has old schema (class_name column) + cursor.execute("PRAGMA table_info(annotations)") + columns = {row[1]: row for row in cursor.fetchall()} + + if "class_name" in columns and "class_id" not in columns: + # Old schema detected, need to migrate + print("Migrating annotations table to new schema with class_id...") + + # Drop old annotations table (assuming no critical data since this is a new feature) + cursor.execute("DROP TABLE IF EXISTS annotations") + conn.commit() + print("Old annotations table dropped, will be recreated with new schema") + def get_connection(self) -> sqlite3.Connection: """Get database connection with proper settings.""" conn = sqlite3.connect(self.db_path) @@ -593,25 +623,38 @@ class DatabaseManager: def add_annotation( self, image_id: int, - class_name: str, + class_id: int, bbox: Tuple[float, float, float, float], annotator: str, segmentation_mask: Optional[List[List[float]]] = None, verified: bool = False, ) -> int: - """Add manual annotation.""" + """ + Add manual annotation. + + Args: + image_id: ID of the image + class_id: ID of the object class (foreign key to object_classes) + bbox: Bounding box coordinates (normalized 0-1) + annotator: Name of person/tool creating annotation + segmentation_mask: Polygon coordinates for segmentation + verified: Whether annotation has been verified + + Returns: + ID of the inserted annotation + """ conn = self.get_connection() try: cursor = conn.cursor() x_min, y_min, x_max, y_max = bbox cursor.execute( """ - INSERT INTO annotations (image_id, class_name, x_min, y_min, x_max, y_max, segmentation_mask, annotator, verified) + INSERT INTO annotations (image_id, class_id, x_min, y_min, x_max, y_max, segmentation_mask, annotator, verified) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( image_id, - class_name, + class_id, x_min, y_min, x_max, @@ -627,15 +670,178 @@ class DatabaseManager: conn.close() def get_annotations_for_image(self, image_id: int) -> List[Dict]: - """Get all annotations for an image.""" + """ + Get all annotations for an image with class information. + + Args: + image_id: ID of the image + + Returns: + List of annotation dictionaries with joined class information + """ conn = self.get_connection() try: cursor = conn.cursor() - cursor.execute("SELECT * FROM annotations WHERE image_id = ?", (image_id,)) + cursor.execute( + """ + SELECT + a.*, + c.class_name, + c.color as class_color, + c.description as class_description + FROM annotations a + JOIN object_classes c ON a.class_id = c.id + WHERE a.image_id = ? + ORDER BY a.created_at DESC + """, + (image_id,), + ) + annotations = [] + for row in cursor.fetchall(): + ann = dict(row) + if ann.get("segmentation_mask"): + ann["segmentation_mask"] = json.loads(ann["segmentation_mask"]) + annotations.append(ann) + return annotations + finally: + conn.close() + + # ==================== Object Class Operations ==================== + + def get_object_classes(self) -> List[Dict]: + """ + Get all object classes. + + Returns: + List of object class dictionaries + """ + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute("SELECT * FROM object_classes ORDER BY class_name") return [dict(row) for row in cursor.fetchall()] finally: conn.close() + def get_object_class_by_id(self, class_id: int) -> Optional[Dict]: + """Get object class by ID.""" + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute("SELECT * FROM object_classes WHERE id = ?", (class_id,)) + row = cursor.fetchone() + return dict(row) if row else None + finally: + conn.close() + + def get_object_class_by_name(self, class_name: str) -> Optional[Dict]: + """Get object class by name.""" + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM object_classes WHERE class_name = ?", (class_name,) + ) + row = cursor.fetchone() + return dict(row) if row else None + finally: + conn.close() + + def add_object_class( + self, class_name: str, color: str, description: Optional[str] = None + ) -> int: + """ + Add a new object class. + + Args: + class_name: Name of the object class + color: Hex color code (e.g., '#FF0000') + description: Optional description + + Returns: + ID of the inserted object class + """ + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute( + """ + INSERT INTO object_classes (class_name, color, description) + VALUES (?, ?, ?) + """, + (class_name, color, description), + ) + conn.commit() + return cursor.lastrowid + except sqlite3.IntegrityError: + # Class already exists + existing = self.get_object_class_by_name(class_name) + return existing["id"] if existing else None + finally: + conn.close() + + def update_object_class( + self, + class_id: int, + class_name: Optional[str] = None, + color: Optional[str] = None, + description: Optional[str] = None, + ) -> bool: + """ + Update an object class. + + Args: + class_id: ID of the class to update + class_name: New class name (optional) + color: New color (optional) + description: New description (optional) + + Returns: + True if updated, False otherwise + """ + conn = self.get_connection() + try: + updates = {} + if class_name is not None: + updates["class_name"] = class_name + if color is not None: + updates["color"] = color + if description is not None: + updates["description"] = description + + if not updates: + return False + + set_clauses = [f"{key} = ?" for key in updates.keys()] + params = list(updates.values()) + [class_id] + + query = f"UPDATE object_classes SET {', '.join(set_clauses)} WHERE id = ?" + cursor = conn.cursor() + cursor.execute(query, params) + conn.commit() + return cursor.rowcount > 0 + finally: + conn.close() + + def delete_object_class(self, class_id: int) -> bool: + """ + Delete an object class. + + Args: + class_id: ID of the class to delete + + Returns: + True if deleted, False otherwise + """ + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute("DELETE FROM object_classes WHERE id = ?", (class_id,)) + conn.commit() + return cursor.rowcount > 0 + finally: + conn.close() + @staticmethod def calculate_checksum(file_path: str) -> str: """Calculate MD5 checksum of a file.""" diff --git a/src/database/schema.sql b/src/database/schema.sql index b09ffee..64123eb 100644 --- a/src/database/schema.sql +++ b/src/database/schema.sql @@ -44,11 +44,27 @@ CREATE TABLE IF NOT EXISTS detections ( FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE ); --- Annotations table: stores manual annotations (future feature) +-- Object classes table: stores annotation class definitions with colors +CREATE TABLE IF NOT EXISTS object_classes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + class_name TEXT NOT NULL UNIQUE, + color TEXT NOT NULL, -- Hex color code (e.g., '#FF0000') + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + description TEXT +); + +-- Insert default object classes +INSERT OR IGNORE INTO object_classes (class_name, color, description) VALUES + ('cell', '#FF0000', 'Cell object'), + ('nucleus', '#00FF00', 'Cell nucleus'), + ('mitochondria', '#0000FF', 'Mitochondria'), + ('vesicle', '#FFFF00', 'Vesicle'); + +-- Annotations table: stores manual annotations CREATE TABLE IF NOT EXISTS annotations ( id INTEGER PRIMARY KEY AUTOINCREMENT, image_id INTEGER NOT NULL, - class_name TEXT NOT NULL, + class_id INTEGER NOT NULL, x_min REAL NOT NULL CHECK(x_min >= 0 AND x_min <= 1), y_min REAL NOT NULL CHECK(y_min >= 0 AND y_min <= 1), x_max REAL NOT NULL CHECK(x_max >= 0 AND x_max <= 1), @@ -57,7 +73,8 @@ CREATE TABLE IF NOT EXISTS annotations ( annotator TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, verified BOOLEAN DEFAULT 0, - FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE + FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE, + FOREIGN KEY (class_id) REFERENCES object_classes (id) ON DELETE CASCADE ); -- Create indexes for performance optimization @@ -69,4 +86,6 @@ CREATE INDEX IF NOT EXISTS idx_detections_confidence ON detections(confidence); CREATE INDEX IF NOT EXISTS idx_images_relative_path ON images(relative_path); CREATE INDEX IF NOT EXISTS idx_images_added_at ON images(added_at); CREATE INDEX IF NOT EXISTS idx_annotations_image_id ON annotations(image_id); -CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at); \ No newline at end of file +CREATE INDEX IF NOT EXISTS idx_annotations_class_id ON annotations(class_id); +CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at); +CREATE INDEX IF NOT EXISTS idx_object_classes_class_name ON object_classes(class_name); \ No newline at end of file diff --git a/src/gui/tabs/annotation_tab.py b/src/gui/tabs/annotation_tab.py index a373ad0..7933caf 100644 --- a/src/gui/tabs/annotation_tab.py +++ b/src/gui/tabs/annotation_tab.py @@ -1,6 +1,6 @@ """ Annotation tab for the microscopy object detection application. -Future feature for manual annotation. +Manual annotation with pen tool and object class management. """ from PySide6.QtWidgets import ( @@ -21,13 +21,13 @@ from src.database.db_manager import DatabaseManager from src.utils.config_manager import ConfigManager from src.utils.image import Image, ImageLoadError from src.utils.logger import get_logger -from src.gui.widgets import ImageDisplayWidget +from src.gui.widgets import AnnotationCanvasWidget, AnnotationToolsWidget logger = get_logger(__name__) class AnnotationTab(QWidget): - """Annotation tab placeholder (future feature).""" + """Annotation tab for manual image annotation.""" def __init__( self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None @@ -37,6 +37,7 @@ class AnnotationTab(QWidget): self.config_manager = config_manager self.current_image = None self.current_image_path = None + self.current_image_id = None self._setup_ui() @@ -52,49 +53,52 @@ class AnnotationTab(QWidget): self.left_splitter = QSplitter(Qt.Vertical) self.left_splitter.setHandleWidth(10) - # Image display section - display_group = QGroupBox("Image Display") - display_layout = QVBoxLayout() + # Annotation canvas section + canvas_group = QGroupBox("Annotation Canvas") + canvas_layout = QVBoxLayout() - # Use the reusable ImageDisplayWidget - self.image_display_widget = ImageDisplayWidget() - self.image_display_widget.zoom_changed.connect(self._on_zoom_changed) - display_layout.addWidget(self.image_display_widget) + # Use the AnnotationCanvasWidget + self.annotation_canvas = AnnotationCanvasWidget() + self.annotation_canvas.zoom_changed.connect(self._on_zoom_changed) + self.annotation_canvas.annotation_drawn.connect(self._on_annotation_drawn) + canvas_layout.addWidget(self.annotation_canvas) - display_group.setLayout(display_layout) - self.left_splitter.addWidget(display_group) + canvas_group.setLayout(canvas_layout) + self.left_splitter.addWidget(canvas_group) - # Zoom controls info - zoom_info = QLabel("Zoom: Mouse wheel or +/- keys to zoom in/out") - zoom_info.setStyleSheet("QLabel { color: #888; font-style: italic; }") - self.left_splitter.addWidget(zoom_info) + # Controls info + controls_info = QLabel( + "Zoom: Mouse wheel or +/- keys | Drawing: Enable pen and drag mouse" + ) + controls_info.setStyleSheet("QLabel { color: #888; font-style: italic; }") + self.left_splitter.addWidget(controls_info) # } # { Right splitter for annotation tools and controls self.right_splitter = QSplitter(Qt.Vertical) self.right_splitter.setHandleWidth(10) + # Annotation tools section + self.annotation_tools = AnnotationToolsWidget(self.db_manager) + self.annotation_tools.pen_enabled_changed.connect( + self.annotation_canvas.set_pen_enabled + ) + self.annotation_tools.pen_color_changed.connect( + self.annotation_canvas.set_pen_color + ) + self.annotation_tools.pen_width_changed.connect( + self.annotation_canvas.set_pen_width + ) + self.annotation_tools.class_selected.connect(self._on_class_selected) + self.annotation_tools.clear_annotations_requested.connect( + self._on_clear_annotations + ) + self.right_splitter.addWidget(self.annotation_tools) + # Image loading section load_group = QGroupBox("Image Loading") load_layout = QVBoxLayout() - # Future features info - info_group = QGroupBox("Annotation Tool (Future Feature)") - info_layout = QVBoxLayout() - info_label = QLabel( - "Full annotation functionality will be implemented in future version.\n\n" - "Planned Features:\n" - "- Drawing tools for bounding boxes\n" - "- Class label assignment\n" - "- Export annotations to YOLO format\n" - "- Annotation verification" - ) - info_label.setWordWrap(True) - info_layout.addWidget(info_label) - info_group.setLayout(info_layout) - - self.right_splitter.addWidget(info_group) - # Load image button button_layout = QHBoxLayout() self.load_image_btn = QPushButton("Load Image") @@ -158,13 +162,22 @@ class AnnotationTab(QWidget): "annotation_tab/last_directory", str(Path(file_path).parent) ) - # Display image using the ImageDisplayWidget - self.image_display_widget.load_image(self.current_image) + # Get or create image in database + relative_path = str(Path(file_path).name) # Simplified for now + self.current_image_id = self.db_manager.get_or_create_image( + relative_path, + Path(file_path).name, + self.current_image.width, + self.current_image.height, + ) + + # Display image using the AnnotationCanvasWidget + self.annotation_canvas.load_image(self.current_image) # Update info label self._update_image_info() - logger.info(f"Loaded image: {file_path}") + logger.info(f"Loaded image: {file_path} (DB ID: {self.current_image_id})") except ImageLoadError as e: logger.error(f"Failed to load image: {e}") @@ -181,7 +194,7 @@ class AnnotationTab(QWidget): self.image_info_label.setText("No image loaded") return - zoom_percentage = self.image_display_widget.get_zoom_percentage() + zoom_percentage = self.annotation_canvas.get_zoom_percentage() info_text = ( f"File: {Path(self.current_image_path).name}\n" f"Size: {self.current_image.width}x{self.current_image.height} pixels\n" @@ -194,9 +207,36 @@ class AnnotationTab(QWidget): self.image_info_label.setText(info_text) def _on_zoom_changed(self, zoom_scale: float): - """Handle zoom level changes from the image display widget.""" + """Handle zoom level changes from the annotation canvas.""" self._update_image_info() + def _on_annotation_drawn(self, points: list): + """Handle when an annotation stroke is drawn.""" + current_class = self.annotation_tools.get_current_class() + + if not current_class: + logger.warning("Annotation drawn but no object class selected") + QMessageBox.warning( + self, + "No Class Selected", + "Please select an object class before drawing annotations.", + ) + return + + logger.info( + f"Annotation drawn with {len(points)} points for class: {current_class['class_name']}" + ) + # Future: Save annotation to database or export + + def _on_class_selected(self, class_data: dict): + """Handle when an object class is selected.""" + logger.debug(f"Object class selected: {class_data['class_name']}") + + def _on_clear_annotations(self): + """Handle clearing all annotations.""" + self.annotation_canvas.clear_annotations() + logger.info("Cleared all annotations") + def _restore_state(self): """Restore splitter positions from settings.""" settings = QSettings("microscopy_app", "object_detection") diff --git a/src/gui/widgets/__init__.py b/src/gui/widgets/__init__.py index 2946406..df8fad7 100644 --- a/src/gui/widgets/__init__.py +++ b/src/gui/widgets/__init__.py @@ -1,5 +1,7 @@ """GUI widgets for the microscopy object detection application.""" from src.gui.widgets.image_display_widget import ImageDisplayWidget +from src.gui.widgets.annotation_canvas_widget import AnnotationCanvasWidget +from src.gui.widgets.annotation_tools_widget import AnnotationToolsWidget -__all__ = ["ImageDisplayWidget"] +__all__ = ["ImageDisplayWidget", "AnnotationCanvasWidget", "AnnotationToolsWidget"] diff --git a/src/gui/widgets/annotation_canvas_widget.py b/src/gui/widgets/annotation_canvas_widget.py new file mode 100644 index 0000000..7ff2bc5 --- /dev/null +++ b/src/gui/widgets/annotation_canvas_widget.py @@ -0,0 +1,406 @@ +""" +Annotation canvas widget for drawing annotations on images. +Supports pen tool with color selection for manual annotation. +""" + +from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QScrollArea +from PySide6.QtGui import ( + QPixmap, + QImage, + QPainter, + QPen, + QColor, + QKeyEvent, + QMouseEvent, + QPaintEvent, +) +from PySide6.QtCore import Qt, QEvent, Signal, QPoint +from typing import List, Optional, Tuple +import numpy as np + +from src.utils.image import Image, ImageLoadError +from src.utils.logger import get_logger + +logger = get_logger(__name__) + + +class AnnotationCanvasWidget(QWidget): + """ + Widget for displaying images and drawing annotations with pen tool. + + Features: + - Display images with zoom functionality + - Pen tool for drawing annotations + - Configurable pen color and width + - Mouse-based drawing interface + - Zoom in/out with mouse wheel and keyboard + + Signals: + zoom_changed: Emitted when zoom level changes (float zoom_scale) + annotation_drawn: Emitted when a new stroke is completed (list of points) + """ + + zoom_changed = Signal(float) + annotation_drawn = Signal(list) # List of (x, y) points in normalized coordinates + + def __init__(self, parent=None): + """Initialize the annotation canvas widget.""" + super().__init__(parent) + + self.current_image = None + self.original_pixmap = None + self.annotation_pixmap = None # Overlay for annotations + self.zoom_scale = 1.0 + self.zoom_min = 0.1 + self.zoom_max = 10.0 + self.zoom_step = 0.1 + self.zoom_wheel_step = 0.15 + + # Drawing state + self.is_drawing = False + self.pen_enabled = False + self.pen_color = QColor(255, 0, 0, 128) # Default red with 50% alpha + self.pen_width = 3 + self.current_stroke = [] # Points in current stroke + self.all_strokes = [] # All completed strokes + + self._setup_ui() + + def _setup_ui(self): + """Setup user interface.""" + layout = QVBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + + # Scroll area for canvas + self.scroll_area = QScrollArea() + self.scroll_area.setWidgetResizable(True) + self.scroll_area.setMinimumHeight(400) + + self.canvas_label = QLabel("No image loaded") + self.canvas_label.setAlignment(Qt.AlignCenter) + self.canvas_label.setStyleSheet( + "QLabel { background-color: #2b2b2b; color: #888; }" + ) + self.canvas_label.setScaledContents(False) + self.canvas_label.setMouseTracking(True) + + self.scroll_area.setWidget(self.canvas_label) + self.scroll_area.viewport().installEventFilter(self) + + layout.addWidget(self.scroll_area) + self.setLayout(layout) + + self.setFocusPolicy(Qt.StrongFocus) + + def load_image(self, image: Image): + """ + Load and display an image. + + Args: + image: Image object to display + """ + self.current_image = image + self.zoom_scale = 1.0 + self.clear_annotations() + self._display_image() + logger.debug( + f"Loaded image into annotation canvas: {image.width}x{image.height}" + ) + + def clear(self): + """Clear the displayed image and all annotations.""" + self.current_image = None + self.original_pixmap = None + self.annotation_pixmap = None + self.zoom_scale = 1.0 + self.clear_annotations() + self.canvas_label.setText("No image loaded") + self.canvas_label.setPixmap(QPixmap()) + + def clear_annotations(self): + """Clear all drawn annotations.""" + self.all_strokes = [] + self.current_stroke = [] + self.is_drawing = False + if self.annotation_pixmap: + self.annotation_pixmap.fill(Qt.transparent) + self._update_display() + + def _display_image(self): + """Display the current image in the canvas.""" + if self.current_image is None: + return + + try: + # Get RGB image data + if self.current_image.channels == 3: + image_data = self.current_image.get_rgb() + height, width, channels = image_data.shape + else: + image_data = self.current_image.get_grayscale() + height, width = image_data.shape + + image_data = np.ascontiguousarray(image_data) + bytes_per_line = image_data.strides[0] + + qimage = QImage( + image_data.data, + width, + height, + bytes_per_line, + self.current_image.qtimage_format, + ) + + self.original_pixmap = QPixmap.fromImage(qimage) + + # Create transparent overlay for annotations + self.annotation_pixmap = QPixmap(self.original_pixmap.size()) + self.annotation_pixmap.fill(Qt.transparent) + + self._apply_zoom() + + except Exception as e: + logger.error(f"Error displaying image: {e}") + raise ImageLoadError(f"Failed to display image: {str(e)}") + + def _apply_zoom(self): + """Apply current zoom level to the displayed image.""" + if self.original_pixmap is None: + return + + scaled_width = int(self.original_pixmap.width() * self.zoom_scale) + scaled_height = int(self.original_pixmap.height() * self.zoom_scale) + + # Scale both image and annotations + scaled_image = self.original_pixmap.scaled( + scaled_width, + scaled_height, + Qt.KeepAspectRatio, + ( + Qt.SmoothTransformation + if self.zoom_scale >= 1.0 + else Qt.FastTransformation + ), + ) + + scaled_annotations = self.annotation_pixmap.scaled( + scaled_width, + scaled_height, + Qt.KeepAspectRatio, + ( + Qt.SmoothTransformation + if self.zoom_scale >= 1.0 + else Qt.FastTransformation + ), + ) + + # Composite image and annotations + combined = QPixmap(scaled_image.size()) + painter = QPainter(combined) + painter.drawPixmap(0, 0, scaled_image) + painter.drawPixmap(0, 0, scaled_annotations) + painter.end() + + self.canvas_label.setPixmap(combined) + self.canvas_label.setScaledContents(False) + self.canvas_label.adjustSize() + + self.zoom_changed.emit(self.zoom_scale) + + def _update_display(self): + """Update display after drawing.""" + self._apply_zoom() + + def set_pen_enabled(self, enabled: bool): + """Enable or disable pen tool.""" + self.pen_enabled = enabled + if enabled: + self.canvas_label.setCursor(Qt.CrossCursor) + else: + self.canvas_label.setCursor(Qt.ArrowCursor) + + def set_pen_color(self, color: QColor): + """Set pen color.""" + self.pen_color = color + + def set_pen_width(self, width: int): + """Set pen width.""" + self.pen_width = max(1, width) + + def get_zoom_percentage(self) -> int: + """Get current zoom level as percentage.""" + return int(self.zoom_scale * 100) + + def zoom_in(self): + """Zoom in on the image.""" + if self.original_pixmap is None: + return + new_scale = self.zoom_scale + self.zoom_step + if new_scale <= self.zoom_max: + self.zoom_scale = new_scale + self._apply_zoom() + + def zoom_out(self): + """Zoom out from the image.""" + if self.original_pixmap is None: + return + new_scale = self.zoom_scale - self.zoom_step + if new_scale >= self.zoom_min: + self.zoom_scale = new_scale + self._apply_zoom() + + def reset_zoom(self): + """Reset zoom to 100%.""" + if self.original_pixmap is None: + return + self.zoom_scale = 1.0 + self._apply_zoom() + + def _canvas_to_image_coords(self, pos: QPoint) -> Optional[Tuple[int, int]]: + """Convert canvas coordinates to image coordinates, accounting for zoom and centering.""" + if self.original_pixmap is None or self.canvas_label.pixmap() is None: + return None + + # Get the displayed pixmap size (after zoom) + displayed_pixmap = self.canvas_label.pixmap() + displayed_width = displayed_pixmap.width() + displayed_height = displayed_pixmap.height() + + # Calculate offset due to label centering (label might be larger than pixmap) + label_width = self.canvas_label.width() + label_height = self.canvas_label.height() + offset_x = max(0, (label_width - displayed_width) // 2) + offset_y = max(0, (label_height - displayed_height) // 2) + + # Adjust position for offset and convert to image coordinates + x = (pos.x() - offset_x) / self.zoom_scale + y = (pos.y() - offset_y) / self.zoom_scale + + # Check bounds + if ( + 0 <= x < self.original_pixmap.width() + and 0 <= y < self.original_pixmap.height() + ): + return (int(x), int(y)) + return None + + def _image_to_normalized_coords(self, x: int, y: int) -> Tuple[float, float]: + """Convert image coordinates to normalized coordinates (0-1).""" + if self.original_pixmap is None: + return (0.0, 0.0) + + norm_x = x / self.original_pixmap.width() + norm_y = y / self.original_pixmap.height() + return (norm_x, norm_y) + + def mousePressEvent(self, event: QMouseEvent): + """Handle mouse press events for drawing.""" + if not self.pen_enabled or self.annotation_pixmap is None: + super().mousePressEvent(event) + return + + if event.button() == Qt.LeftButton: + # Get accurate position using global coordinates + label_pos = self.canvas_label.mapFromGlobal(event.globalPos()) + img_coords = self._canvas_to_image_coords(label_pos) + + if img_coords: + self.is_drawing = True + self.current_stroke = [img_coords] + + def mouseMoveEvent(self, event: QMouseEvent): + """Handle mouse move events for drawing.""" + if ( + not self.is_drawing + or not self.pen_enabled + or self.annotation_pixmap is None + ): + super().mouseMoveEvent(event) + return + + # Get accurate position using global coordinates + label_pos = self.canvas_label.mapFromGlobal(event.globalPos()) + img_coords = self._canvas_to_image_coords(label_pos) + + if img_coords and len(self.current_stroke) > 0: + # Draw line from last point to current point + painter = QPainter(self.annotation_pixmap) + pen = QPen( + self.pen_color, self.pen_width, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin + ) + painter.setPen(pen) + + last_point = self.current_stroke[-1] + painter.drawLine(last_point[0], last_point[1], img_coords[0], img_coords[1]) + painter.end() + + self.current_stroke.append(img_coords) + self._update_display() + + def mouseReleaseEvent(self, event: QMouseEvent): + """Handle mouse release events to complete a stroke.""" + if not self.is_drawing or event.button() != Qt.LeftButton: + super().mouseReleaseEvent(event) + return + + self.is_drawing = False + + if len(self.current_stroke) > 1: + # Convert to normalized coordinates and save stroke + normalized_stroke = [ + self._image_to_normalized_coords(x, y) for x, y in self.current_stroke + ] + self.all_strokes.append( + { + "points": normalized_stroke, + "color": self.pen_color.name(), + "alpha": self.pen_color.alpha(), + "width": self.pen_width, + } + ) + + # Emit signal with normalized coordinates + self.annotation_drawn.emit(normalized_stroke) + logger.debug(f"Completed stroke with {len(normalized_stroke)} points") + + self.current_stroke = [] + + def get_all_strokes(self) -> List[dict]: + """Get all drawn strokes with metadata.""" + return self.all_strokes + + def keyPressEvent(self, event: QKeyEvent): + """Handle keyboard events for zooming.""" + if event.key() in (Qt.Key_Plus, Qt.Key_Equal): + self.zoom_in() + event.accept() + elif event.key() == Qt.Key_Minus: + self.zoom_out() + event.accept() + elif event.key() == Qt.Key_0 and event.modifiers() == Qt.ControlModifier: + self.reset_zoom() + event.accept() + else: + super().keyPressEvent(event) + + def eventFilter(self, obj, event: QEvent) -> bool: + """Event filter to capture wheel events for zooming.""" + if event.type() == QEvent.Wheel: + wheel_event = event + if self.original_pixmap is not None: + delta = wheel_event.angleDelta().y() + + if delta > 0: + new_scale = self.zoom_scale + self.zoom_wheel_step + if new_scale <= self.zoom_max: + self.zoom_scale = new_scale + self._apply_zoom() + else: + new_scale = self.zoom_scale - self.zoom_wheel_step + if new_scale >= self.zoom_min: + self.zoom_scale = new_scale + self._apply_zoom() + + return True + + return super().eventFilter(obj, event) diff --git a/src/gui/widgets/annotation_tools_widget.py b/src/gui/widgets/annotation_tools_widget.py new file mode 100644 index 0000000..e59e1ff --- /dev/null +++ b/src/gui/widgets/annotation_tools_widget.py @@ -0,0 +1,352 @@ +""" +Annotation tools widget for controlling annotation parameters. +Includes pen tool, color picker, class selection, and annotation management. +""" + +from PySide6.QtWidgets import ( + QWidget, + QVBoxLayout, + QHBoxLayout, + QLabel, + QGroupBox, + QPushButton, + QComboBox, + QSpinBox, + QColorDialog, + QInputDialog, + QMessageBox, +) +from PySide6.QtGui import QColor, QIcon, QPixmap, QPainter +from PySide6.QtCore import Qt, Signal +from typing import Optional, Dict + +from src.database.db_manager import DatabaseManager +from src.utils.logger import get_logger + +logger = get_logger(__name__) + + +class AnnotationToolsWidget(QWidget): + """ + Widget for annotation tool controls. + + Features: + - Enable/disable pen tool + - Color selection for pen + - Object class selection + - Add new object classes + - Pen width control + - Clear annotations + + Signals: + pen_enabled_changed: Emitted when pen tool is enabled/disabled (bool) + pen_color_changed: Emitted when pen color changes (QColor) + pen_width_changed: Emitted when pen width changes (int) + class_selected: Emitted when object class is selected (dict) + clear_annotations_requested: Emitted when clear button is pressed + """ + + pen_enabled_changed = Signal(bool) + pen_color_changed = Signal(QColor) + pen_width_changed = Signal(int) + class_selected = Signal(dict) + clear_annotations_requested = Signal() + + def __init__(self, db_manager: DatabaseManager, parent=None): + """ + Initialize annotation tools widget. + + Args: + db_manager: Database manager instance + parent: Parent widget + """ + super().__init__(parent) + self.db_manager = db_manager + self.pen_enabled = False + self.current_color = QColor(255, 0, 0, 128) # Red with 50% alpha + self.current_class = None + + self._setup_ui() + self._load_object_classes() + + def _setup_ui(self): + """Setup user interface.""" + layout = QVBoxLayout() + + # Pen Tool Group + pen_group = QGroupBox("Pen Tool") + pen_layout = QVBoxLayout() + + # Enable/Disable pen + button_layout = QHBoxLayout() + self.pen_toggle_btn = QPushButton("Enable Pen") + self.pen_toggle_btn.setCheckable(True) + self.pen_toggle_btn.clicked.connect(self._on_pen_toggle) + button_layout.addWidget(self.pen_toggle_btn) + pen_layout.addLayout(button_layout) + + # Pen width control + width_layout = QHBoxLayout() + width_layout.addWidget(QLabel("Pen Width:")) + self.pen_width_spin = QSpinBox() + self.pen_width_spin.setMinimum(1) + self.pen_width_spin.setMaximum(20) + self.pen_width_spin.setValue(3) + self.pen_width_spin.valueChanged.connect(self._on_pen_width_changed) + width_layout.addWidget(self.pen_width_spin) + width_layout.addStretch() + pen_layout.addLayout(width_layout) + + # Color selection + color_layout = QHBoxLayout() + color_layout.addWidget(QLabel("Color:")) + self.color_btn = QPushButton() + self.color_btn.setFixedSize(40, 30) + self.color_btn.clicked.connect(self._on_color_picker) + self._update_color_button() + color_layout.addWidget(self.color_btn) + color_layout.addStretch() + pen_layout.addLayout(color_layout) + + pen_group.setLayout(pen_layout) + layout.addWidget(pen_group) + + # Object Class Group + class_group = QGroupBox("Object Class") + class_layout = QVBoxLayout() + + # Class selection dropdown + self.class_combo = QComboBox() + self.class_combo.currentIndexChanged.connect(self._on_class_selected) + class_layout.addWidget(self.class_combo) + + # Add class button + class_button_layout = QHBoxLayout() + self.add_class_btn = QPushButton("Add New Class") + self.add_class_btn.clicked.connect(self._on_add_class) + class_button_layout.addWidget(self.add_class_btn) + + self.refresh_classes_btn = QPushButton("Refresh") + self.refresh_classes_btn.clicked.connect(self._load_object_classes) + class_button_layout.addWidget(self.refresh_classes_btn) + class_layout.addLayout(class_button_layout) + + # Selected class info + self.class_info_label = QLabel("No class selected") + self.class_info_label.setWordWrap(True) + self.class_info_label.setStyleSheet( + "QLabel { color: #888; font-style: italic; }" + ) + class_layout.addWidget(self.class_info_label) + + class_group.setLayout(class_layout) + layout.addWidget(class_group) + + # Actions Group + actions_group = QGroupBox("Actions") + actions_layout = QVBoxLayout() + + self.clear_btn = QPushButton("Clear All Annotations") + self.clear_btn.clicked.connect(self._on_clear_annotations) + actions_layout.addWidget(self.clear_btn) + + actions_group.setLayout(actions_layout) + layout.addWidget(actions_group) + + layout.addStretch() + self.setLayout(layout) + + def _update_color_button(self): + """Update the color button appearance with current color.""" + pixmap = QPixmap(40, 30) + pixmap.fill(self.current_color) + + # Add border + painter = QPainter(pixmap) + painter.setPen(Qt.black) + painter.drawRect(0, 0, pixmap.width() - 1, pixmap.height() - 1) + painter.end() + + self.color_btn.setIcon(QIcon(pixmap)) + self.color_btn.setStyleSheet(f"background-color: {self.current_color.name()};") + + def _load_object_classes(self): + """Load object classes from database and populate combo box.""" + try: + classes = self.db_manager.get_object_classes() + + # Clear and repopulate combo box + self.class_combo.clear() + self.class_combo.addItem("-- Select Class --", None) + + for cls in classes: + self.class_combo.addItem(cls["class_name"], cls) + + logger.debug(f"Loaded {len(classes)} object classes") + + except Exception as e: + logger.error(f"Error loading object classes: {e}") + QMessageBox.warning( + self, "Error", f"Failed to load object classes:\n{str(e)}" + ) + + def _on_pen_toggle(self, checked: bool): + """Handle pen tool enable/disable.""" + self.pen_enabled = checked + + if checked: + self.pen_toggle_btn.setText("Disable Pen") + self.pen_toggle_btn.setStyleSheet( + "QPushButton { background-color: #4CAF50; }" + ) + else: + self.pen_toggle_btn.setText("Enable Pen") + self.pen_toggle_btn.setStyleSheet("") + + self.pen_enabled_changed.emit(self.pen_enabled) + logger.debug(f"Pen tool {'enabled' if checked else 'disabled'}") + + def _on_pen_width_changed(self, width: int): + """Handle pen width changes.""" + self.pen_width_changed.emit(width) + logger.debug(f"Pen width changed to {width}") + + def _on_color_picker(self): + """Open color picker dialog with alpha support.""" + color = QColorDialog.getColor( + self.current_color, + self, + "Select Pen Color", + QColorDialog.ShowAlphaChannel, # Enable alpha channel selection + ) + + if color.isValid(): + self.current_color = color + self._update_color_button() + self.pen_color_changed.emit(color) + logger.debug( + f"Pen color changed to {color.name()} with alpha {color.alpha()}" + ) + + def _on_class_selected(self, index: int): + """Handle object class selection.""" + class_data = self.class_combo.currentData() + + if class_data: + self.current_class = class_data + + # Update info label + info_text = ( + f"Class: {class_data['class_name']}\n" f"Color: {class_data['color']}" + ) + if class_data.get("description"): + info_text += f"\nDescription: {class_data['description']}" + + self.class_info_label.setText(info_text) + + # Update pen color to match class color with semi-transparency + class_color = QColor(class_data["color"]) + if class_color.isValid(): + # Add 50% alpha for semi-transparency + class_color.setAlpha(128) + self.current_color = class_color + self._update_color_button() + self.pen_color_changed.emit(class_color) + + self.class_selected.emit(class_data) + logger.debug(f"Selected class: {class_data['class_name']}") + else: + self.current_class = None + self.class_info_label.setText("No class selected") + + def _on_add_class(self): + """Handle adding a new object class.""" + # Get class name + class_name, ok = QInputDialog.getText( + self, "Add Object Class", "Enter class name:" + ) + + if not ok or not class_name.strip(): + return + + class_name = class_name.strip() + + # Check if class already exists + existing = self.db_manager.get_object_class_by_name(class_name) + if existing: + QMessageBox.warning( + self, "Class Exists", f"A class named '{class_name}' already exists." + ) + return + + # Get color + color = QColorDialog.getColor(self.current_color, self, "Select Class Color") + + if not color.isValid(): + return + + # Get optional description + description, ok = QInputDialog.getText( + self, "Class Description", "Enter class description (optional):" + ) + + if not ok: + description = None + + # Add to database + try: + class_id = self.db_manager.add_object_class( + class_name, color.name(), description.strip() if description else None + ) + + logger.info(f"Added new object class: {class_name} (ID: {class_id})") + + # Reload classes and select the new one + self._load_object_classes() + + # Find and select the newly added class + for i in range(self.class_combo.count()): + class_data = self.class_combo.itemData(i) + if class_data and class_data.get("id") == class_id: + self.class_combo.setCurrentIndex(i) + break + + QMessageBox.information( + self, "Success", f"Class '{class_name}' added successfully!" + ) + + except Exception as e: + logger.error(f"Error adding object class: {e}") + QMessageBox.critical( + self, "Error", f"Failed to add object class:\n{str(e)}" + ) + + def _on_clear_annotations(self): + """Handle clear annotations button.""" + reply = QMessageBox.question( + self, + "Clear Annotations", + "Are you sure you want to clear all annotations?", + QMessageBox.Yes | QMessageBox.No, + QMessageBox.No, + ) + + if reply == QMessageBox.Yes: + self.clear_annotations_requested.emit() + logger.debug("Clear annotations requested") + + def get_current_class(self) -> Optional[Dict]: + """Get currently selected object class.""" + return self.current_class + + def get_pen_color(self) -> QColor: + """Get current pen color.""" + return self.current_color + + def get_pen_width(self) -> int: + """Get current pen width.""" + return self.pen_width_spin.value() + + def is_pen_enabled(self) -> bool: + """Check if pen tool is enabled.""" + return self.pen_enabled -- 2.49.1 From 710b6844562455819092f74949175ee65a40bac1 Mon Sep 17 00:00:00 2001 From: Martin Laasmaa Date: Mon, 8 Dec 2025 23:59:44 +0200 Subject: [PATCH 07/13] Updating annotations --- src/gui/tabs/annotation_tab.py | 127 ++++++++++++++++++++ src/gui/widgets/annotation_canvas_widget.py | 96 +++++++++++++++ src/gui/widgets/annotation_tools_widget.py | 34 ++++++ 3 files changed, 257 insertions(+) diff --git a/src/gui/tabs/annotation_tab.py b/src/gui/tabs/annotation_tab.py index 7933caf..6ce6d57 100644 --- a/src/gui/tabs/annotation_tab.py +++ b/src/gui/tabs/annotation_tab.py @@ -93,6 +93,12 @@ class AnnotationTab(QWidget): self.annotation_tools.clear_annotations_requested.connect( self._on_clear_annotations ) + self.annotation_tools.process_annotations_requested.connect( + self._on_process_annotations + ) + self.annotation_tools.show_annotations_requested.connect( + self._on_show_annotations + ) self.right_splitter.addWidget(self.annotation_tools) # Image loading section @@ -237,6 +243,127 @@ class AnnotationTab(QWidget): self.annotation_canvas.clear_annotations() logger.info("Cleared all annotations") + def _on_process_annotations(self): + """Process annotations and save to database.""" + # Check if we have an image loaded + if not self.current_image or not self.current_image_id: + QMessageBox.warning( + self, "No Image", "Please load an image before processing annotations." + ) + return + + # Get current class + current_class = self.annotation_tools.get_current_class() + if not current_class: + QMessageBox.warning( + self, + "No Class Selected", + "Please select an object class before processing annotations.", + ) + return + + # Compute bounding box and polyline from annotations + bounds = self.annotation_canvas.compute_annotation_bounds() + if not bounds: + QMessageBox.warning( + self, + "No Annotations", + "Please draw some annotations before processing.", + ) + return + + polyline = self.annotation_canvas.get_annotation_polyline() + + try: + # Save annotation to database + annotation_id = self.db_manager.add_annotation( + image_id=self.current_image_id, + class_id=current_class["id"], + bbox=bounds, + annotator="manual", + segmentation_mask=polyline, + verified=False, + ) + + logger.info( + f"Saved annotation (ID: {annotation_id}) for class '{current_class['class_name']}' " + f"with {len(polyline)} polyline points" + ) + + QMessageBox.information( + self, + "Success", + f"Annotation saved successfully!\n\n" + f"Class: {current_class['class_name']}\n" + f"Bounding box: ({bounds[0]:.3f}, {bounds[1]:.3f}) to ({bounds[2]:.3f}, {bounds[3]:.3f})\n" + f"Polyline points: {len(polyline)}", + ) + + # Optionally clear annotations after saving + reply = QMessageBox.question( + self, + "Clear Annotations", + "Do you want to clear the annotations to start a new one?", + QMessageBox.Yes | QMessageBox.No, + QMessageBox.Yes, + ) + + if reply == QMessageBox.Yes: + self.annotation_canvas.clear_annotations() + + except Exception as e: + logger.error(f"Failed to save annotation: {e}") + QMessageBox.critical(self, "Error", f"Failed to save annotation:\n{str(e)}") + + def _on_show_annotations(self): + """Load and display saved annotations from database.""" + # Check if we have an image loaded + if not self.current_image or not self.current_image_id: + QMessageBox.warning( + self, "No Image", "Please load an image to view its annotations." + ) + return + + try: + # Clear current annotations + self.annotation_canvas.clear_annotations() + + # Retrieve annotations from database + annotations = self.db_manager.get_annotations_for_image( + self.current_image_id + ) + + if not annotations: + QMessageBox.information( + self, "No Annotations", "No saved annotations found for this image." + ) + return + + # Draw each annotation's polyline + drawn_count = 0 + for ann in annotations: + if ann.get("segmentation_mask"): + polyline = ann["segmentation_mask"] + color = ann.get("class_color", "#FF0000") + + # Draw the polyline + self.annotation_canvas.draw_saved_polyline(polyline, color, width=3) + drawn_count += 1 + + logger.info(f"Displayed {drawn_count} saved annotations from database") + + QMessageBox.information( + self, + "Annotations Loaded", + f"Successfully loaded and displayed {drawn_count} annotation(s).", + ) + + except Exception as e: + logger.error(f"Failed to load annotations: {e}") + QMessageBox.critical( + self, "Error", f"Failed to load annotations:\n{str(e)}" + ) + def _restore_state(self): """Restore splitter positions from settings.""" settings = QSettings("microscopy_app", "object_detection") diff --git a/src/gui/widgets/annotation_canvas_widget.py b/src/gui/widgets/annotation_canvas_widget.py index 7ff2bc5..a2a57a6 100644 --- a/src/gui/widgets/annotation_canvas_widget.py +++ b/src/gui/widgets/annotation_canvas_widget.py @@ -369,6 +369,102 @@ class AnnotationCanvasWidget(QWidget): """Get all drawn strokes with metadata.""" return self.all_strokes + def compute_annotation_bounds(self) -> Optional[Tuple[float, float, float, float]]: + """ + Compute bounding box that encompasses all annotation strokes. + + Returns: + Tuple of (x_min, y_min, x_max, y_max) in normalized coordinates (0-1), + or None if no annotations exist. + """ + if not self.all_strokes: + return None + + # Find min/max across all strokes + all_x = [] + all_y = [] + + for stroke in self.all_strokes: + for x, y in stroke["points"]: + all_x.append(x) + all_y.append(y) + + if not all_x: + return None + + x_min = min(all_x) + y_min = min(all_y) + x_max = max(all_x) + y_max = max(all_y) + + return (x_min, y_min, x_max, y_max) + + def get_annotation_polyline(self) -> List[List[float]]: + """ + Get polyline coordinates representing all annotation strokes. + + Returns: + List of [x, y] coordinate pairs in normalized coordinates (0-1). + """ + polyline = [] + + for stroke in self.all_strokes: + polyline.extend(stroke["points"]) + + return polyline + + def draw_saved_polyline( + self, polyline: List[List[float]], color: str, width: int = 3 + ): + """ + Draw a polyline from database coordinates onto the annotation canvas. + + Args: + polyline: List of [x, y] coordinate pairs in normalized coordinates (0-1) + color: Color hex string (e.g., '#FF0000') + width: Line width in pixels + """ + if not self.annotation_pixmap or not self.original_pixmap: + logger.warning("Cannot draw polyline: no image loaded") + return + + if len(polyline) < 2: + logger.warning("Polyline has less than 2 points, cannot draw") + return + + # Convert normalized coordinates to image coordinates + img_coords = [] + for x_norm, y_norm in polyline: + x = int(x_norm * self.original_pixmap.width()) + y = int(y_norm * self.original_pixmap.height()) + img_coords.append((x, y)) + + # Draw polyline on annotation pixmap + painter = QPainter(self.annotation_pixmap) + pen_color = QColor(color) + pen_color.setAlpha(128) # Add semi-transparency + pen = QPen(pen_color, width, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin) + painter.setPen(pen) + + # Draw lines between consecutive points + for i in range(len(img_coords) - 1): + x1, y1 = img_coords[i] + x2, y2 = img_coords[i + 1] + painter.drawLine(x1, y1, x2, y2) + + painter.end() + + # Store in all_strokes for consistency + self.all_strokes.append( + {"points": polyline, "color": color, "alpha": 128, "width": width} + ) + + # Update display + self._update_display() + logger.debug( + f"Drew saved polyline with {len(polyline)} points in color {color}" + ) + def keyPressEvent(self, event: QKeyEvent): """Handle keyboard events for zooming.""" if event.key() in (Qt.Key_Plus, Qt.Key_Equal): diff --git a/src/gui/widgets/annotation_tools_widget.py b/src/gui/widgets/annotation_tools_widget.py index e59e1ff..4153dc1 100644 --- a/src/gui/widgets/annotation_tools_widget.py +++ b/src/gui/widgets/annotation_tools_widget.py @@ -51,6 +51,8 @@ class AnnotationToolsWidget(QWidget): pen_width_changed = Signal(int) class_selected = Signal(dict) clear_annotations_requested = Signal() + process_annotations_requested = Signal() + show_annotations_requested = Signal() def __init__(self, db_manager: DatabaseManager, parent=None): """ @@ -146,6 +148,20 @@ class AnnotationToolsWidget(QWidget): actions_group = QGroupBox("Actions") actions_layout = QVBoxLayout() + self.process_btn = QPushButton("Process Annotations") + self.process_btn.clicked.connect(self._on_process_annotations) + self.process_btn.setStyleSheet( + "QPushButton { background-color: #2196F3; color: white; font-weight: bold; }" + ) + actions_layout.addWidget(self.process_btn) + + self.show_btn = QPushButton("Show Saved Annotations") + self.show_btn.clicked.connect(self._on_show_annotations) + self.show_btn.setStyleSheet( + "QPushButton { background-color: #4CAF50; color: white; }" + ) + actions_layout.addWidget(self.show_btn) + self.clear_btn = QPushButton("Clear All Annotations") self.clear_btn.clicked.connect(self._on_clear_annotations) actions_layout.addWidget(self.clear_btn) @@ -335,6 +351,24 @@ class AnnotationToolsWidget(QWidget): self.clear_annotations_requested.emit() logger.debug("Clear annotations requested") + def _on_process_annotations(self): + """Handle process annotations button.""" + if not self.current_class: + QMessageBox.warning( + self, + "No Class Selected", + "Please select an object class before processing annotations.", + ) + return + + self.process_annotations_requested.emit() + logger.debug("Process annotations requested") + + def _on_show_annotations(self): + """Handle show annotations button.""" + self.show_annotations_requested.emit() + logger.debug("Show annotations requested") + def get_current_class(self) -> Optional[Dict]: """Get currently selected object class.""" return self.current_class -- 2.49.1 From 12f2bf94d57cafdcdd74082ae11a6d598a578bc2 Mon Sep 17 00:00:00 2001 From: Martin Laasmaa Date: Tue, 9 Dec 2025 15:42:42 +0200 Subject: [PATCH 08/13] Updating polyline saving and drawing --- src/gui/tabs/annotation_tab.py | 91 ++++--- src/gui/widgets/annotation_canvas_widget.py | 251 +++++++++++++++++--- 2 files changed, 265 insertions(+), 77 deletions(-) diff --git a/src/gui/tabs/annotation_tab.py b/src/gui/tabs/annotation_tab.py index 6ce6d57..ae16446 100644 --- a/src/gui/tabs/annotation_tab.py +++ b/src/gui/tabs/annotation_tab.py @@ -262,9 +262,9 @@ class AnnotationTab(QWidget): ) return - # Compute bounding box and polyline from annotations - bounds = self.annotation_canvas.compute_annotation_bounds() - if not bounds: + # Compute annotation parameters asbounding boxes and polylines from annotations + parameters = self.annotation_canvas.get_annotation_parameters() + if not parameters: QMessageBox.warning( self, "No Annotations", @@ -272,48 +272,56 @@ class AnnotationTab(QWidget): ) return - polyline = self.annotation_canvas.get_annotation_polyline() + # polyline = self.annotation_canvas.get_annotation_polyline() - try: - # Save annotation to database - annotation_id = self.db_manager.add_annotation( - image_id=self.current_image_id, - class_id=current_class["id"], - bbox=bounds, - annotator="manual", - segmentation_mask=polyline, - verified=False, - ) + for param in parameters: + bounds = param["bbox"] + polyline = param["polyline"] - logger.info( - f"Saved annotation (ID: {annotation_id}) for class '{current_class['class_name']}' " - f"with {len(polyline)} polyline points" - ) + try: + # Save annotation to database + annotation_id = self.db_manager.add_annotation( + image_id=self.current_image_id, + class_id=current_class["id"], + bbox=bounds, + annotator="manual", + segmentation_mask=polyline, + verified=False, + ) - QMessageBox.information( - self, - "Success", - f"Annotation saved successfully!\n\n" - f"Class: {current_class['class_name']}\n" - f"Bounding box: ({bounds[0]:.3f}, {bounds[1]:.3f}) to ({bounds[2]:.3f}, {bounds[3]:.3f})\n" - f"Polyline points: {len(polyline)}", - ) + logger.info( + f"Saved annotation (ID: {annotation_id}) for class '{current_class['class_name']}' " + f"Bounding box: ({bounds[0]:.3f}, {bounds[1]:.3f}) to ({bounds[2]:.3f}, {bounds[3]:.3f})\n" + f"with {len(polyline)} polyline points" + ) - # Optionally clear annotations after saving - reply = QMessageBox.question( - self, - "Clear Annotations", - "Do you want to clear the annotations to start a new one?", - QMessageBox.Yes | QMessageBox.No, - QMessageBox.Yes, - ) + # QMessageBox.information( + # self, + # "Success", + # f"Annotation saved successfully!\n\n" + # f"Class: {current_class['class_name']}\n" + # f"Bounding box: ({bounds[0]:.3f}, {bounds[1]:.3f}) to ({bounds[2]:.3f}, {bounds[3]:.3f})\n" + # f"Polyline points: {len(polyline)}", + # ) - if reply == QMessageBox.Yes: - self.annotation_canvas.clear_annotations() + except Exception as e: + logger.error(f"Failed to save annotation: {e}") + QMessageBox.critical( + self, "Error", f"Failed to save annotation:\n{str(e)}" + ) - except Exception as e: - logger.error(f"Failed to save annotation: {e}") - QMessageBox.critical(self, "Error", f"Failed to save annotation:\n{str(e)}") + # Optionally clear annotations after saving + reply = QMessageBox.question( + self, + "Clear Annotations", + "Do you want to clear the annotations to start a new one?", + QMessageBox.Yes | QMessageBox.No, + QMessageBox.Yes, + ) + + if reply == QMessageBox.Yes: + self.annotation_canvas.clear_annotations() + logger.info("Cleared annotations after saving") def _on_show_annotations(self): """Load and display saved annotations from database.""" @@ -348,6 +356,11 @@ class AnnotationTab(QWidget): # Draw the polyline self.annotation_canvas.draw_saved_polyline(polyline, color, width=3) + self.annotation_canvas.draw_saved_bbox( + [ann["x_min"], ann["y_min"], ann["x_max"], ann["y_max"]], + color, + width=3, + ) drawn_count += 1 logger.info(f"Displayed {drawn_count} saved annotations from database") diff --git a/src/gui/widgets/annotation_canvas_widget.py b/src/gui/widgets/annotation_canvas_widget.py index a2a57a6..9851e97 100644 --- a/src/gui/widgets/annotation_canvas_widget.py +++ b/src/gui/widgets/annotation_canvas_widget.py @@ -3,6 +3,8 @@ Annotation canvas widget for drawing annotations on images. Supports pen tool with color selection for manual annotation. """ +import numpy as np + from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QScrollArea from PySide6.QtGui import ( QPixmap, @@ -15,12 +17,17 @@ from PySide6.QtGui import ( QPaintEvent, ) from PySide6.QtCore import Qt, QEvent, Signal, QPoint -from typing import List, Optional, Tuple -import numpy as np +from typing import Any, Dict, List, Optional, Tuple + +from scipy.ndimage import binary_dilation, label, binary_fill_holes, find_objects +from skimage.measure import find_contours from src.utils.image import Image, ImageLoadError from src.utils.logger import get_logger +# For debugging visualization +import pylab as plt + logger = get_logger(__name__) @@ -369,49 +376,149 @@ class AnnotationCanvasWidget(QWidget): """Get all drawn strokes with metadata.""" return self.all_strokes - def compute_annotation_bounds(self) -> Optional[Tuple[float, float, float, float]]: + # def get_annotation_bounds(self) -> Optional[Tuple[float, float, float, float]]: + # """ + # Compute bounding box that encompasses all annotation strokes. + + # Returns: + # Tuple of (x_min, y_min, x_max, y_max) in normalized coordinates (0-1), + # or None if no annotations exist. + # """ + # if not self.all_strokes: + # return None + + # # Find min/max across all strokes + # all_x = [] + # all_y = [] + + # for stroke in self.all_strokes: + # for x, y in stroke["points"]: + # all_x.append(x) + # all_y.append(y) + + # if not all_x: + # return None + + # x_min = min(all_x) + # y_min = min(all_y) + # x_max = max(all_x) + # y_max = max(all_y) + + # return (x_min, y_min, x_max, y_max) + + # def get_annotation_polyline(self) -> List[List[float]]: + # """ + # Get polyline coordinates representing all annotation strokes. + + # Returns: + # List of [x, y] coordinate pairs in normalized coordinates (0-1). + # """ + # polyline = [] + + # fig = plt.figure() + # ax1 = fig.add_subplot(411) + # ax2 = fig.add_subplot(412) + # ax3 = fig.add_subplot(413) + # ax4 = fig.add_subplot(414) + + # # Get np.arrays from annotation_pixmap accoriding to the color of the stroke + # qimage = self.annotation_pixmap.toImage() + # arr = np.ndarray( + # (qimage.height(), qimage.width(), 4), + # buffer=qimage.constBits(), + # strides=[qimage.bytesPerLine(), 4, 1], + # dtype=np.uint8, + # ) + # print(arr.shape, arr.dtype, arr.min(), arr.max()) + # arr = np.sum(arr, axis=2) + # ax1.imshow(arr) + + # arr_bin = arr > 0 + # ax2.imshow(arr_bin) + + # arr_bin = binary_fill_holes(arr_bin) + # ax3.imshow(arr_bin) + + # labels, _number_of_features = label( + # arr_bin, + # ) + + # ax4.imshow(labels) + + # objects = find_objects(labels) + # bounding_boxes = np.array( + # [[obj[0].start, obj[0].stop, obj[1].start, obj[1].stop] for obj in objects] + # ) / np.array([arr.shape[0], arr.shape[1]]) + + # print(objects) + # print(bounding_boxes) + # print(np.array([arr.shape[0], arr.shape[1]])) + + # polylines = find_contours(arr_bin, 0.5) + # for pl in polylines: + # ax1.plot(pl[:, 1], pl[:, 0], "k") + + # print(arr.shape, arr.dtype, arr.min(), arr.max()) + + # plt.show() + + # return polyline + + def get_annotation_parameters(self) -> Dict[str, Any]: """ - Compute bounding box that encompasses all annotation strokes. + Get all annotation parameters including bounding box and polyline. Returns: - Tuple of (x_min, y_min, x_max, y_max) in normalized coordinates (0-1), - or None if no annotations exist. + Dictionary containing: + - 'bbox': Bounding box coordinates (x_min, y_min, x_max, y_max) + - 'polyline': List of [x, y] coordinate pairs """ - if not self.all_strokes: + + # Get np.arrays from annotation_pixmap accoriding to the color of the stroke + qimage = self.annotation_pixmap.toImage() + arr = np.ndarray( + (qimage.height(), qimage.width(), 4), + buffer=qimage.constBits(), + strides=[qimage.bytesPerLine(), 4, 1], + dtype=np.uint8, + ) + arr = np.sum(arr, axis=2) + arr_bin = arr > 0 + arr_bin = binary_fill_holes(arr_bin) + + labels, _number_of_features = label( + arr_bin, + ) + if _number_of_features == 0: return None - # Find min/max across all strokes - all_x = [] - all_y = [] + objects = find_objects(labels) + w, h = arr.shape + bounding_boxes = [ + [obj[0].start / w, obj[1].start / h, obj[0].stop / w, obj[1].stop / h] + for obj in objects + ] - for stroke in self.all_strokes: - for x, y in stroke["points"]: - all_x.append(x) - all_y.append(y) + polylines = find_contours(arr_bin, 0.5) + params = [] + for i, pl in enumerate(polylines): + # pl is in [row, col] format from find_contours + # We need to normalize: row/height, col/width + # w = height (rows), h = width (cols) from line 510 + normalized_polyline = (pl[::-1] / np.array([w, h])).tolist() - if not all_x: - return None + logger.debug(f"Polyline {i}: {len(pl)} points") + logger.debug(f" w={w} (height), h={h} (width)") + logger.debug(f" First 3 normalized points: {normalized_polyline[:3]}") - x_min = min(all_x) - y_min = min(all_y) - x_max = max(all_x) - y_max = max(all_y) + params.append( + { + "bbox": bounding_boxes[i], + "polyline": normalized_polyline, + } + ) - return (x_min, y_min, x_max, y_max) - - def get_annotation_polyline(self) -> List[List[float]]: - """ - Get polyline coordinates representing all annotation strokes. - - Returns: - List of [x, y] coordinate pairs in normalized coordinates (0-1). - """ - polyline = [] - - for stroke in self.all_strokes: - polyline.extend(stroke["points"]) - - return polyline + return params def draw_saved_polyline( self, polyline: List[List[float]], color: str, width: int = 3 @@ -433,12 +540,22 @@ class AnnotationCanvasWidget(QWidget): return # Convert normalized coordinates to image coordinates + # Polyline is stored as [[y_norm, x_norm], ...] (row_norm, col_norm format) + img_width = self.original_pixmap.width() + img_height = self.original_pixmap.height() + + logger.debug(f"Loading polyline with {len(polyline)} points") + logger.debug(f" Image size: {img_width}x{img_height}") + logger.debug(f" First 3 normalized points from DB: {polyline[:3]}") + img_coords = [] - for x_norm, y_norm in polyline: - x = int(x_norm * self.original_pixmap.width()) - y = int(y_norm * self.original_pixmap.height()) + for y_norm, x_norm in polyline: + x = int(x_norm * img_width) + y = int(y_norm * img_height) img_coords.append((x, y)) + logger.debug(f" First 3 pixel coords: {img_coords[:3]}") + # Draw polyline on annotation pixmap painter = QPainter(self.annotation_pixmap) pen_color = QColor(color) @@ -465,6 +582,64 @@ class AnnotationCanvasWidget(QWidget): f"Drew saved polyline with {len(polyline)} points in color {color}" ) + def draw_saved_bbox(self, bbox: List[float], color: str, width: int = 3): + """ + Draw a bounding box from database coordinates onto the annotation canvas. + + Args: + bbox: Bounding box as [y_min_norm, x_min_norm, y_max_norm, x_max_norm] + in normalized coordinates (0-1) + color: Color hex string (e.g., '#FF0000') + width: Line width in pixels + """ + if not self.annotation_pixmap or not self.original_pixmap: + logger.warning("Cannot draw bounding box: no image loaded") + return + + if len(bbox) != 4: + logger.warning( + f"Invalid bounding box format: expected 4 values, got {len(bbox)}" + ) + return + + # Convert normalized coordinates to image coordinates + # bbox format: [y_min_norm, x_min_norm, y_max_norm, x_max_norm] + img_width = self.original_pixmap.width() + img_height = self.original_pixmap.height() + + y_min_norm, x_min_norm, y_max_norm, x_max_norm = bbox + x_min = int(x_min_norm * img_width) + y_min = int(y_min_norm * img_height) + x_max = int(x_max_norm * img_width) + y_max = int(y_max_norm * img_height) + + logger.debug(f"Drawing bounding box: {bbox}") + logger.debug(f" Image size: {img_width}x{img_height}") + logger.debug(f" Pixel coords: ({x_min}, {y_min}) to ({x_max}, {y_max})") + + # Draw bounding box on annotation pixmap + painter = QPainter(self.annotation_pixmap) + pen_color = QColor(color) + pen_color.setAlpha(128) # Add semi-transparency + pen = QPen(pen_color, width, Qt.SolidLine, Qt.SquareCap, Qt.MiterJoin) + painter.setPen(pen) + + # Draw rectangle + rect_width = x_max - x_min + rect_height = y_max - y_min + painter.drawRect(x_min, y_min, rect_width, rect_height) + + painter.end() + + # Store in all_strokes for consistency + self.all_strokes.append( + {"bbox": bbox, "color": color, "alpha": 128, "width": width} + ) + + # Update display + self._update_display() + logger.debug(f"Drew saved bounding box in color {color}") + def keyPressEvent(self, event: QKeyEvent): """Handle keyboard events for zooming.""" if event.key() in (Qt.Key_Plus, Qt.Key_Equal): -- 2.49.1 From 73cb69848823a9a17944a64b034364f93e5797ff Mon Sep 17 00:00:00 2001 From: Martin Laasmaa Date: Tue, 9 Dec 2025 22:00:56 +0200 Subject: [PATCH 09/13] Saving state before replacing annotation tool --- src/gui/widgets/annotation_canvas_widget.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gui/widgets/annotation_canvas_widget.py b/src/gui/widgets/annotation_canvas_widget.py index 9851e97..3f57c2a 100644 --- a/src/gui/widgets/annotation_canvas_widget.py +++ b/src/gui/widgets/annotation_canvas_widget.py @@ -495,7 +495,7 @@ class AnnotationCanvasWidget(QWidget): objects = find_objects(labels) w, h = arr.shape bounding_boxes = [ - [obj[0].start / w, obj[1].start / h, obj[0].stop / w, obj[1].stop / h] + [obj[1].start / h, obj[0].start / w, obj[1].stop / h, obj[0].stop / w] for obj in objects ] @@ -607,7 +607,7 @@ class AnnotationCanvasWidget(QWidget): img_width = self.original_pixmap.width() img_height = self.original_pixmap.height() - y_min_norm, x_min_norm, y_max_norm, x_max_norm = bbox + x_min_norm, y_min_norm, x_max_norm, y_max_norm = bbox x_min = int(x_min_norm * img_width) y_min = int(y_min_norm * img_height) x_max = int(x_max_norm * img_width) -- 2.49.1 From dad5c2bf746f6d753be203a447cc696e36f74bcc Mon Sep 17 00:00:00 2001 From: Martin Laasmaa Date: Tue, 9 Dec 2025 22:44:23 +0200 Subject: [PATCH 10/13] Updating --- src/gui/tabs/annotation_tab.py | 267 ++++++++-------- src/gui/widgets/annotation_canvas_widget.py | 322 ++++++++++++++------ src/gui/widgets/annotation_tools_widget.py | 46 ++- 3 files changed, 411 insertions(+), 224 deletions(-) diff --git a/src/gui/tabs/annotation_tab.py b/src/gui/tabs/annotation_tab.py index ae16446..cec0d28 100644 --- a/src/gui/tabs/annotation_tab.py +++ b/src/gui/tabs/annotation_tab.py @@ -38,6 +38,7 @@ class AnnotationTab(QWidget): self.current_image = None self.current_image_path = None self.current_image_id = None + self.current_annotations = [] self._setup_ui() @@ -89,6 +90,13 @@ class AnnotationTab(QWidget): self.annotation_tools.pen_width_changed.connect( self.annotation_canvas.set_pen_width ) + # RDP simplification controls + self.annotation_tools.simplify_on_finish_changed.connect( + self._on_simplify_on_finish_changed + ) + self.annotation_tools.simplify_epsilon_changed.connect( + self._on_simplify_epsilon_changed + ) self.annotation_tools.class_selected.connect(self._on_class_selected) self.annotation_tools.clear_annotations_requested.connect( self._on_clear_annotations @@ -96,9 +104,6 @@ class AnnotationTab(QWidget): self.annotation_tools.process_annotations_requested.connect( self._on_process_annotations ) - self.annotation_tools.show_annotations_requested.connect( - self._on_show_annotations - ) self.right_splitter.addWidget(self.annotation_tools) # Image loading section @@ -180,6 +185,9 @@ class AnnotationTab(QWidget): # Display image using the AnnotationCanvasWidget self.annotation_canvas.load_image(self.current_image) + # Load and display any existing annotations for this image + self._load_annotations_for_current_image() + # Update info label self._update_image_info() @@ -217,7 +225,22 @@ class AnnotationTab(QWidget): self._update_image_info() def _on_annotation_drawn(self, points: list): - """Handle when an annotation stroke is drawn.""" + """ + Handle when an annotation stroke is drawn. + + Saves the new annotation directly to the database and refreshes the + on-canvas display of annotations for the current image. + """ + # Ensure we have an image loaded and in the DB + if not self.current_image or not self.current_image_id: + logger.warning("Annotation drawn but no image loaded") + QMessageBox.warning( + self, + "No Image", + "Please load an image before drawing annotations.", + ) + return + current_class = self.annotation_tools.get_current_class() if not current_class: @@ -229,14 +252,58 @@ class AnnotationTab(QWidget): ) return - logger.info( - f"Annotation drawn with {len(points)} points for class: {current_class['class_name']}" - ) - # Future: Save annotation to database or export + if not points: + logger.warning("Annotation drawn with no points, ignoring") + return + + # points are [(x_norm, y_norm), ...] + xs = [p[0] for p in points] + ys = [p[1] for p in points] + x_min, x_max = min(xs), max(xs) + y_min, y_max = min(ys), max(ys) + + # Store segmentation mask in [y_norm, x_norm] format to match DB + db_polyline = [[float(y), float(x)] for (x, y) in points] + + try: + annotation_id = self.db_manager.add_annotation( + image_id=self.current_image_id, + class_id=current_class["id"], + bbox=(x_min, y_min, x_max, y_max), + annotator="manual", + segmentation_mask=db_polyline, + verified=False, + ) + + logger.info( + f"Saved annotation (ID: {annotation_id}) for class " + f"'{current_class['class_name']}' " + f"Bounding box: ({x_min:.3f}, {y_min:.3f}) to ({x_max:.3f}, {y_max:.3f})\n" + f"with {len(points)} polyline points" + ) + + # Reload annotations from DB and redraw (respecting current class filter) + self._load_annotations_for_current_image() + + except Exception as e: + logger.error(f"Failed to save annotation: {e}") + QMessageBox.critical(self, "Error", f"Failed to save annotation:\n{str(e)}") + + def _on_simplify_on_finish_changed(self, enabled: bool): + """Update canvas simplify-on-finish flag from tools widget.""" + self.annotation_canvas.simplify_on_finish = enabled + logger.debug(f"Annotation simplification on finish set to {enabled}") + + def _on_simplify_epsilon_changed(self, epsilon: float): + """Update canvas RDP epsilon from tools widget.""" + self.annotation_canvas.simplify_epsilon = float(epsilon) + logger.debug(f"Annotation simplification epsilon set to {epsilon}") def _on_class_selected(self, class_data: dict): """Handle when an object class is selected.""" logger.debug(f"Object class selected: {class_data['class_name']}") + # When a class is selected, update which annotations are visible + self._redraw_annotations_for_current_filter() def _on_clear_annotations(self): """Handle clearing all annotations.""" @@ -244,138 +311,92 @@ class AnnotationTab(QWidget): logger.info("Cleared all annotations") def _on_process_annotations(self): - """Process annotations and save to database.""" - # Check if we have an image loaded + """ + Legacy hook kept for UI compatibility. + + Annotations are now saved automatically when a stroke is completed, + so this handler does not perform any additional database writes. + """ if not self.current_image or not self.current_image_id: - QMessageBox.warning( - self, "No Image", "Please load an image before processing annotations." - ) - return - - # Get current class - current_class = self.annotation_tools.get_current_class() - if not current_class: QMessageBox.warning( self, - "No Class Selected", - "Please select an object class before processing annotations.", + "No Image", + "Please load an image before working with annotations.", ) return - # Compute annotation parameters asbounding boxes and polylines from annotations - parameters = self.annotation_canvas.get_annotation_parameters() - if not parameters: - QMessageBox.warning( - self, - "No Annotations", - "Please draw some annotations before processing.", - ) - return - - # polyline = self.annotation_canvas.get_annotation_polyline() - - for param in parameters: - bounds = param["bbox"] - polyline = param["polyline"] - - try: - # Save annotation to database - annotation_id = self.db_manager.add_annotation( - image_id=self.current_image_id, - class_id=current_class["id"], - bbox=bounds, - annotator="manual", - segmentation_mask=polyline, - verified=False, - ) - - logger.info( - f"Saved annotation (ID: {annotation_id}) for class '{current_class['class_name']}' " - f"Bounding box: ({bounds[0]:.3f}, {bounds[1]:.3f}) to ({bounds[2]:.3f}, {bounds[3]:.3f})\n" - f"with {len(polyline)} polyline points" - ) - - # QMessageBox.information( - # self, - # "Success", - # f"Annotation saved successfully!\n\n" - # f"Class: {current_class['class_name']}\n" - # f"Bounding box: ({bounds[0]:.3f}, {bounds[1]:.3f}) to ({bounds[2]:.3f}, {bounds[3]:.3f})\n" - # f"Polyline points: {len(polyline)}", - # ) - - except Exception as e: - logger.error(f"Failed to save annotation: {e}") - QMessageBox.critical( - self, "Error", f"Failed to save annotation:\n{str(e)}" - ) - - # Optionally clear annotations after saving - reply = QMessageBox.question( + QMessageBox.information( self, - "Clear Annotations", - "Do you want to clear the annotations to start a new one?", - QMessageBox.Yes | QMessageBox.No, - QMessageBox.Yes, + "Annotations Already Saved", + "Annotations are saved automatically as you draw. " + "There is no separate processing step required.", ) - if reply == QMessageBox.Yes: + def _load_annotations_for_current_image(self): + """ + Load all annotations for the current image from the database and + redraw them on the canvas, honoring the currently selected class + filter (if any). + """ + if not self.current_image_id: + self.current_annotations = [] self.annotation_canvas.clear_annotations() - logger.info("Cleared annotations after saving") - - def _on_show_annotations(self): - """Load and display saved annotations from database.""" - # Check if we have an image loaded - if not self.current_image or not self.current_image_id: - QMessageBox.warning( - self, "No Image", "Please load an image to view its annotations." - ) return try: - # Clear current annotations - self.annotation_canvas.clear_annotations() - - # Retrieve annotations from database - annotations = self.db_manager.get_annotations_for_image( + self.current_annotations = self.db_manager.get_annotations_for_image( self.current_image_id ) - - if not annotations: - QMessageBox.information( - self, "No Annotations", "No saved annotations found for this image." - ) - return - - # Draw each annotation's polyline - drawn_count = 0 - for ann in annotations: - if ann.get("segmentation_mask"): - polyline = ann["segmentation_mask"] - color = ann.get("class_color", "#FF0000") - - # Draw the polyline - self.annotation_canvas.draw_saved_polyline(polyline, color, width=3) - self.annotation_canvas.draw_saved_bbox( - [ann["x_min"], ann["y_min"], ann["x_max"], ann["y_max"]], - color, - width=3, - ) - drawn_count += 1 - - logger.info(f"Displayed {drawn_count} saved annotations from database") - - QMessageBox.information( - self, - "Annotations Loaded", - f"Successfully loaded and displayed {drawn_count} annotation(s).", - ) - + self._redraw_annotations_for_current_filter() except Exception as e: - logger.error(f"Failed to load annotations: {e}") - QMessageBox.critical( - self, "Error", f"Failed to load annotations:\n{str(e)}" + logger.error( + f"Failed to load annotations for image {self.current_image_id}: {e}" ) + QMessageBox.critical( + self, + "Error", + f"Failed to load annotations for this image:\n{str(e)}", + ) + + def _redraw_annotations_for_current_filter(self): + """ + Redraw annotations for the current image, optionally filtered by the + currently selected object class. + """ + # Clear current on-canvas annotations but keep the image + self.annotation_canvas.clear_annotations() + + if not self.current_annotations: + return + + current_class = self.annotation_tools.get_current_class() + selected_class_id = current_class["id"] if current_class else None + + drawn_count = 0 + for ann in self.current_annotations: + # Filter by class if one is selected + if ( + selected_class_id is not None + and ann.get("class_id") != selected_class_id + ): + continue + + if ann.get("segmentation_mask"): + polyline = ann["segmentation_mask"] + color = ann.get("class_color", "#FF0000") + + self.annotation_canvas.draw_saved_polyline(polyline, color, width=3) + self.annotation_canvas.draw_saved_bbox( + [ann["x_min"], ann["y_min"], ann["x_max"], ann["y_max"]], + color, + width=3, + ) + drawn_count += 1 + + logger.info( + f"Displayed {drawn_count} annotation(s) for current image with " + f"{'no class filter' if selected_class_id is None else f'class_id={selected_class_id}'}" + ) def _restore_state(self): """Restore splitter positions from settings.""" diff --git a/src/gui/widgets/annotation_canvas_widget.py b/src/gui/widgets/annotation_canvas_widget.py index 3f57c2a..8ea5b56 100644 --- a/src/gui/widgets/annotation_canvas_widget.py +++ b/src/gui/widgets/annotation_canvas_widget.py @@ -4,6 +4,7 @@ Supports pen tool with color selection for manual annotation. """ import numpy as np +import math from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QScrollArea from PySide6.QtGui import ( @@ -19,18 +20,95 @@ from PySide6.QtGui import ( from PySide6.QtCore import Qt, QEvent, Signal, QPoint from typing import Any, Dict, List, Optional, Tuple -from scipy.ndimage import binary_dilation, label, binary_fill_holes, find_objects -from skimage.measure import find_contours - from src.utils.image import Image, ImageLoadError from src.utils.logger import get_logger -# For debugging visualization -import pylab as plt - logger = get_logger(__name__) +def perpendicular_distance( + point: Tuple[float, float], + start: Tuple[float, float], + end: Tuple[float, float], +) -> float: + """Perpendicular distance from `point` to the line defined by `start`->`end`.""" + (x, y), (x1, y1), (x2, y2) = point, start, end + dx = x2 - x1 + dy = y2 - y1 + if dx == 0.0 and dy == 0.0: + return math.hypot(x - x1, y - y1) + num = abs(dy * x - dx * y + x2 * y1 - y2 * x1) + den = math.hypot(dx, dy) + return num / den + + +def rdp(points: List[Tuple[float, float]], epsilon: float) -> List[Tuple[float, float]]: + """ + Recursive Ramer-Douglas-Peucker (RDP) polyline simplification. + + Args: + points: List of (x, y) points. + epsilon: Maximum allowed perpendicular distance in pixels. + + Returns: + Simplified list of (x, y) points including first and last. + """ + if len(points) <= 2: + return list(points) + + start = points[0] + end = points[-1] + max_dist = -1.0 + index = -1 + + for i in range(1, len(points) - 1): + d = perpendicular_distance(points[i], start, end) + if d > max_dist: + max_dist = d + index = i + + if max_dist > epsilon: + # Recursive split + left = rdp(points[: index + 1], epsilon) + right = rdp(points[index:], epsilon) + # Concatenate but avoid duplicate at split point + return left[:-1] + right + + # Keep only start and end + return [start, end] + + +def simplify_polyline( + points: List[Tuple[float, float]], epsilon: float +) -> List[Tuple[float, float]]: + """ + Simplify a polyline with RDP while preserving closure semantics. + + If the polyline is closed (first == last), the duplicate last point is removed + before simplification and then re-added after simplification. + """ + if not points: + return [] + + pts = [(float(x), float(y)) for x, y in points] + closed = False + + if len(pts) >= 2 and pts[0] == pts[-1]: + closed = True + pts = pts[:-1] # remove duplicate last for simplification + + if len(pts) <= 2: + simplified = list(pts) + else: + simplified = rdp(pts, epsilon) + + if closed and simplified: + if simplified[0] != simplified[-1]: + simplified.append(simplified[0]) + + return simplified + + class AnnotationCanvasWidget(QWidget): """ Widget for displaying images and drawing annotations with pen tool. @@ -68,8 +146,19 @@ class AnnotationCanvasWidget(QWidget): self.pen_enabled = False self.pen_color = QColor(255, 0, 0, 128) # Default red with 50% alpha self.pen_width = 3 - self.current_stroke = [] # Points in current stroke - self.all_strokes = [] # All completed strokes + + # Current stroke and stored polylines (in image coordinates, pixel units) + self.current_stroke: List[Tuple[float, float]] = [] + self.polylines: List[List[Tuple[float, float]]] = [] + self.stroke_meta: List[Dict[str, Any]] = [] # per-polyline style (color, width) + + # Legacy collection of strokes in normalized coordinates (kept for API compatibility) + self.all_strokes: List[dict] = [] + + # RDP simplification parameters (in pixels) + self.simplify_on_finish: bool = True + self.simplify_epsilon: float = 2.0 + self.sample_threshold: float = 2.0 # minimum movement to sample a new point self._setup_ui() @@ -128,6 +217,8 @@ class AnnotationCanvasWidget(QWidget): """Clear all drawn annotations.""" self.all_strokes = [] self.current_stroke = [] + self.polylines = [] + self.stroke_meta = [] self.is_drawing = False if self.annotation_pixmap: self.annotation_pixmap.fill(Qt.transparent) @@ -300,6 +391,46 @@ class AnnotationCanvasWidget(QWidget): norm_y = y / self.original_pixmap.height() return (norm_x, norm_y) + def _add_polyline( + self, img_points: List[Tuple[float, float]], color: QColor, width: int + ): + """Store a polyline in image coordinates and redraw annotations.""" + if not img_points or len(img_points) < 2: + return + + # Ensure all points are tuples of floats + normalized_points = [(float(x), float(y)) for x, y in img_points] + self.polylines.append(normalized_points) + self.stroke_meta.append({"color": QColor(color), "width": int(width)}) + + self._redraw_annotations() + + def _redraw_annotations(self): + """Redraw all stored polylines onto the annotation pixmap.""" + if self.annotation_pixmap is None: + return + + # Clear existing overlay + self.annotation_pixmap.fill(Qt.transparent) + + painter = QPainter(self.annotation_pixmap) + for polyline, meta in zip(self.polylines, self.stroke_meta): + pen_color: QColor = meta.get("color", self.pen_color) + width: int = meta.get("width", self.pen_width) + pen = QPen( + pen_color, + width, + Qt.SolidLine, + Qt.RoundCap, + Qt.RoundJoin, + ) + painter.setPen(pen) + for (x1, y1), (x2, y2) in zip(polyline[:-1], polyline[1:]): + painter.drawLine(int(x1), int(y1), int(x2), int(y2)) + painter.end() + + self._update_display() + def mousePressEvent(self, event: QMouseEvent): """Handle mouse press events for drawing.""" if not self.pen_enabled or self.annotation_pixmap is None: @@ -313,7 +444,7 @@ class AnnotationCanvasWidget(QWidget): if img_coords: self.is_drawing = True - self.current_stroke = [img_coords] + self.current_stroke = [(float(img_coords[0]), float(img_coords[1]))] def mouseMoveEvent(self, event: QMouseEvent): """Handle mouse move events for drawing.""" @@ -330,18 +461,33 @@ class AnnotationCanvasWidget(QWidget): img_coords = self._canvas_to_image_coords(label_pos) if img_coords and len(self.current_stroke) > 0: - # Draw line from last point to current point + last_point = self.current_stroke[-1] + dx = img_coords[0] - last_point[0] + dy = img_coords[1] - last_point[1] + + # Only sample a new point if we moved enough pixels + if math.hypot(dx, dy) < self.sample_threshold: + return + + # Draw line from last point to current point for interactive feedback painter = QPainter(self.annotation_pixmap) pen = QPen( - self.pen_color, self.pen_width, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin + self.pen_color, + self.pen_width, + Qt.SolidLine, + Qt.RoundCap, + Qt.RoundJoin, ) painter.setPen(pen) - - last_point = self.current_stroke[-1] - painter.drawLine(last_point[0], last_point[1], img_coords[0], img_coords[1]) + painter.drawLine( + int(last_point[0]), + int(last_point[1]), + int(img_coords[0]), + int(img_coords[1]), + ) painter.end() - self.current_stroke.append(img_coords) + self.current_stroke.append((float(img_coords[0]), float(img_coords[1]))) self._update_display() def mouseReleaseEvent(self, event: QMouseEvent): @@ -352,23 +498,42 @@ class AnnotationCanvasWidget(QWidget): self.is_drawing = False - if len(self.current_stroke) > 1: - # Convert to normalized coordinates and save stroke - normalized_stroke = [ - self._image_to_normalized_coords(x, y) for x, y in self.current_stroke - ] - self.all_strokes.append( - { - "points": normalized_stroke, - "color": self.pen_color.name(), - "alpha": self.pen_color.alpha(), - "width": self.pen_width, - } - ) + if len(self.current_stroke) > 1 and self.original_pixmap is not None: + # Ensure the stroke is closed by connecting end -> start + raw_points = list(self.current_stroke) + if raw_points[0] != raw_points[-1]: + raw_points.append(raw_points[0]) - # Emit signal with normalized coordinates - self.annotation_drawn.emit(normalized_stroke) - logger.debug(f"Completed stroke with {len(normalized_stroke)} points") + # Optional RDP simplification (in image pixel space) + if self.simplify_on_finish: + simplified = simplify_polyline(raw_points, self.simplify_epsilon) + else: + simplified = raw_points + + if len(simplified) >= 2: + # Store polyline and redraw all annotations + self._add_polyline(simplified, self.pen_color, self.pen_width) + + # Convert to normalized coordinates for metadata + signal + normalized_stroke = [ + self._image_to_normalized_coords(int(x), int(y)) + for (x, y) in simplified + ] + self.all_strokes.append( + { + "points": normalized_stroke, + "color": self.pen_color.name(), + "alpha": self.pen_color.alpha(), + "width": self.pen_width, + } + ) + + # Emit signal with normalized coordinates + self.annotation_drawn.emit(normalized_stroke) + logger.debug( + f"Completed stroke with {len(simplified)} points " + f"(normalized len={len(normalized_stroke)})" + ) self.current_stroke = [] @@ -464,61 +629,54 @@ class AnnotationCanvasWidget(QWidget): # return polyline - def get_annotation_parameters(self) -> Dict[str, Any]: + def get_annotation_parameters(self) -> Optional[List[Dict[str, Any]]]: """ Get all annotation parameters including bounding box and polyline. Returns: - Dictionary containing: - - 'bbox': Bounding box coordinates (x_min, y_min, x_max, y_max) - - 'polyline': List of [x, y] coordinate pairs + List of dictionaries, each containing: + - 'bbox': [x_min, y_min, x_max, y_max] in normalized image coordinates + - 'polyline': List of [y_norm, x_norm] points describing the polygon """ - - # Get np.arrays from annotation_pixmap accoriding to the color of the stroke - qimage = self.annotation_pixmap.toImage() - arr = np.ndarray( - (qimage.height(), qimage.width(), 4), - buffer=qimage.constBits(), - strides=[qimage.bytesPerLine(), 4, 1], - dtype=np.uint8, - ) - arr = np.sum(arr, axis=2) - arr_bin = arr > 0 - arr_bin = binary_fill_holes(arr_bin) - - labels, _number_of_features = label( - arr_bin, - ) - if _number_of_features == 0: + if self.original_pixmap is None or not self.polylines: return None - objects = find_objects(labels) - w, h = arr.shape - bounding_boxes = [ - [obj[1].start / h, obj[0].start / w, obj[1].stop / h, obj[0].stop / w] - for obj in objects - ] + img_width = float(self.original_pixmap.width()) + img_height = float(self.original_pixmap.height()) - polylines = find_contours(arr_bin, 0.5) - params = [] - for i, pl in enumerate(polylines): - # pl is in [row, col] format from find_contours - # We need to normalize: row/height, col/width - # w = height (rows), h = width (cols) from line 510 - normalized_polyline = (pl[::-1] / np.array([w, h])).tolist() + params: List[Dict[str, Any]] = [] - logger.debug(f"Polyline {i}: {len(pl)} points") - logger.debug(f" w={w} (height), h={h} (width)") - logger.debug(f" First 3 normalized points: {normalized_polyline[:3]}") + for idx, polyline in enumerate(self.polylines): + if len(polyline) < 2: + continue + + xs = [p[0] for p in polyline] + ys = [p[1] for p in polyline] + + x_min_norm = min(xs) / img_width + x_max_norm = max(xs) / img_width + y_min_norm = min(ys) / img_height + y_max_norm = max(ys) / img_height + + # Store polyline as [y_norm, x_norm] to match DB convention and + # the expectations of draw_saved_polyline(). + normalized_polyline = [ + [y / img_height, x / img_width] for (x, y) in polyline + ] + + logger.debug( + f"Polyline {idx}: {len(polyline)} points, " + f"bbox=({x_min_norm:.3f}, {y_min_norm:.3f})-({x_max_norm:.3f}, {y_max_norm:.3f})" + ) params.append( { - "bbox": bounding_boxes[i], + "bbox": [x_min_norm, y_min_norm, x_max_norm, y_max_norm], "polyline": normalized_polyline, } ) - return params + return params or None def draw_saved_polyline( self, polyline: List[List[float]], color: str, width: int = 3 @@ -548,36 +706,24 @@ class AnnotationCanvasWidget(QWidget): logger.debug(f" Image size: {img_width}x{img_height}") logger.debug(f" First 3 normalized points from DB: {polyline[:3]}") - img_coords = [] + img_coords: List[Tuple[float, float]] = [] for y_norm, x_norm in polyline: - x = int(x_norm * img_width) - y = int(y_norm * img_height) + x = float(x_norm * img_width) + y = float(y_norm * img_height) img_coords.append((x, y)) logger.debug(f" First 3 pixel coords: {img_coords[:3]}") - # Draw polyline on annotation pixmap - painter = QPainter(self.annotation_pixmap) + # Store and redraw using common pipeline pen_color = QColor(color) pen_color.setAlpha(128) # Add semi-transparency - pen = QPen(pen_color, width, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin) - painter.setPen(pen) + self._add_polyline(img_coords, pen_color, width) - # Draw lines between consecutive points - for i in range(len(img_coords) - 1): - x1, y1 = img_coords[i] - x2, y2 = img_coords[i + 1] - painter.drawLine(x1, y1, x2, y2) - - painter.end() - - # Store in all_strokes for consistency + # Store in all_strokes for consistency (uses normalized coordinates) self.all_strokes.append( {"points": polyline, "color": color, "alpha": 128, "width": width} ) - # Update display - self._update_display() logger.debug( f"Drew saved polyline with {len(polyline)} points in color {color}" ) diff --git a/src/gui/widgets/annotation_tools_widget.py b/src/gui/widgets/annotation_tools_widget.py index 4153dc1..89e4340 100644 --- a/src/gui/widgets/annotation_tools_widget.py +++ b/src/gui/widgets/annotation_tools_widget.py @@ -12,6 +12,8 @@ from PySide6.QtWidgets import ( QPushButton, QComboBox, QSpinBox, + QDoubleSpinBox, + QCheckBox, QColorDialog, QInputDialog, QMessageBox, @@ -49,10 +51,11 @@ class AnnotationToolsWidget(QWidget): pen_enabled_changed = Signal(bool) pen_color_changed = Signal(QColor) pen_width_changed = Signal(int) + simplify_on_finish_changed = Signal(bool) + simplify_epsilon_changed = Signal(float) class_selected = Signal(dict) clear_annotations_requested = Signal() process_annotations_requested = Signal() - show_annotations_requested = Signal() def __init__(self, db_manager: DatabaseManager, parent=None): """ @@ -110,6 +113,23 @@ class AnnotationToolsWidget(QWidget): color_layout.addStretch() pen_layout.addLayout(color_layout) + # Simplification controls (RDP) + simplify_layout = QHBoxLayout() + self.simplify_checkbox = QCheckBox("Simplify on finish") + self.simplify_checkbox.setChecked(True) + self.simplify_checkbox.stateChanged.connect(self._on_simplify_toggle) + simplify_layout.addWidget(self.simplify_checkbox) + + simplify_layout.addWidget(QLabel("epsilon (px):")) + self.eps_spin = QDoubleSpinBox() + self.eps_spin.setRange(0.0, 1000.0) + self.eps_spin.setSingleStep(0.5) + self.eps_spin.setValue(2.0) + self.eps_spin.valueChanged.connect(self._on_eps_change) + simplify_layout.addWidget(self.eps_spin) + simplify_layout.addStretch() + pen_layout.addLayout(simplify_layout) + pen_group.setLayout(pen_layout) layout.addWidget(pen_group) @@ -155,13 +175,6 @@ class AnnotationToolsWidget(QWidget): ) actions_layout.addWidget(self.process_btn) - self.show_btn = QPushButton("Show Saved Annotations") - self.show_btn.clicked.connect(self._on_show_annotations) - self.show_btn.setStyleSheet( - "QPushButton { background-color: #4CAF50; color: white; }" - ) - actions_layout.addWidget(self.show_btn) - self.clear_btn = QPushButton("Clear All Annotations") self.clear_btn.clicked.connect(self._on_clear_annotations) actions_layout.addWidget(self.clear_btn) @@ -227,6 +240,18 @@ class AnnotationToolsWidget(QWidget): self.pen_width_changed.emit(width) logger.debug(f"Pen width changed to {width}") + def _on_simplify_toggle(self, state: int): + """Handle simplify-on-finish checkbox toggle.""" + enabled = bool(state) + self.simplify_on_finish_changed.emit(enabled) + logger.debug(f"Simplify on finish set to {enabled}") + + def _on_eps_change(self, val: float): + """Handle epsilon (RDP tolerance) value changes.""" + epsilon = float(val) + self.simplify_epsilon_changed.emit(epsilon) + logger.debug(f"Simplification epsilon changed to {epsilon}") + def _on_color_picker(self): """Open color picker dialog with alpha support.""" color = QColorDialog.getColor( @@ -364,11 +389,6 @@ class AnnotationToolsWidget(QWidget): self.process_annotations_requested.emit() logger.debug("Process annotations requested") - def _on_show_annotations(self): - """Handle show annotations button.""" - self.show_annotations_requested.emit() - logger.debug("Show annotations requested") - def get_current_class(self) -> Optional[Dict]: """Get currently selected object class.""" return self.current_class -- 2.49.1 From c3d44ac9458eb3717a5f26093699393782285448 Mon Sep 17 00:00:00 2001 From: Martin Laasmaa Date: Tue, 9 Dec 2025 23:38:23 +0200 Subject: [PATCH 11/13] Renaming Pen tool to polyline tool --- src/gui/tabs/annotation_tab.py | 74 ++++--- src/gui/widgets/annotation_canvas_widget.py | 140 +++--------- src/gui/widgets/annotation_tools_widget.py | 232 +++++++++++--------- 3 files changed, 200 insertions(+), 246 deletions(-) diff --git a/src/gui/tabs/annotation_tab.py b/src/gui/tabs/annotation_tab.py index cec0d28..62b3d22 100644 --- a/src/gui/tabs/annotation_tab.py +++ b/src/gui/tabs/annotation_tab.py @@ -81,14 +81,14 @@ class AnnotationTab(QWidget): # Annotation tools section self.annotation_tools = AnnotationToolsWidget(self.db_manager) - self.annotation_tools.pen_enabled_changed.connect( - self.annotation_canvas.set_pen_enabled + self.annotation_tools.polyline_enabled_changed.connect( + self.annotation_canvas.set_polyline_enabled ) - self.annotation_tools.pen_color_changed.connect( - self.annotation_canvas.set_pen_color + self.annotation_tools.polyline_pen_color_changed.connect( + self.annotation_canvas.set_polyline_pen_color ) - self.annotation_tools.pen_width_changed.connect( - self.annotation_canvas.set_pen_width + self.annotation_tools.polyline_pen_width_changed.connect( + self.annotation_canvas.set_polyline_pen_width ) # RDP simplification controls self.annotation_tools.simplify_on_finish_changed.connect( @@ -97,13 +97,12 @@ class AnnotationTab(QWidget): self.annotation_tools.simplify_epsilon_changed.connect( self._on_simplify_epsilon_changed ) + # Class selection and class-color changes self.annotation_tools.class_selected.connect(self._on_class_selected) + self.annotation_tools.class_color_changed.connect(self._on_class_color_changed) self.annotation_tools.clear_annotations_requested.connect( self._on_clear_annotations ) - self.annotation_tools.process_annotations_requested.connect( - self._on_process_annotations - ) self.right_splitter.addWidget(self.annotation_tools) # Image loading section @@ -299,10 +298,37 @@ class AnnotationTab(QWidget): self.annotation_canvas.simplify_epsilon = float(epsilon) logger.debug(f"Annotation simplification epsilon set to {epsilon}") - def _on_class_selected(self, class_data: dict): - """Handle when an object class is selected.""" - logger.debug(f"Object class selected: {class_data['class_name']}") - # When a class is selected, update which annotations are visible + def _on_class_color_changed(self): + """ + Handle changes to the selected object's class color. + + When the user updates a class color in the tools widget, reload the + annotations for the current image so that all polylines are redrawn + using the updated per-class colors. + """ + if not self.current_image_id: + return + + logger.debug( + f"Class color changed; reloading annotations for image ID {self.current_image_id}" + ) + self._load_annotations_for_current_image() + + def _on_class_selected(self, class_data): + """ + Handle when an object class is selected or cleared. + + When a specific class is selected, only annotations of that class are drawn. + When the selection is cleared (\"-- Select Class --\"), all annotations are shown. + """ + if class_data: + logger.debug(f"Object class selected: {class_data['class_name']}") + else: + logger.debug( + 'No class selected ("-- Select Class --"), showing all annotations' + ) + + # Whenever the selection changes, update which annotations are visible self._redraw_annotations_for_current_filter() def _on_clear_annotations(self): @@ -310,28 +336,6 @@ class AnnotationTab(QWidget): self.annotation_canvas.clear_annotations() logger.info("Cleared all annotations") - def _on_process_annotations(self): - """ - Legacy hook kept for UI compatibility. - - Annotations are now saved automatically when a stroke is completed, - so this handler does not perform any additional database writes. - """ - if not self.current_image or not self.current_image_id: - QMessageBox.warning( - self, - "No Image", - "Please load an image before working with annotations.", - ) - return - - QMessageBox.information( - self, - "Annotations Already Saved", - "Annotations are saved automatically as you draw. " - "There is no separate processing step required.", - ) - def _load_annotations_for_current_image(self): """ Load all annotations for the current image from the database and diff --git a/src/gui/widgets/annotation_canvas_widget.py b/src/gui/widgets/annotation_canvas_widget.py index 8ea5b56..4a8be4e 100644 --- a/src/gui/widgets/annotation_canvas_widget.py +++ b/src/gui/widgets/annotation_canvas_widget.py @@ -1,6 +1,6 @@ """ Annotation canvas widget for drawing annotations on images. -Supports pen tool with color selection for manual annotation. +Currently supports polyline drawing tool with color selection for manual annotation. """ import numpy as np @@ -111,11 +111,11 @@ def simplify_polyline( class AnnotationCanvasWidget(QWidget): """ - Widget for displaying images and drawing annotations with pen tool. + Widget for displaying images and drawing annotations with zoom and drawing tools. Features: - Display images with zoom functionality - - Pen tool for drawing annotations + - Polyline tool for drawing annotations - Configurable pen color and width - Mouse-based drawing interface - Zoom in/out with mouse wheel and keyboard @@ -143,9 +143,9 @@ class AnnotationCanvasWidget(QWidget): # Drawing state self.is_drawing = False - self.pen_enabled = False - self.pen_color = QColor(255, 0, 0, 128) # Default red with 50% alpha - self.pen_width = 3 + self.polyline_enabled = False + self.polyline_pen_color = QColor(255, 0, 0, 128) # Default red with 50% alpha + self.polyline_pen_width = 3 # Current stroke and stored polylines (in image coordinates, pixel units) self.current_stroke: List[Tuple[float, float]] = [] @@ -309,21 +309,21 @@ class AnnotationCanvasWidget(QWidget): """Update display after drawing.""" self._apply_zoom() - def set_pen_enabled(self, enabled: bool): - """Enable or disable pen tool.""" - self.pen_enabled = enabled + def set_polyline_enabled(self, enabled: bool): + """Enable or disable polyline tool.""" + self.polyline_enabled = enabled if enabled: self.canvas_label.setCursor(Qt.CrossCursor) else: self.canvas_label.setCursor(Qt.ArrowCursor) - def set_pen_color(self, color: QColor): - """Set pen color.""" - self.pen_color = color + def set_polyline_pen_color(self, color: QColor): + """Set polyline pen color.""" + self.polyline_pen_color = color - def set_pen_width(self, width: int): - """Set pen width.""" - self.pen_width = max(1, width) + def set_polyline_pen_width(self, width: int): + """Set polyline pen width.""" + self.polyline_pen_width = max(1, width) def get_zoom_percentage(self) -> int: """Get current zoom level as percentage.""" @@ -415,8 +415,8 @@ class AnnotationCanvasWidget(QWidget): painter = QPainter(self.annotation_pixmap) for polyline, meta in zip(self.polylines, self.stroke_meta): - pen_color: QColor = meta.get("color", self.pen_color) - width: int = meta.get("width", self.pen_width) + pen_color: QColor = meta.get("color", self.polyline_pen_color) + width: int = meta.get("width", self.polyline_pen_width) pen = QPen( pen_color, width, @@ -433,7 +433,7 @@ class AnnotationCanvasWidget(QWidget): def mousePressEvent(self, event: QMouseEvent): """Handle mouse press events for drawing.""" - if not self.pen_enabled or self.annotation_pixmap is None: + if not self.polyline_enabled or self.annotation_pixmap is None: super().mousePressEvent(event) return @@ -450,7 +450,7 @@ class AnnotationCanvasWidget(QWidget): """Handle mouse move events for drawing.""" if ( not self.is_drawing - or not self.pen_enabled + or not self.polyline_enabled or self.annotation_pixmap is None ): super().mouseMoveEvent(event) @@ -472,8 +472,8 @@ class AnnotationCanvasWidget(QWidget): # Draw line from last point to current point for interactive feedback painter = QPainter(self.annotation_pixmap) pen = QPen( - self.pen_color, - self.pen_width, + self.polyline_pen_color, + self.polyline_pen_width, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin, @@ -512,7 +512,9 @@ class AnnotationCanvasWidget(QWidget): if len(simplified) >= 2: # Store polyline and redraw all annotations - self._add_polyline(simplified, self.pen_color, self.pen_width) + self._add_polyline( + simplified, self.polyline_pen_color, self.polyline_pen_width + ) # Convert to normalized coordinates for metadata + signal normalized_stroke = [ @@ -522,9 +524,9 @@ class AnnotationCanvasWidget(QWidget): self.all_strokes.append( { "points": normalized_stroke, - "color": self.pen_color.name(), - "alpha": self.pen_color.alpha(), - "width": self.pen_width, + "color": self.polyline_pen_color.name(), + "alpha": self.polyline_pen_color.alpha(), + "width": self.polyline_pen_width, } ) @@ -541,94 +543,6 @@ class AnnotationCanvasWidget(QWidget): """Get all drawn strokes with metadata.""" return self.all_strokes - # def get_annotation_bounds(self) -> Optional[Tuple[float, float, float, float]]: - # """ - # Compute bounding box that encompasses all annotation strokes. - - # Returns: - # Tuple of (x_min, y_min, x_max, y_max) in normalized coordinates (0-1), - # or None if no annotations exist. - # """ - # if not self.all_strokes: - # return None - - # # Find min/max across all strokes - # all_x = [] - # all_y = [] - - # for stroke in self.all_strokes: - # for x, y in stroke["points"]: - # all_x.append(x) - # all_y.append(y) - - # if not all_x: - # return None - - # x_min = min(all_x) - # y_min = min(all_y) - # x_max = max(all_x) - # y_max = max(all_y) - - # return (x_min, y_min, x_max, y_max) - - # def get_annotation_polyline(self) -> List[List[float]]: - # """ - # Get polyline coordinates representing all annotation strokes. - - # Returns: - # List of [x, y] coordinate pairs in normalized coordinates (0-1). - # """ - # polyline = [] - - # fig = plt.figure() - # ax1 = fig.add_subplot(411) - # ax2 = fig.add_subplot(412) - # ax3 = fig.add_subplot(413) - # ax4 = fig.add_subplot(414) - - # # Get np.arrays from annotation_pixmap accoriding to the color of the stroke - # qimage = self.annotation_pixmap.toImage() - # arr = np.ndarray( - # (qimage.height(), qimage.width(), 4), - # buffer=qimage.constBits(), - # strides=[qimage.bytesPerLine(), 4, 1], - # dtype=np.uint8, - # ) - # print(arr.shape, arr.dtype, arr.min(), arr.max()) - # arr = np.sum(arr, axis=2) - # ax1.imshow(arr) - - # arr_bin = arr > 0 - # ax2.imshow(arr_bin) - - # arr_bin = binary_fill_holes(arr_bin) - # ax3.imshow(arr_bin) - - # labels, _number_of_features = label( - # arr_bin, - # ) - - # ax4.imshow(labels) - - # objects = find_objects(labels) - # bounding_boxes = np.array( - # [[obj[0].start, obj[0].stop, obj[1].start, obj[1].stop] for obj in objects] - # ) / np.array([arr.shape[0], arr.shape[1]]) - - # print(objects) - # print(bounding_boxes) - # print(np.array([arr.shape[0], arr.shape[1]])) - - # polylines = find_contours(arr_bin, 0.5) - # for pl in polylines: - # ax1.plot(pl[:, 1], pl[:, 0], "k") - - # print(arr.shape, arr.dtype, arr.min(), arr.max()) - - # plt.show() - - # return polyline - def get_annotation_parameters(self) -> Optional[List[Dict[str, Any]]]: """ Get all annotation parameters including bounding box and polyline. diff --git a/src/gui/widgets/annotation_tools_widget.py b/src/gui/widgets/annotation_tools_widget.py index 89e4340..ddbf690 100644 --- a/src/gui/widgets/annotation_tools_widget.py +++ b/src/gui/widgets/annotation_tools_widget.py @@ -1,6 +1,6 @@ """ Annotation tools widget for controlling annotation parameters. -Includes pen tool, color picker, class selection, and annotation management. +Includes polyline tool, color picker, class selection, and annotation management. """ from PySide6.QtWidgets import ( @@ -33,29 +33,29 @@ class AnnotationToolsWidget(QWidget): Widget for annotation tool controls. Features: - - Enable/disable pen tool - - Color selection for pen + - Enable/disable polyline tool + - Color selection for polyline pen - Object class selection - Add new object classes - Pen width control - Clear annotations Signals: - pen_enabled_changed: Emitted when pen tool is enabled/disabled (bool) - pen_color_changed: Emitted when pen color changes (QColor) - pen_width_changed: Emitted when pen width changes (int) + polyline_enabled_changed: Emitted when polyline tool is enabled/disabled (bool) + polyline_pen_color_changed: Emitted when polyline pen color changes (QColor) + polyline_pen_width_changed: Emitted when polyline pen width changes (int) class_selected: Emitted when object class is selected (dict) clear_annotations_requested: Emitted when clear button is pressed """ - pen_enabled_changed = Signal(bool) - pen_color_changed = Signal(QColor) - pen_width_changed = Signal(int) + polyline_enabled_changed = Signal(bool) + polyline_pen_color_changed = Signal(QColor) + polyline_pen_width_changed = Signal(int) simplify_on_finish_changed = Signal(bool) simplify_epsilon_changed = Signal(float) class_selected = Signal(dict) + class_color_changed = Signal() clear_annotations_requested = Signal() - process_annotations_requested = Signal() def __init__(self, db_manager: DatabaseManager, parent=None): """ @@ -67,7 +67,7 @@ class AnnotationToolsWidget(QWidget): """ super().__init__(parent) self.db_manager = db_manager - self.pen_enabled = False + self.polyline_enabled = False self.current_color = QColor(255, 0, 0, 128) # Red with 50% alpha self.current_class = None @@ -78,40 +78,31 @@ class AnnotationToolsWidget(QWidget): """Setup user interface.""" layout = QVBoxLayout() - # Pen Tool Group - pen_group = QGroupBox("Pen Tool") - pen_layout = QVBoxLayout() + # Polyline Tool Group + polyline_group = QGroupBox("Polyline Tool") + polyline_layout = QVBoxLayout() - # Enable/Disable pen + # Enable/Disable polyline tool button_layout = QHBoxLayout() - self.pen_toggle_btn = QPushButton("Enable Pen") - self.pen_toggle_btn.setCheckable(True) - self.pen_toggle_btn.clicked.connect(self._on_pen_toggle) - button_layout.addWidget(self.pen_toggle_btn) - pen_layout.addLayout(button_layout) + self.polyline_toggle_btn = QPushButton("Start Drawing Polyline") + self.polyline_toggle_btn.setCheckable(True) + self.polyline_toggle_btn.clicked.connect(self._on_polyline_toggle) + button_layout.addWidget(self.polyline_toggle_btn) + polyline_layout.addLayout(button_layout) - # Pen width control + # Polyline pen width control width_layout = QHBoxLayout() width_layout.addWidget(QLabel("Pen Width:")) - self.pen_width_spin = QSpinBox() - self.pen_width_spin.setMinimum(1) - self.pen_width_spin.setMaximum(20) - self.pen_width_spin.setValue(3) - self.pen_width_spin.valueChanged.connect(self._on_pen_width_changed) - width_layout.addWidget(self.pen_width_spin) + self.polyline_pen_width_spin = QSpinBox() + self.polyline_pen_width_spin.setMinimum(1) + self.polyline_pen_width_spin.setMaximum(20) + self.polyline_pen_width_spin.setValue(3) + self.polyline_pen_width_spin.valueChanged.connect( + self._on_polyline_pen_width_changed + ) + width_layout.addWidget(self.polyline_pen_width_spin) width_layout.addStretch() - pen_layout.addLayout(width_layout) - - # Color selection - color_layout = QHBoxLayout() - color_layout.addWidget(QLabel("Color:")) - self.color_btn = QPushButton() - self.color_btn.setFixedSize(40, 30) - self.color_btn.clicked.connect(self._on_color_picker) - self._update_color_button() - color_layout.addWidget(self.color_btn) - color_layout.addStretch() - pen_layout.addLayout(color_layout) + polyline_layout.addLayout(width_layout) # Simplification controls (RDP) simplify_layout = QHBoxLayout() @@ -128,10 +119,10 @@ class AnnotationToolsWidget(QWidget): self.eps_spin.valueChanged.connect(self._on_eps_change) simplify_layout.addWidget(self.eps_spin) simplify_layout.addStretch() - pen_layout.addLayout(simplify_layout) + polyline_layout.addLayout(simplify_layout) - pen_group.setLayout(pen_layout) - layout.addWidget(pen_group) + polyline_group.setLayout(polyline_layout) + layout.addWidget(polyline_group) # Object Class Group class_group = QGroupBox("Object Class") @@ -142,7 +133,7 @@ class AnnotationToolsWidget(QWidget): self.class_combo.currentIndexChanged.connect(self._on_class_selected) class_layout.addWidget(self.class_combo) - # Add class button + # Add / manage classes class_button_layout = QHBoxLayout() self.add_class_btn = QPushButton("Add New Class") self.add_class_btn.clicked.connect(self._on_add_class) @@ -153,6 +144,17 @@ class AnnotationToolsWidget(QWidget): class_button_layout.addWidget(self.refresh_classes_btn) class_layout.addLayout(class_button_layout) + # Class color (associated with selected object class) + color_layout = QHBoxLayout() + color_layout.addWidget(QLabel("Class Color:")) + self.color_btn = QPushButton() + self.color_btn.setFixedSize(40, 30) + self.color_btn.clicked.connect(self._on_color_picker) + self._update_color_button() + color_layout.addWidget(self.color_btn) + color_layout.addStretch() + class_layout.addLayout(color_layout) + # Selected class info self.class_info_label = QLabel("No class selected") self.class_info_label.setWordWrap(True) @@ -168,13 +170,6 @@ class AnnotationToolsWidget(QWidget): actions_group = QGroupBox("Actions") actions_layout = QVBoxLayout() - self.process_btn = QPushButton("Process Annotations") - self.process_btn.clicked.connect(self._on_process_annotations) - self.process_btn.setStyleSheet( - "QPushButton { background-color: #2196F3; color: white; font-weight: bold; }" - ) - actions_layout.addWidget(self.process_btn) - self.clear_btn = QPushButton("Clear All Annotations") self.clear_btn.clicked.connect(self._on_clear_annotations) actions_layout.addWidget(self.clear_btn) @@ -206,7 +201,7 @@ class AnnotationToolsWidget(QWidget): # Clear and repopulate combo box self.class_combo.clear() - self.class_combo.addItem("-- Select Class --", None) + self.class_combo.addItem("-- Select Class / Show All --", None) for cls in classes: self.class_combo.addItem(cls["class_name"], cls) @@ -219,26 +214,26 @@ class AnnotationToolsWidget(QWidget): self, "Error", f"Failed to load object classes:\n{str(e)}" ) - def _on_pen_toggle(self, checked: bool): - """Handle pen tool enable/disable.""" - self.pen_enabled = checked + def _on_polyline_toggle(self, checked: bool): + """Handle polyline tool enable/disable.""" + self.polyline_enabled = checked if checked: - self.pen_toggle_btn.setText("Disable Pen") - self.pen_toggle_btn.setStyleSheet( + self.polyline_toggle_btn.setText("Start Drawing Polyline") + self.polyline_toggle_btn.setStyleSheet( "QPushButton { background-color: #4CAF50; }" ) else: - self.pen_toggle_btn.setText("Enable Pen") - self.pen_toggle_btn.setStyleSheet("") + self.polyline_toggle_btn.setText("Stop drawing Polyline") + self.polyline_toggle_btn.setStyleSheet("") - self.pen_enabled_changed.emit(self.pen_enabled) - logger.debug(f"Pen tool {'enabled' if checked else 'disabled'}") + self.polyline_enabled_changed.emit(self.polyline_enabled) + logger.debug(f"Polyline tool {'enabled' if checked else 'disabled'}") - def _on_pen_width_changed(self, width: int): - """Handle pen width changes.""" - self.pen_width_changed.emit(width) - logger.debug(f"Pen width changed to {width}") + def _on_polyline_pen_width_changed(self, width: int): + """Handle polyline pen width changes.""" + self.polyline_pen_width_changed.emit(width) + logger.debug(f"Polyline pen width changed to {width}") def _on_simplify_toggle(self, state: int): """Handle simplify-on-finish checkbox toggle.""" @@ -253,24 +248,75 @@ class AnnotationToolsWidget(QWidget): logger.debug(f"Simplification epsilon changed to {epsilon}") def _on_color_picker(self): - """Open color picker dialog with alpha support.""" + """Open color picker dialog and update the selected object's class color.""" + if not self.current_class: + QMessageBox.warning( + self, + "No Class Selected", + "Please select an object class before changing its color.", + ) + return + + # Use current class color (without alpha) as the base + base_color = QColor(self.current_class.get("color", self.current_color.name())) color = QColorDialog.getColor( - self.current_color, + base_color, self, - "Select Pen Color", - QColorDialog.ShowAlphaChannel, # Enable alpha channel selection + "Select Class Color", + QColorDialog.ShowAlphaChannel, # Allow alpha in UI, but store RGB in DB ) - if color.isValid(): - self.current_color = color - self._update_color_button() - self.pen_color_changed.emit(color) - logger.debug( - f"Pen color changed to {color.name()} with alpha {color.alpha()}" + if not color.isValid(): + return + + # Normalize to opaque RGB for storage + new_color = QColor(color) + new_color.setAlpha(255) + hex_color = new_color.name() + + try: + # Update in database + self.db_manager.update_object_class( + class_id=self.current_class["id"], color=hex_color ) + except Exception as e: + logger.error(f"Failed to update class color in database: {e}") + QMessageBox.critical( + self, + "Error", + f"Failed to update class color in database:\n{str(e)}", + ) + return + + # Update local class data and combo box item data + self.current_class["color"] = hex_color + current_index = self.class_combo.currentIndex() + if current_index >= 0: + self.class_combo.setItemData(current_index, dict(self.current_class)) + + # Update info label text + info_text = f"Class: {self.current_class['class_name']}\nColor: {hex_color}" + if self.current_class.get("description"): + info_text += f"\nDescription: {self.current_class['description']}" + self.class_info_label.setText(info_text) + + # Use semi-transparent version for polyline pen / button preview + class_color = QColor(hex_color) + class_color.setAlpha(128) + self.current_color = class_color + self._update_color_button() + self.polyline_pen_color_changed.emit(class_color) + + logger.debug( + f"Updated class '{self.current_class['class_name']}' color to " + f"{hex_color} (polyline pen alpha={class_color.alpha()})" + ) + + # Notify listeners (e.g., AnnotationTab) so they can reload/redraw + self.class_color_changed.emit() def _on_class_selected(self, index: int): - """Handle object class selection.""" + """Handle object class selection (including '-- Select Class --').""" class_data = self.class_combo.currentData() if class_data: @@ -285,20 +331,23 @@ class AnnotationToolsWidget(QWidget): self.class_info_label.setText(info_text) - # Update pen color to match class color with semi-transparency + # Update polyline pen color to match class color with semi-transparency class_color = QColor(class_data["color"]) if class_color.isValid(): # Add 50% alpha for semi-transparency class_color.setAlpha(128) self.current_color = class_color self._update_color_button() - self.pen_color_changed.emit(class_color) + self.polyline_pen_color_changed.emit(class_color) self.class_selected.emit(class_data) logger.debug(f"Selected class: {class_data['class_name']}") else: + # "-- Select Class --" chosen: clear current class and show all annotations self.current_class = None self.class_info_label.setText("No class selected") + self.class_selected.emit(None) + logger.debug("Class selection cleared: showing annotations for all classes") def _on_add_class(self): """Handle adding a new object class.""" @@ -376,31 +425,18 @@ class AnnotationToolsWidget(QWidget): self.clear_annotations_requested.emit() logger.debug("Clear annotations requested") - def _on_process_annotations(self): - """Handle process annotations button.""" - if not self.current_class: - QMessageBox.warning( - self, - "No Class Selected", - "Please select an object class before processing annotations.", - ) - return - - self.process_annotations_requested.emit() - logger.debug("Process annotations requested") - def get_current_class(self) -> Optional[Dict]: """Get currently selected object class.""" return self.current_class - def get_pen_color(self) -> QColor: - """Get current pen color.""" + def get_polyline_pen_color(self) -> QColor: + """Get current polyline pen color.""" return self.current_color - def get_pen_width(self) -> int: - """Get current pen width.""" - return self.pen_width_spin.value() + def get_polyline_pen_width(self) -> int: + """Get current polyline pen width.""" + return self.polyline_pen_width_spin.value() - def is_pen_enabled(self) -> bool: - """Check if pen tool is enabled.""" - return self.pen_enabled + def is_polyline_enabled(self) -> bool: + """Check if polyline tool is enabled.""" + return self.polyline_enabled -- 2.49.1 From 35e2398e95c9844c39d368d59fecf41814cf52c5 Mon Sep 17 00:00:00 2001 From: Martin Laasmaa Date: Tue, 9 Dec 2025 23:56:29 +0200 Subject: [PATCH 12/13] Fixing bounding box drawing --- src/gui/tabs/annotation_tab.py | 4 ++ src/gui/widgets/annotation_canvas_widget.py | 79 ++++++++++++++++----- src/gui/widgets/annotation_tools_widget.py | 18 ++++- 3 files changed, 82 insertions(+), 19 deletions(-) diff --git a/src/gui/tabs/annotation_tab.py b/src/gui/tabs/annotation_tab.py index 62b3d22..d517f56 100644 --- a/src/gui/tabs/annotation_tab.py +++ b/src/gui/tabs/annotation_tab.py @@ -90,6 +90,10 @@ class AnnotationTab(QWidget): self.annotation_tools.polyline_pen_width_changed.connect( self.annotation_canvas.set_polyline_pen_width ) + # Show / hide bounding boxes + self.annotation_tools.show_bboxes_changed.connect( + self.annotation_canvas.set_show_bboxes + ) # RDP simplification controls self.annotation_tools.simplify_on_finish_changed.connect( self._on_simplify_on_finish_changed diff --git a/src/gui/widgets/annotation_canvas_widget.py b/src/gui/widgets/annotation_canvas_widget.py index 4a8be4e..acacc72 100644 --- a/src/gui/widgets/annotation_canvas_widget.py +++ b/src/gui/widgets/annotation_canvas_widget.py @@ -146,12 +146,17 @@ class AnnotationCanvasWidget(QWidget): self.polyline_enabled = False self.polyline_pen_color = QColor(255, 0, 0, 128) # Default red with 50% alpha self.polyline_pen_width = 3 + self.show_bboxes: bool = True # Control visibility of bounding boxes # Current stroke and stored polylines (in image coordinates, pixel units) self.current_stroke: List[Tuple[float, float]] = [] self.polylines: List[List[Tuple[float, float]]] = [] self.stroke_meta: List[Dict[str, Any]] = [] # per-polyline style (color, width) + # Stored bounding boxes in normalized coordinates (x_min, y_min, x_max, y_max) + self.bboxes: List[List[float]] = [] + self.bbox_meta: List[Dict[str, Any]] = [] # per-bbox style (color, width) + # Legacy collection of strokes in normalized coordinates (kept for API compatibility) self.all_strokes: List[dict] = [] @@ -219,6 +224,8 @@ class AnnotationCanvasWidget(QWidget): self.current_stroke = [] self.polylines = [] self.stroke_meta = [] + self.bboxes = [] + self.bbox_meta = [] self.is_drawing = False if self.annotation_pixmap: self.annotation_pixmap.fill(Qt.transparent) @@ -406,7 +413,7 @@ class AnnotationCanvasWidget(QWidget): self._redraw_annotations() def _redraw_annotations(self): - """Redraw all stored polylines onto the annotation pixmap.""" + """Redraw all stored polylines and (optionally) bounding boxes onto the annotation pixmap.""" if self.annotation_pixmap is None: return @@ -414,6 +421,8 @@ class AnnotationCanvasWidget(QWidget): self.annotation_pixmap.fill(Qt.transparent) painter = QPainter(self.annotation_pixmap) + + # Draw polylines for polyline, meta in zip(self.polylines, self.stroke_meta): pen_color: QColor = meta.get("color", self.polyline_pen_color) width: int = meta.get("width", self.polyline_pen_width) @@ -427,6 +436,37 @@ class AnnotationCanvasWidget(QWidget): painter.setPen(pen) for (x1, y1), (x2, y2) in zip(polyline[:-1], polyline[1:]): painter.drawLine(int(x1), int(y1), int(x2), int(y2)) + + # Draw bounding boxes (dashed) if enabled + if self.show_bboxes and self.original_pixmap is not None and self.bboxes: + img_width = float(self.original_pixmap.width()) + img_height = float(self.original_pixmap.height()) + + for bbox, meta in zip(self.bboxes, self.bbox_meta): + if len(bbox) != 4: + continue + + x_min_norm, y_min_norm, x_max_norm, y_max_norm = bbox + x_min = int(x_min_norm * img_width) + y_min = int(y_min_norm * img_height) + x_max = int(x_max_norm * img_width) + y_max = int(y_max_norm * img_height) + + rect_width = x_max - x_min + rect_height = y_max - y_min + + pen_color: QColor = meta.get("color", QColor(255, 0, 0, 128)) + width: int = meta.get("width", self.polyline_pen_width) + pen = QPen( + pen_color, + width, + Qt.DashLine, + Qt.SquareCap, + Qt.MiterJoin, + ) + painter.setPen(pen) + painter.drawRect(x_min, y_min, rect_width, rect_height) + painter.end() self._update_display() @@ -647,7 +687,7 @@ class AnnotationCanvasWidget(QWidget): Draw a bounding box from database coordinates onto the annotation canvas. Args: - bbox: Bounding box as [y_min_norm, x_min_norm, y_max_norm, x_max_norm] + bbox: Bounding box as [x_min_norm, y_min_norm, x_max_norm, y_max_norm] in normalized coordinates (0-1) color: Color hex string (e.g., '#FF0000') width: Line width in pixels @@ -662,8 +702,7 @@ class AnnotationCanvasWidget(QWidget): ) return - # Convert normalized coordinates to image coordinates - # bbox format: [y_min_norm, x_min_norm, y_max_norm, x_max_norm] + # Convert normalized coordinates to image coordinates (for logging/debug) img_width = self.original_pixmap.width() img_height = self.original_pixmap.height() @@ -677,29 +716,35 @@ class AnnotationCanvasWidget(QWidget): logger.debug(f" Image size: {img_width}x{img_height}") logger.debug(f" Pixel coords: ({x_min}, {y_min}) to ({x_max}, {y_max})") - # Draw bounding box on annotation pixmap - painter = QPainter(self.annotation_pixmap) + # Store bounding box (normalized) and its style; actual drawing happens + # in _redraw_annotations() together with all polylines. pen_color = QColor(color) pen_color.setAlpha(128) # Add semi-transparency - pen = QPen(pen_color, width, Qt.SolidLine, Qt.SquareCap, Qt.MiterJoin) - painter.setPen(pen) - - # Draw rectangle - rect_width = x_max - x_min - rect_height = y_max - y_min - painter.drawRect(x_min, y_min, rect_width, rect_height) - - painter.end() + self.bboxes.append( + [float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)] + ) + self.bbox_meta.append({"color": pen_color, "width": int(width)}) # Store in all_strokes for consistency self.all_strokes.append( {"bbox": bbox, "color": color, "alpha": 128, "width": width} ) - # Update display - self._update_display() + # Redraw overlay (polylines + all bounding boxes) + self._redraw_annotations() logger.debug(f"Drew saved bounding box in color {color}") + def set_show_bboxes(self, show: bool): + """ + Enable or disable drawing of bounding boxes. + + Args: + show: If True, draw bounding boxes; if False, hide them. + """ + self.show_bboxes = bool(show) + logger.debug(f"Set show_bboxes to {self.show_bboxes}") + self._redraw_annotations() + def keyPressEvent(self, event: QKeyEvent): """Handle keyboard events for zooming.""" if event.key() in (Qt.Key_Plus, Qt.Key_Equal): diff --git a/src/gui/widgets/annotation_tools_widget.py b/src/gui/widgets/annotation_tools_widget.py index ddbf690..70312bc 100644 --- a/src/gui/widgets/annotation_tools_widget.py +++ b/src/gui/widgets/annotation_tools_widget.py @@ -53,6 +53,8 @@ class AnnotationToolsWidget(QWidget): polyline_pen_width_changed = Signal(int) simplify_on_finish_changed = Signal(bool) simplify_epsilon_changed = Signal(float) + # Toggle visibility of bounding boxes on the canvas + show_bboxes_changed = Signal(bool) class_selected = Signal(dict) class_color_changed = Signal() clear_annotations_requested = Signal() @@ -170,6 +172,12 @@ class AnnotationToolsWidget(QWidget): actions_group = QGroupBox("Actions") actions_layout = QVBoxLayout() + # Show / hide bounding boxes + self.show_bboxes_checkbox = QCheckBox("Show bounding boxes") + self.show_bboxes_checkbox.setChecked(True) + self.show_bboxes_checkbox.stateChanged.connect(self._on_show_bboxes_toggle) + actions_layout.addWidget(self.show_bboxes_checkbox) + self.clear_btn = QPushButton("Clear All Annotations") self.clear_btn.clicked.connect(self._on_clear_annotations) actions_layout.addWidget(self.clear_btn) @@ -219,12 +227,12 @@ class AnnotationToolsWidget(QWidget): self.polyline_enabled = checked if checked: - self.polyline_toggle_btn.setText("Start Drawing Polyline") + self.polyline_toggle_btn.setText("Stop Drawing Polyline") self.polyline_toggle_btn.setStyleSheet( "QPushButton { background-color: #4CAF50; }" ) else: - self.polyline_toggle_btn.setText("Stop drawing Polyline") + self.polyline_toggle_btn.setText("Start Drawing Polyline") self.polyline_toggle_btn.setStyleSheet("") self.polyline_enabled_changed.emit(self.polyline_enabled) @@ -247,6 +255,12 @@ class AnnotationToolsWidget(QWidget): self.simplify_epsilon_changed.emit(epsilon) logger.debug(f"Simplification epsilon changed to {epsilon}") + def _on_show_bboxes_toggle(self, state: int): + """Handle 'Show bounding boxes' checkbox toggle.""" + show = bool(state) + self.show_bboxes_changed.emit(show) + logger.debug(f"Show bounding boxes set to {show}") + def _on_color_picker(self): """Open color picker dialog and update the selected object's class color.""" if not self.current_class: -- 2.49.1 From e6a5e74fa15bb5d11e2c14359d78d23733183e37 Mon Sep 17 00:00:00 2001 From: Martin Laasmaa Date: Wed, 10 Dec 2025 00:19:59 +0200 Subject: [PATCH 13/13] Adding feature to remove annotations --- src/database/db_manager.py | 19 +++ src/gui/tabs/annotation_tab.py | 116 +++++++++++++++- src/gui/widgets/annotation_canvas_widget.py | 141 +++++++++++++++++--- src/gui/widgets/annotation_tools_widget.py | 22 +++ 4 files changed, 278 insertions(+), 20 deletions(-) diff --git a/src/database/db_manager.py b/src/database/db_manager.py index db2da9e..53d5695 100644 --- a/src/database/db_manager.py +++ b/src/database/db_manager.py @@ -706,6 +706,25 @@ class DatabaseManager: finally: conn.close() + def delete_annotation(self, annotation_id: int) -> bool: + """ + Delete a manual annotation by ID. + + Args: + annotation_id: ID of the annotation to delete + + Returns: + True if an annotation was deleted, False otherwise. + """ + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute("DELETE FROM annotations WHERE id = ?", (annotation_id,)) + conn.commit() + return cursor.rowcount > 0 + finally: + conn.close() + # ==================== Object Class Operations ==================== def get_object_classes(self) -> List[Dict]: diff --git a/src/gui/tabs/annotation_tab.py b/src/gui/tabs/annotation_tab.py index d517f56..6927aba 100644 --- a/src/gui/tabs/annotation_tab.py +++ b/src/gui/tabs/annotation_tab.py @@ -39,6 +39,8 @@ class AnnotationTab(QWidget): self.current_image_path = None self.current_image_id = None self.current_annotations = [] + # IDs of annotations currently selected on the canvas (multi-select) + self.selected_annotation_ids = [] self._setup_ui() @@ -62,6 +64,8 @@ class AnnotationTab(QWidget): self.annotation_canvas = AnnotationCanvasWidget() self.annotation_canvas.zoom_changed.connect(self._on_zoom_changed) self.annotation_canvas.annotation_drawn.connect(self._on_annotation_drawn) + # Selection of existing polylines (when tool is not in drawing mode) + self.annotation_canvas.annotation_selected.connect(self._on_annotation_selected) canvas_layout.addWidget(self.annotation_canvas) canvas_group.setLayout(canvas_layout) @@ -107,6 +111,10 @@ class AnnotationTab(QWidget): self.annotation_tools.clear_annotations_requested.connect( self._on_clear_annotations ) + # Delete selected annotation on canvas + self.annotation_tools.delete_selected_annotation_requested.connect( + self._on_delete_selected_annotation + ) self.right_splitter.addWidget(self.annotation_tools) # Image loading section @@ -292,6 +300,25 @@ class AnnotationTab(QWidget): logger.error(f"Failed to save annotation: {e}") QMessageBox.critical(self, "Error", f"Failed to save annotation:\n{str(e)}") + def _on_annotation_selected(self, annotation_ids): + """ + Handle selection of existing annotations on the canvas. + + Args: + annotation_ids: List of selected annotation IDs, or None/empty if cleared. + """ + if not annotation_ids: + self.selected_annotation_ids = [] + self.annotation_tools.set_has_selected_annotation(False) + logger.debug("Annotation selection cleared on canvas") + return + + # Normalize to a unique, sorted list of integer IDs + ids = sorted({int(aid) for aid in annotation_ids if isinstance(aid, int)}) + self.selected_annotation_ids = ids + self.annotation_tools.set_has_selected_annotation(bool(ids)) + logger.debug(f"Annotations selected on canvas: IDs={ids}") + def _on_simplify_on_finish_changed(self, enabled: bool): """Update canvas simplify-on-finish flag from tools widget.""" self.annotation_canvas.simplify_on_finish = enabled @@ -323,7 +350,7 @@ class AnnotationTab(QWidget): Handle when an object class is selected or cleared. When a specific class is selected, only annotations of that class are drawn. - When the selection is cleared (\"-- Select Class --\"), all annotations are shown. + When the selection is cleared ("-- Select Class --"), all annotations are shown. """ if class_data: logger.debug(f"Object class selected: {class_data['class_name']}") @@ -332,14 +359,89 @@ class AnnotationTab(QWidget): 'No class selected ("-- Select Class --"), showing all annotations' ) + # Changing the class filter invalidates any previous selection + self.selected_annotation_ids = [] + self.annotation_tools.set_has_selected_annotation(False) + # Whenever the selection changes, update which annotations are visible self._redraw_annotations_for_current_filter() def _on_clear_annotations(self): """Handle clearing all annotations.""" self.annotation_canvas.clear_annotations() + # Clear in-memory state and selection, but keep DB entries unchanged + self.current_annotations = [] + self.selected_annotation_ids = [] + self.annotation_tools.set_has_selected_annotation(False) logger.info("Cleared all annotations") + def _on_delete_selected_annotation(self): + """Handle deleting the currently selected annotation(s) (if any).""" + if not self.selected_annotation_ids: + QMessageBox.information( + self, + "No Selection", + "No annotation is currently selected.", + ) + return + + count = len(self.selected_annotation_ids) + if count == 1: + question = "Are you sure you want to delete the selected annotation?" + title = "Delete Annotation" + else: + question = ( + f"Are you sure you want to delete the {count} selected annotations?" + ) + title = "Delete Annotations" + + reply = QMessageBox.question( + self, + title, + question, + QMessageBox.Yes | QMessageBox.No, + QMessageBox.No, + ) + if reply != QMessageBox.Yes: + return + + failed_ids = [] + try: + for ann_id in self.selected_annotation_ids: + try: + deleted = self.db_manager.delete_annotation(ann_id) + if not deleted: + failed_ids.append(ann_id) + except Exception as e: + logger.error(f"Failed to delete annotation ID {ann_id}: {e}") + failed_ids.append(ann_id) + + if failed_ids: + QMessageBox.warning( + self, + "Partial Failure", + "Some annotations could not be deleted:\n" + + ", ".join(str(a) for a in failed_ids), + ) + else: + logger.info( + f"Deleted {count} annotation(s): " + + ", ".join(str(a) for a in self.selected_annotation_ids) + ) + + # Clear selection and reload annotations for the current image from DB + self.selected_annotation_ids = [] + self.annotation_tools.set_has_selected_annotation(False) + self._load_annotations_for_current_image() + + except Exception as e: + logger.error(f"Failed to delete annotations: {e}") + QMessageBox.critical( + self, + "Error", + f"Failed to delete annotations:\n{str(e)}", + ) + def _load_annotations_for_current_image(self): """ Load all annotations for the current image from the database and @@ -349,12 +451,17 @@ class AnnotationTab(QWidget): if not self.current_image_id: self.current_annotations = [] self.annotation_canvas.clear_annotations() + self.selected_annotation_ids = [] + self.annotation_tools.set_has_selected_annotation(False) return try: self.current_annotations = self.db_manager.get_annotations_for_image( self.current_image_id ) + # New annotations loaded; reset any selection + self.selected_annotation_ids = [] + self.annotation_tools.set_has_selected_annotation(False) self._redraw_annotations_for_current_filter() except Exception as e: logger.error( @@ -393,7 +500,12 @@ class AnnotationTab(QWidget): polyline = ann["segmentation_mask"] color = ann.get("class_color", "#FF0000") - self.annotation_canvas.draw_saved_polyline(polyline, color, width=3) + self.annotation_canvas.draw_saved_polyline( + polyline, + color, + width=3, + annotation_id=ann["id"], + ) self.annotation_canvas.draw_saved_bbox( [ann["x_min"], ann["y_min"], ann["x_max"], ann["y_max"]], color, diff --git a/src/gui/widgets/annotation_canvas_widget.py b/src/gui/widgets/annotation_canvas_widget.py index acacc72..37523e9 100644 --- a/src/gui/widgets/annotation_canvas_widget.py +++ b/src/gui/widgets/annotation_canvas_widget.py @@ -127,6 +127,9 @@ class AnnotationCanvasWidget(QWidget): zoom_changed = Signal(float) annotation_drawn = Signal(list) # List of (x, y) points in normalized coordinates + # Emitted when the user selects an existing polyline on the canvas. + # Carries the associated annotation_id (int) or None if selection is cleared + annotation_selected = Signal(object) def __init__(self, parent=None): """Initialize the annotation canvas widget.""" @@ -141,7 +144,7 @@ class AnnotationCanvasWidget(QWidget): self.zoom_step = 0.1 self.zoom_wheel_step = 0.15 - # Drawing state + # Drawing / interaction state self.is_drawing = False self.polyline_enabled = False self.polyline_pen_color = QColor(255, 0, 0, 128) # Default red with 50% alpha @@ -152,6 +155,10 @@ class AnnotationCanvasWidget(QWidget): self.current_stroke: List[Tuple[float, float]] = [] self.polylines: List[List[Tuple[float, float]]] = [] self.stroke_meta: List[Dict[str, Any]] = [] # per-polyline style (color, width) + # Optional DB annotation_id for each stored polyline (None for temporary / unsaved) + self.polyline_annotation_ids: List[Optional[int]] = [] + # Indices in self.polylines of the currently selected polylines (multi-select) + self.selected_polyline_indices: List[int] = [] # Stored bounding boxes in normalized coordinates (x_min, y_min, x_max, y_max) self.bboxes: List[List[float]] = [] @@ -224,6 +231,8 @@ class AnnotationCanvasWidget(QWidget): self.current_stroke = [] self.polylines = [] self.stroke_meta = [] + self.polyline_annotation_ids = [] + self.selected_polyline_indices = [] self.bboxes = [] self.bbox_meta = [] self.is_drawing = False @@ -389,6 +398,41 @@ class AnnotationCanvasWidget(QWidget): return (int(x), int(y)) return None + def _find_polyline_at( + self, img_x: float, img_y: float, threshold_px: float = 5.0 + ) -> Optional[int]: + """ + Find index of polyline whose geometry is within threshold_px of (img_x, img_y). + Returns the index in self.polylines, or None if none is close enough. + """ + best_index: Optional[int] = None + best_dist: float = float("inf") + + for idx, polyline in enumerate(self.polylines): + if len(polyline) < 2: + continue + + # Quick bounding-box check to skip obviously distant polylines + xs = [p[0] for p in polyline] + ys = [p[1] for p in polyline] + if img_x < min(xs) - threshold_px or img_x > max(xs) + threshold_px: + continue + if img_y < min(ys) - threshold_px or img_y > max(ys) + threshold_px: + continue + + # Precise distance to all segments + for (x1, y1), (x2, y2) in zip(polyline[:-1], polyline[1:]): + d = perpendicular_distance( + (img_x, img_y), (float(x1), float(y1)), (float(x2), float(y2)) + ) + if d < best_dist: + best_dist = d + best_index = idx + + if best_index is not None and best_dist <= threshold_px: + return best_index + return None + def _image_to_normalized_coords(self, x: int, y: int) -> Tuple[float, float]: """Convert image coordinates to normalized coordinates (0-1).""" if self.original_pixmap is None: @@ -399,7 +443,11 @@ class AnnotationCanvasWidget(QWidget): return (norm_x, norm_y) def _add_polyline( - self, img_points: List[Tuple[float, float]], color: QColor, width: int + self, + img_points: List[Tuple[float, float]], + color: QColor, + width: int, + annotation_id: Optional[int] = None, ): """Store a polyline in image coordinates and redraw annotations.""" if not img_points or len(img_points) < 2: @@ -409,6 +457,7 @@ class AnnotationCanvasWidget(QWidget): normalized_points = [(float(x), float(y)) for x, y in img_points] self.polylines.append(normalized_points) self.stroke_meta.append({"color": QColor(color), "width": int(width)}) + self.polyline_annotation_ids.append(annotation_id) self._redraw_annotations() @@ -423,16 +472,29 @@ class AnnotationCanvasWidget(QWidget): painter = QPainter(self.annotation_pixmap) # Draw polylines - for polyline, meta in zip(self.polylines, self.stroke_meta): + for idx, (polyline, meta) in enumerate(zip(self.polylines, self.stroke_meta)): pen_color: QColor = meta.get("color", self.polyline_pen_color) width: int = meta.get("width", self.polyline_pen_width) - pen = QPen( - pen_color, - width, - Qt.SolidLine, - Qt.RoundCap, - Qt.RoundJoin, - ) + + if idx in self.selected_polyline_indices: + # Highlight selected polylines in a distinct color / width + highlight_color = QColor(255, 255, 0, 200) # yellow, semi-opaque + pen = QPen( + highlight_color, + width + 1, + Qt.SolidLine, + Qt.RoundCap, + Qt.RoundJoin, + ) + else: + pen = QPen( + pen_color, + width, + Qt.SolidLine, + Qt.RoundCap, + Qt.RoundJoin, + ) + painter.setPen(pen) for (x1, y1), (x2, y2) in zip(polyline[:-1], polyline[1:]): painter.drawLine(int(x1), int(y1), int(x2), int(y2)) @@ -472,19 +534,58 @@ class AnnotationCanvasWidget(QWidget): self._update_display() def mousePressEvent(self, event: QMouseEvent): - """Handle mouse press events for drawing.""" - if not self.polyline_enabled or self.annotation_pixmap is None: + """Handle mouse press events for drawing and selecting polylines.""" + if self.annotation_pixmap is None: super().mousePressEvent(event) return - if event.button() == Qt.LeftButton: - # Get accurate position using global coordinates - label_pos = self.canvas_label.mapFromGlobal(event.globalPos()) - img_coords = self._canvas_to_image_coords(label_pos) + # Map click to image coordinates + label_pos = self.canvas_label.mapFromGlobal(event.globalPos()) + img_coords = self._canvas_to_image_coords(label_pos) + # Left button + drawing tool enabled -> start a new stroke + if event.button() == Qt.LeftButton and self.polyline_enabled: if img_coords: self.is_drawing = True self.current_stroke = [(float(img_coords[0]), float(img_coords[1]))] + return + + # Left button + drawing tool disabled -> attempt selection of existing polyline + if event.button() == Qt.LeftButton and not self.polyline_enabled: + if img_coords: + idx = self._find_polyline_at(float(img_coords[0]), float(img_coords[1])) + if idx is not None: + if event.modifiers() & Qt.ShiftModifier: + # Multi-select mode: add to current selection (if not already selected) + if idx not in self.selected_polyline_indices: + self.selected_polyline_indices.append(idx) + else: + # Single-select mode: replace current selection + self.selected_polyline_indices = [idx] + + # Build list of selected annotation IDs (ignore None entries) + selected_ids: List[int] = [] + for sel_idx in self.selected_polyline_indices: + if 0 <= sel_idx < len(self.polyline_annotation_ids): + ann_id = self.polyline_annotation_ids[sel_idx] + if isinstance(ann_id, int): + selected_ids.append(ann_id) + + if selected_ids: + self.annotation_selected.emit(selected_ids) + else: + # No valid DB-backed annotations in selection + self.annotation_selected.emit(None) + else: + # Clicked on empty space -> clear selection + self.selected_polyline_indices = [] + self.annotation_selected.emit(None) + + self._redraw_annotations() + return + + # Fallback for other buttons / cases + super().mousePressEvent(event) def mouseMoveEvent(self, event: QMouseEvent): """Handle mouse move events for drawing.""" @@ -633,7 +734,11 @@ class AnnotationCanvasWidget(QWidget): return params or None def draw_saved_polyline( - self, polyline: List[List[float]], color: str, width: int = 3 + self, + polyline: List[List[float]], + color: str, + width: int = 3, + annotation_id: Optional[int] = None, ): """ Draw a polyline from database coordinates onto the annotation canvas. @@ -671,7 +776,7 @@ class AnnotationCanvasWidget(QWidget): # Store and redraw using common pipeline pen_color = QColor(color) pen_color.setAlpha(128) # Add semi-transparency - self._add_polyline(img_coords, pen_color, width) + self._add_polyline(img_coords, pen_color, width, annotation_id=annotation_id) # Store in all_strokes for consistency (uses normalized coordinates) self.all_strokes.append( diff --git a/src/gui/widgets/annotation_tools_widget.py b/src/gui/widgets/annotation_tools_widget.py index 70312bc..e0f68b0 100644 --- a/src/gui/widgets/annotation_tools_widget.py +++ b/src/gui/widgets/annotation_tools_widget.py @@ -58,6 +58,8 @@ class AnnotationToolsWidget(QWidget): class_selected = Signal(dict) class_color_changed = Signal() clear_annotations_requested = Signal() + # Request deletion of the currently selected annotation on the canvas + delete_selected_annotation_requested = Signal() def __init__(self, db_manager: DatabaseManager, parent=None): """ @@ -182,6 +184,12 @@ class AnnotationToolsWidget(QWidget): self.clear_btn.clicked.connect(self._on_clear_annotations) actions_layout.addWidget(self.clear_btn) + # Delete currently selected annotation (enabled when a selection exists) + self.delete_selected_btn = QPushButton("Delete Selected Annotation") + self.delete_selected_btn.clicked.connect(self._on_delete_selected_annotation) + self.delete_selected_btn.setEnabled(False) + actions_layout.addWidget(self.delete_selected_btn) + actions_group.setLayout(actions_layout) layout.addWidget(actions_group) @@ -439,6 +447,20 @@ class AnnotationToolsWidget(QWidget): self.clear_annotations_requested.emit() logger.debug("Clear annotations requested") + def _on_delete_selected_annotation(self): + """Handle delete selected annotation button.""" + self.delete_selected_annotation_requested.emit() + logger.debug("Delete selected annotation requested") + + def set_has_selected_annotation(self, has_selection: bool): + """ + Enable/disable actions that require a selected annotation. + + Args: + has_selection: True if an annotation is currently selected on the canvas. + """ + self.delete_selected_btn.setEnabled(bool(has_selection)) + def get_current_class(self) -> Optional[Dict]: """Get currently selected object class.""" return self.current_class -- 2.49.1