segmentation #1
@@ -2,11 +2,11 @@
|
||||
|
||||
## Project Overview
|
||||
|
||||
A desktop application for detecting organelles and membrane branching structures in microscopy images using YOLOv8s, with comprehensive training, validation, and visualization capabilities.
|
||||
A desktop application for detecting and segmenting organelles and membrane branching structures in microscopy images using YOLOv8s-seg, with comprehensive training, validation, and visualization capabilities including pixel-accurate segmentation masks.
|
||||
|
||||
## Technology Stack
|
||||
|
||||
- **ML Framework**: Ultralytics YOLOv8 (YOLOv8s.pt model)
|
||||
- **ML Framework**: Ultralytics YOLOv8 (YOLOv8s-seg.pt segmentation model)
|
||||
- **GUI Framework**: PySide6 (Qt6 for Python)
|
||||
- **Visualization**: pyqtgraph
|
||||
- **Database**: SQLite3
|
||||
@@ -110,6 +110,7 @@ erDiagram
|
||||
float x_max
|
||||
float y_max
|
||||
float confidence
|
||||
text segmentation_mask
|
||||
datetime detected_at
|
||||
json metadata
|
||||
}
|
||||
@@ -122,6 +123,7 @@ erDiagram
|
||||
float y_min
|
||||
float x_max
|
||||
float y_max
|
||||
text segmentation_mask
|
||||
string annotator
|
||||
datetime created_at
|
||||
boolean verified
|
||||
@@ -139,7 +141,7 @@ Stores information about trained models and their versions.
|
||||
| model_name | TEXT | NOT NULL | User-friendly model name |
|
||||
| model_version | TEXT | NOT NULL | Version string (e.g., "v1.0") |
|
||||
| model_path | TEXT | NOT NULL | Path to model weights file |
|
||||
| base_model | TEXT | NOT NULL | Base model used (e.g., "yolov8s.pt") |
|
||||
| base_model | TEXT | NOT NULL | Base model used (e.g., "yolov8s-seg.pt") |
|
||||
| created_at | TIMESTAMP | DEFAULT CURRENT_TIMESTAMP | Model creation timestamp |
|
||||
| training_params | JSON | | Training hyperparameters |
|
||||
| metrics | JSON | | Validation metrics (mAP, precision, recall) |
|
||||
@@ -159,7 +161,7 @@ Stores metadata about microscopy images.
|
||||
| checksum | TEXT | | MD5 hash for integrity verification |
|
||||
|
||||
#### **detections** table
|
||||
Stores object detection results.
|
||||
Stores object detection results with optional segmentation masks.
|
||||
|
||||
| Column | Type | Constraints | Description |
|
||||
|--------|------|-------------|-------------|
|
||||
@@ -172,11 +174,12 @@ Stores object detection results.
|
||||
| x_max | REAL | NOT NULL | Bounding box right coordinate (normalized 0-1) |
|
||||
| y_max | REAL | NOT NULL | Bounding box bottom coordinate (normalized 0-1) |
|
||||
| confidence | REAL | NOT NULL | Detection confidence score (0-1) |
|
||||
| segmentation_mask | TEXT | | JSON array of polygon coordinates [[x1,y1], [x2,y2], ...] (normalized 0-1) |
|
||||
| detected_at | TIMESTAMP | DEFAULT CURRENT_TIMESTAMP | When detection was performed |
|
||||
| metadata | JSON | | Additional metadata (processing time, etc.) |
|
||||
|
||||
#### **annotations** table
|
||||
Stores manual annotations for training data (future feature).
|
||||
Stores manual annotations for training data with optional segmentation masks (future feature).
|
||||
|
||||
| Column | Type | Constraints | Description |
|
||||
|--------|------|-------------|-------------|
|
||||
@@ -187,6 +190,7 @@ Stores manual annotations for training data (future feature).
|
||||
| y_min | REAL | NOT NULL | Bounding box top coordinate (normalized) |
|
||||
| x_max | REAL | NOT NULL | Bounding box right coordinate (normalized) |
|
||||
| y_max | REAL | NOT NULL | Bounding box bottom coordinate (normalized) |
|
||||
| segmentation_mask | TEXT | | JSON array of polygon coordinates [[x1,y1], [x2,y2], ...] (normalized 0-1) |
|
||||
| annotator | TEXT | | Name of person who created annotation |
|
||||
| created_at | TIMESTAMP | DEFAULT CURRENT_TIMESTAMP | Annotation timestamp |
|
||||
| verified | BOOLEAN | DEFAULT 0 | Whether annotation is verified |
|
||||
@@ -245,8 +249,9 @@ graph TB
|
||||
### Key Components
|
||||
|
||||
#### 1. **YOLO Wrapper** ([`src/model/yolo_wrapper.py`](src/model/yolo_wrapper.py))
|
||||
Encapsulates YOLOv8 operations:
|
||||
- Load pre-trained YOLOv8s model
|
||||
Encapsulates YOLOv8-seg operations:
|
||||
- Load pre-trained YOLOv8s-seg segmentation model
|
||||
- Extract pixel-accurate segmentation masks
|
||||
- Fine-tune on custom microscopy dataset
|
||||
- Export trained models
|
||||
- Provide training progress callbacks
|
||||
@@ -255,10 +260,10 @@ Encapsulates YOLOv8 operations:
|
||||
**Key Methods:**
|
||||
```python
|
||||
class YOLOWrapper:
|
||||
def __init__(self, model_path: str = "yolov8s.pt")
|
||||
def __init__(self, model_path: str = "yolov8s-seg.pt")
|
||||
def train(self, data_yaml: str, epochs: int, callbacks: dict)
|
||||
def validate(self, data_yaml: str) -> dict
|
||||
def predict(self, image_path: str, conf: float) -> list
|
||||
def predict(self, image_path: str, conf: float) -> list # Returns detections with segmentation masks
|
||||
def export_model(self, format: str, output_path: str)
|
||||
```
|
||||
|
||||
@@ -435,7 +440,7 @@ image_repository:
|
||||
allowed_extensions: [".jpg", ".jpeg", ".png", ".tif", ".tiff"]
|
||||
|
||||
models:
|
||||
default_base_model: "yolov8s.pt"
|
||||
default_base_model: "yolov8s-seg.pt"
|
||||
models_directory: "data/models"
|
||||
|
||||
training:
|
||||
|
||||
178
BUILD.md
Normal file
178
BUILD.md
Normal file
@@ -0,0 +1,178 @@
|
||||
# Building and Publishing Guide
|
||||
|
||||
This guide explains how to build and publish the microscopy-object-detection package.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
```bash
|
||||
pip install build twine
|
||||
```
|
||||
|
||||
## Building the Package
|
||||
|
||||
### 1. Clean Previous Builds
|
||||
|
||||
```bash
|
||||
rm -rf build/ dist/ *.egg-info
|
||||
```
|
||||
|
||||
### 2. Build Distribution Archives
|
||||
|
||||
```bash
|
||||
python -m build
|
||||
```
|
||||
|
||||
This will create both wheel (`.whl`) and source distribution (`.tar.gz`) in the `dist/` directory.
|
||||
|
||||
### 3. Verify the Build
|
||||
|
||||
```bash
|
||||
ls dist/
|
||||
# Should show:
|
||||
# microscopy_object_detection-1.0.0-py3-none-any.whl
|
||||
# microscopy_object_detection-1.0.0.tar.gz
|
||||
```
|
||||
|
||||
## Testing the Package Locally
|
||||
|
||||
### Install in Development Mode
|
||||
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### Install from Built Package
|
||||
|
||||
```bash
|
||||
pip install dist/microscopy_object_detection-1.0.0-py3-none-any.whl
|
||||
```
|
||||
|
||||
### Test the Installation
|
||||
|
||||
```bash
|
||||
# Test CLI
|
||||
microscopy-detect --version
|
||||
|
||||
# Test GUI launcher
|
||||
microscopy-detect-gui
|
||||
```
|
||||
|
||||
## Publishing to PyPI
|
||||
|
||||
### 1. Configure PyPI Credentials
|
||||
|
||||
Create or update `~/.pypirc`:
|
||||
|
||||
```ini
|
||||
[pypi]
|
||||
username = __token__
|
||||
password = pypi-YOUR-API-TOKEN-HERE
|
||||
```
|
||||
|
||||
### 2. Upload to Test PyPI (Recommended First)
|
||||
|
||||
```bash
|
||||
python -m twine upload --repository testpypi dist/*
|
||||
```
|
||||
|
||||
Then test installation:
|
||||
|
||||
```bash
|
||||
pip install --index-url https://test.pypi.org/simple/ microscopy-object-detection
|
||||
```
|
||||
|
||||
### 3. Upload to PyPI
|
||||
|
||||
```bash
|
||||
python -m twine upload dist/*
|
||||
```
|
||||
|
||||
## Version Management
|
||||
|
||||
Update version in multiple files:
|
||||
- `setup.py`: Update `version` parameter
|
||||
- `pyproject.toml`: Update `version` field
|
||||
- `src/__init__.py`: Update `__version__` variable
|
||||
|
||||
## Git Tags
|
||||
|
||||
After publishing, tag the release:
|
||||
|
||||
```bash
|
||||
git tag -a v1.0.0 -m "Release version 1.0.0"
|
||||
git push origin v1.0.0
|
||||
```
|
||||
|
||||
## Package Structure
|
||||
|
||||
The built package includes:
|
||||
- All Python source files in `src/`
|
||||
- Configuration files in `config/`
|
||||
- Database schema file (`src/database/schema.sql`)
|
||||
- Documentation files (README.md, LICENSE, etc.)
|
||||
- Entry points for CLI and GUI
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Import Errors
|
||||
If you get import errors, ensure:
|
||||
- All `__init__.py` files are present
|
||||
- Package structure follows the setup configuration
|
||||
- Dependencies are listed in `requirements.txt`
|
||||
|
||||
### Missing Files
|
||||
If files are missing in the built package:
|
||||
- Check `MANIFEST.in` includes the required patterns
|
||||
- Check `pyproject.toml` package-data configuration
|
||||
- Rebuild with `python -m build --no-isolation` for debugging
|
||||
|
||||
### Version Conflicts
|
||||
If version conflicts occur:
|
||||
- Ensure version is consistent across all files
|
||||
- Clear build artifacts and rebuild
|
||||
- Check for cached installations: `pip list | grep microscopy`
|
||||
|
||||
## CI/CD Integration
|
||||
|
||||
### GitHub Actions Example
|
||||
|
||||
```yaml
|
||||
name: Build and Publish
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [created]
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: '3.8'
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install build twine
|
||||
- name: Build package
|
||||
run: python -m build
|
||||
- name: Publish to PyPI
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
|
||||
run: twine upload dist/*
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Version Bumping**: Use semantic versioning (MAJOR.MINOR.PATCH)
|
||||
2. **Testing**: Always test on Test PyPI before publishing to PyPI
|
||||
3. **Documentation**: Update README.md and CHANGELOG.md for each release
|
||||
4. **Git Tags**: Tag releases in git for easy reference
|
||||
5. **Dependencies**: Keep requirements.txt updated and specify version ranges
|
||||
|
||||
## Resources
|
||||
|
||||
- [Python Packaging Guide](https://packaging.python.org/)
|
||||
- [setuptools Documentation](https://setuptools.pypa.io/)
|
||||
- [PyPI Publishing Guide](https://packaging.python.org/tutorials/packaging-projects/)
|
||||
236
INSTALL_TEST.md
Normal file
236
INSTALL_TEST.md
Normal file
@@ -0,0 +1,236 @@
|
||||
# Installation Testing Guide
|
||||
|
||||
This guide helps you verify that the package installation works correctly.
|
||||
|
||||
## Clean Installation Test
|
||||
|
||||
### 1. Remove Any Previous Installations
|
||||
|
||||
```bash
|
||||
# Deactivate any active virtual environment
|
||||
deactivate
|
||||
|
||||
# Remove old virtual environment (if exists)
|
||||
rm -rf venv
|
||||
|
||||
# Create fresh virtual environment
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate # On Linux/Mac
|
||||
# or
|
||||
venv\Scripts\activate # On Windows
|
||||
```
|
||||
|
||||
### 2. Install the Package
|
||||
|
||||
#### Option A: Editable/Development Install
|
||||
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
This allows you to modify source code and see changes immediately.
|
||||
|
||||
#### Option B: Regular Install
|
||||
|
||||
```bash
|
||||
pip install .
|
||||
```
|
||||
|
||||
This installs the package as if it were from PyPI.
|
||||
|
||||
### 3. Verify Installation
|
||||
|
||||
```bash
|
||||
# Check package is installed
|
||||
pip list | grep microscopy
|
||||
|
||||
# Check version
|
||||
microscopy-detect --version
|
||||
# Expected output: microscopy-object-detection 1.0.0
|
||||
|
||||
# Test Python import
|
||||
python -c "import src; print(src.__version__)"
|
||||
# Expected output: 1.0.0
|
||||
```
|
||||
|
||||
### 4. Test Entry Points
|
||||
|
||||
```bash
|
||||
# Test CLI
|
||||
microscopy-detect --help
|
||||
|
||||
# Test GUI launcher (will open window)
|
||||
microscopy-detect-gui
|
||||
```
|
||||
|
||||
### 5. Verify Package Contents
|
||||
|
||||
```python
|
||||
# Run this in Python shell
|
||||
import src
|
||||
import src.database
|
||||
import src.model
|
||||
import src.gui
|
||||
|
||||
# Check schema file is included
|
||||
from pathlib import Path
|
||||
import src.database
|
||||
db_path = Path(src.database.__file__).parent
|
||||
schema_file = db_path / 'schema.sql'
|
||||
print(f"Schema file exists: {schema_file.exists()}")
|
||||
# Expected: Schema file exists: True
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: ModuleNotFoundError
|
||||
|
||||
**Error:**
|
||||
```
|
||||
ModuleNotFoundError: No module named 'src'
|
||||
```
|
||||
|
||||
**Solution:**
|
||||
```bash
|
||||
# Reinstall with verbose output
|
||||
pip install -e . -v
|
||||
|
||||
# Or try regular install
|
||||
pip install . --force-reinstall
|
||||
```
|
||||
|
||||
### Issue: Entry Points Not Working
|
||||
|
||||
**Error:**
|
||||
```
|
||||
microscopy-detect: command not found
|
||||
```
|
||||
|
||||
**Solution:**
|
||||
```bash
|
||||
# Check if scripts are in PATH
|
||||
which microscopy-detect
|
||||
|
||||
# If not found, check pip install location
|
||||
pip show microscopy-object-detection
|
||||
|
||||
# You might need to add to PATH or use full path
|
||||
~/.local/bin/microscopy-detect # Linux
|
||||
```
|
||||
|
||||
### Issue: Import Errors for PySide6
|
||||
|
||||
**Error:**
|
||||
```
|
||||
ImportError: cannot import name 'QApplication' from 'PySide6.QtWidgets'
|
||||
```
|
||||
|
||||
**Solution:**
|
||||
```bash
|
||||
# Install Qt dependencies (Linux only)
|
||||
sudo apt-get install libxcb-xinerama0
|
||||
|
||||
# Reinstall PySide6
|
||||
pip uninstall PySide6
|
||||
pip install PySide6
|
||||
```
|
||||
|
||||
### Issue: Config Files Not Found
|
||||
|
||||
**Error:**
|
||||
```
|
||||
FileNotFoundError: config/app_config.yaml
|
||||
```
|
||||
|
||||
**Solution:**
|
||||
The config file should be created automatically. If not:
|
||||
```bash
|
||||
# Create config directory in your home
|
||||
mkdir -p ~/.microscopy-detect
|
||||
cp config/app_config.yaml ~/.microscopy-detect/
|
||||
|
||||
# Or run from source directory first time
|
||||
cd /home/martin/code/object_detection
|
||||
python main.py
|
||||
```
|
||||
|
||||
## Manual Testing Checklist
|
||||
|
||||
- [ ] Package installs without errors
|
||||
- [ ] Version command works (`microscopy-detect --version`)
|
||||
- [ ] Help command works (`microscopy-detect --help`)
|
||||
- [ ] GUI launches (`microscopy-detect-gui`)
|
||||
- [ ] Can import all modules in Python
|
||||
- [ ] Database schema file is accessible
|
||||
- [ ] Configuration loads correctly
|
||||
|
||||
## Build and Install from Wheel
|
||||
|
||||
```bash
|
||||
# Build the package
|
||||
python -m build
|
||||
|
||||
# Install from wheel
|
||||
pip install dist/microscopy_object_detection-1.0.0-py3-none-any.whl
|
||||
|
||||
# Test
|
||||
microscopy-detect --version
|
||||
```
|
||||
|
||||
## Uninstall
|
||||
|
||||
```bash
|
||||
pip uninstall microscopy-object-detection
|
||||
```
|
||||
|
||||
## Development Workflow
|
||||
|
||||
### After Code Changes
|
||||
|
||||
If installed with `-e` (editable mode):
|
||||
- Python code changes are immediately available
|
||||
- No need to reinstall
|
||||
|
||||
If installed with regular `pip install .`:
|
||||
- Reinstall after changes: `pip install . --force-reinstall`
|
||||
|
||||
### After Adding New Files
|
||||
|
||||
```bash
|
||||
# Reinstall to include new files
|
||||
pip install -e . --force-reinstall
|
||||
```
|
||||
|
||||
## Expected Installation Output
|
||||
|
||||
```
|
||||
Processing /home/martin/code/object_detection
|
||||
Installing build dependencies ... done
|
||||
Getting requirements to build wheel ... done
|
||||
Preparing metadata (pyproject.toml) ... done
|
||||
Building wheels for collected packages: microscopy-object-detection
|
||||
Building wheel for microscopy-object-detection (pyproject.toml) ... done
|
||||
Successfully built microscopy-object-detection
|
||||
Installing collected packages: microscopy-object-detection
|
||||
Successfully installed microscopy-object-detection-1.0.0
|
||||
```
|
||||
|
||||
## Success Criteria
|
||||
|
||||
Installation is successful when:
|
||||
1. ✅ No error messages during installation
|
||||
2. ✅ `pip list` shows the package
|
||||
3. ✅ `microscopy-detect --version` returns correct version
|
||||
4. ✅ GUI launches without errors
|
||||
5. ✅ All Python modules can be imported
|
||||
6. ✅ Database operations work
|
||||
7. ✅ Detection functionality works
|
||||
|
||||
## Next Steps After Successful Install
|
||||
|
||||
1. Configure image repository path
|
||||
2. Run first detection
|
||||
3. Train a custom model
|
||||
4. Export results
|
||||
|
||||
For usage instructions, see [QUICKSTART.md](QUICKSTART.md)
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 Your Name
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
37
MANIFEST.in
Normal file
37
MANIFEST.in
Normal file
@@ -0,0 +1,37 @@
|
||||
# Include documentation files
|
||||
include README.md
|
||||
include LICENSE
|
||||
include ARCHITECTURE.md
|
||||
include IMPLEMENTATION_GUIDE.md
|
||||
include QUICKSTART.md
|
||||
include PLAN_SUMMARY.md
|
||||
|
||||
# Include requirements
|
||||
include requirements.txt
|
||||
|
||||
# Include configuration files
|
||||
recursive-include config *.yaml
|
||||
recursive-include config *.yml
|
||||
|
||||
# Include database schema
|
||||
recursive-include src/database *.sql
|
||||
|
||||
# Include tests
|
||||
recursive-include tests *.py
|
||||
|
||||
# Exclude compiled Python files
|
||||
global-exclude *.pyc
|
||||
global-exclude *.pyo
|
||||
global-exclude __pycache__
|
||||
global-exclude *.so
|
||||
global-exclude .DS_Store
|
||||
|
||||
# Exclude git and IDE files
|
||||
global-exclude .git*
|
||||
global-exclude .vscode
|
||||
global-exclude .idea
|
||||
|
||||
# Exclude build artifacts
|
||||
prune build
|
||||
prune dist
|
||||
prune *.egg-info
|
||||
@@ -38,7 +38,7 @@ This will install:
|
||||
- OpenCV and Pillow (image processing)
|
||||
- And other dependencies
|
||||
|
||||
**Note:** The first run will automatically download the YOLOv8s.pt model (~22MB).
|
||||
**Note:** The first run will automatically download the YOLOv8s-seg.pt segmentation model (~23MB).
|
||||
|
||||
### 4. Verify Installation
|
||||
|
||||
@@ -84,11 +84,11 @@ In the Settings dialog:
|
||||
### Single Image Detection
|
||||
|
||||
1. Go to the **Detection** tab
|
||||
2. Select a model from the dropdown (default: Base Model yolov8s.pt)
|
||||
2. Select a model from the dropdown (default: Base Model yolov8s-seg.pt)
|
||||
3. Adjust confidence threshold with the slider
|
||||
4. Click "Detect Single Image"
|
||||
5. Select an image file
|
||||
6. View results in the results panel
|
||||
6. View results with segmentation masks overlaid on the image
|
||||
|
||||
### Batch Detection
|
||||
|
||||
@@ -108,9 +108,18 @@ Detection results include:
|
||||
- **Class names**: Types of objects detected (e.g., organelle, membrane_branch)
|
||||
- **Confidence scores**: Detection confidence (0-1)
|
||||
- **Bounding boxes**: Object locations (stored in database)
|
||||
- **Segmentation masks**: Pixel-accurate polygon coordinates for each detected object
|
||||
|
||||
All results are stored in the SQLite database at [`data/detections.db`](data/detections.db).
|
||||
|
||||
### Segmentation Visualization
|
||||
|
||||
The application automatically displays segmentation masks when available:
|
||||
- Semi-transparent colored overlay (30% opacity) showing the exact shape of detected objects
|
||||
- Polygon contours outlining each segmentation
|
||||
- Color-coded by object class
|
||||
- Toggle-able in future versions
|
||||
|
||||
## Database
|
||||
|
||||
The application uses SQLite to store:
|
||||
@@ -176,7 +185,7 @@ sudo apt-get install libxcb-xinerama0
|
||||
### Detection Not Working
|
||||
|
||||
**No models available**
|
||||
- The base YOLOv8s model will be downloaded automatically on first use
|
||||
- The base YOLOv8s-seg segmentation model will be downloaded automatically on first use
|
||||
- Make sure you have internet connection for the first run
|
||||
|
||||
**Images not found**
|
||||
|
||||
62
README.md
62
README.md
@@ -1,6 +1,6 @@
|
||||
# Microscopy Object Detection Application
|
||||
|
||||
A desktop application for detecting organelles and membrane branching structures in microscopy images using YOLOv8, featuring comprehensive training, validation, and visualization capabilities.
|
||||
A desktop application for detecting and segmenting organelles and membrane branching structures in microscopy images using YOLOv8-seg, featuring comprehensive training, validation, and visualization capabilities with pixel-accurate segmentation masks.
|
||||
|
||||

|
||||

|
||||
@@ -8,8 +8,8 @@ A desktop application for detecting organelles and membrane branching structures
|
||||
|
||||
## Features
|
||||
|
||||
- **🎯 Object Detection**: Real-time and batch detection of microscopy objects
|
||||
- **🎓 Model Training**: Fine-tune YOLOv8s on custom microscopy datasets
|
||||
- **🎯 Object Detection & Segmentation**: Real-time and batch detection with pixel-accurate segmentation masks
|
||||
- **🎓 Model Training**: Fine-tune YOLOv8s-seg on custom microscopy datasets
|
||||
- **📊 Validation & Metrics**: Comprehensive model validation with visualization
|
||||
- **💾 Database Storage**: SQLite database for detection results and metadata
|
||||
- **📈 Visualization**: Interactive plots and charts using pyqtgraph
|
||||
@@ -34,14 +34,24 @@ A desktop application for detecting organelles and membrane branching structures
|
||||
|
||||
## Installation
|
||||
|
||||
### 1. Clone the Repository
|
||||
### Option 1: Install from PyPI (Recommended)
|
||||
|
||||
```bash
|
||||
pip install microscopy-object-detection
|
||||
```
|
||||
|
||||
This will install the package and all its dependencies.
|
||||
|
||||
### Option 2: Install from Source
|
||||
|
||||
#### 1. Clone the Repository
|
||||
|
||||
```bash
|
||||
git clone <repository-url>
|
||||
cd object_detection
|
||||
```
|
||||
|
||||
### 2. Create Virtual Environment
|
||||
#### 2. Create Virtual Environment
|
||||
|
||||
```bash
|
||||
# Linux/Mac
|
||||
@@ -53,25 +63,44 @@ python -m venv venv
|
||||
venv\Scripts\activate
|
||||
```
|
||||
|
||||
### 3. Install Dependencies
|
||||
#### 3. Install in Development Mode
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
# Install in editable mode with dev dependencies
|
||||
pip install -e ".[dev]"
|
||||
|
||||
# Or install just the package
|
||||
pip install .
|
||||
```
|
||||
|
||||
### 4. Download Base Model
|
||||
|
||||
The application will automatically download the YOLOv8s.pt model on first use, or you can download it manually:
|
||||
The application will automatically download the YOLOv8s-seg.pt segmentation model on first use, or you can download it manually:
|
||||
|
||||
```bash
|
||||
# The model will be downloaded automatically by ultralytics
|
||||
# Or download manually from: https://github.com/ultralytics/assets/releases
|
||||
```
|
||||
|
||||
**Note:** YOLOv8s-seg is a segmentation model that provides pixel-accurate masks for detected objects, enabling more precise analysis than standard bounding box detection.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Launch the Application
|
||||
|
||||
After installation, you can launch the application in two ways:
|
||||
|
||||
**Using the GUI launcher:**
|
||||
```bash
|
||||
microscopy-detect-gui
|
||||
```
|
||||
|
||||
**Or using Python directly:**
|
||||
```bash
|
||||
python -m microscopy_object_detection
|
||||
```
|
||||
|
||||
**If installed from source:**
|
||||
```bash
|
||||
python main.py
|
||||
```
|
||||
@@ -85,11 +114,12 @@ python main.py
|
||||
### 3. Perform Detection
|
||||
|
||||
1. Navigate to the **Detection** tab
|
||||
2. Select a model (default: yolov8s.pt)
|
||||
2. Select a model (default: yolov8s-seg.pt)
|
||||
3. Choose an image or folder
|
||||
4. Set confidence threshold
|
||||
5. Click **Detect**
|
||||
6. View results and save to database
|
||||
6. View results with segmentation masks overlaid
|
||||
7. Save results to database
|
||||
|
||||
### 4. Train Custom Model
|
||||
|
||||
@@ -212,8 +242,8 @@ The application uses SQLite with the following main tables:
|
||||
|
||||
- **models**: Stores trained model information and metrics
|
||||
- **images**: Stores image metadata and paths
|
||||
- **detections**: Stores detection results with bounding boxes
|
||||
- **annotations**: Stores manual annotations (future feature)
|
||||
- **detections**: Stores detection results with bounding boxes and segmentation masks (polygon coordinates)
|
||||
- **annotations**: Stores manual annotations with optional segmentation masks (future feature)
|
||||
|
||||
See [`ARCHITECTURE.md`](ARCHITECTURE.md) for detailed schema information.
|
||||
|
||||
@@ -230,7 +260,7 @@ image_repository:
|
||||
allowed_extensions: [".jpg", ".jpeg", ".png", ".tif", ".tiff"]
|
||||
|
||||
models:
|
||||
default_base_model: "yolov8s.pt"
|
||||
default_base_model: "yolov8s-seg.pt"
|
||||
models_directory: "data/models"
|
||||
|
||||
training:
|
||||
@@ -258,7 +288,7 @@ visualization:
|
||||
from src.model.yolo_wrapper import YOLOWrapper
|
||||
|
||||
# Initialize wrapper
|
||||
yolo = YOLOWrapper("yolov8s.pt")
|
||||
yolo = YOLOWrapper("yolov8s-seg.pt")
|
||||
|
||||
# Train model
|
||||
results = yolo.train(
|
||||
@@ -393,10 +423,10 @@ make html
|
||||
|
||||
**Issue**: Model not found error
|
||||
|
||||
**Solution**: Ensure YOLOv8s.pt is downloaded. Run:
|
||||
**Solution**: Ensure YOLOv8s-seg.pt is downloaded. Run:
|
||||
```python
|
||||
from ultralytics import YOLO
|
||||
model = YOLO('yolov8s.pt') # Will auto-download
|
||||
model = YOLO('yolov8s-seg.pt') # Will auto-download
|
||||
```
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ image_repository:
|
||||
- .tiff
|
||||
- .bmp
|
||||
models:
|
||||
default_base_model: yolov8s.pt
|
||||
default_base_model: yolov8s-seg.pt
|
||||
models_directory: data/models
|
||||
training:
|
||||
default_epochs: 100
|
||||
|
||||
220
docs/IMAGE_CLASS_USAGE.md
Normal file
220
docs/IMAGE_CLASS_USAGE.md
Normal file
@@ -0,0 +1,220 @@
|
||||
# Image Class Usage Guide
|
||||
|
||||
The `Image` class provides a convenient way to load and work with images in the microscopy object detection application.
|
||||
|
||||
## Supported Formats
|
||||
|
||||
The Image class supports the following image formats:
|
||||
- `.jpg`, `.jpeg` - JPEG images
|
||||
- `.png` - PNG images
|
||||
- `.tif`, `.tiff` - TIFF images (commonly used in microscopy)
|
||||
- `.bmp` - Bitmap images
|
||||
|
||||
## Basic Usage
|
||||
|
||||
### Loading an Image
|
||||
|
||||
```python
|
||||
from src.utils import Image, ImageLoadError
|
||||
|
||||
# Load an image from a file path
|
||||
try:
|
||||
img = Image("path/to/image.jpg")
|
||||
print(f"Loaded image: {img.width}x{img.height} pixels")
|
||||
except ImageLoadError as e:
|
||||
print(f"Failed to load image: {e}")
|
||||
```
|
||||
|
||||
### Accessing Image Properties
|
||||
|
||||
```python
|
||||
img = Image("microscopy_image.tif")
|
||||
|
||||
# Basic properties
|
||||
print(f"Width: {img.width} pixels")
|
||||
print(f"Height: {img.height} pixels")
|
||||
print(f"Channels: {img.channels}")
|
||||
print(f"Format: {img.format}")
|
||||
print(f"Shape: {img.shape}") # (height, width, channels)
|
||||
|
||||
# File information
|
||||
print(f"File size: {img.size_mb:.2f} MB")
|
||||
print(f"File size: {img.size_bytes} bytes")
|
||||
|
||||
# Image type checks
|
||||
print(f"Is color: {img.is_color()}")
|
||||
print(f"Is grayscale: {img.is_grayscale()}")
|
||||
|
||||
# String representation
|
||||
print(img) # Shows summary of image properties
|
||||
```
|
||||
|
||||
### Working with Image Data
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
|
||||
img = Image("sample.png")
|
||||
|
||||
# Get image data as numpy array (OpenCV format, BGR)
|
||||
bgr_data = img.data
|
||||
print(f"Data shape: {bgr_data.shape}")
|
||||
print(f"Data type: {bgr_data.dtype}")
|
||||
|
||||
# Get image as RGB (for display or processing)
|
||||
rgb_data = img.get_rgb()
|
||||
|
||||
# Get grayscale version
|
||||
gray_data = img.get_grayscale()
|
||||
|
||||
# Create a copy (for modifications)
|
||||
img_copy = img.copy()
|
||||
img_copy[0, 0] = [255, 255, 255] # Modify copy, original unchanged
|
||||
|
||||
# Resize image (returns new array, doesn't modify original)
|
||||
resized = img.resize(640, 640)
|
||||
```
|
||||
|
||||
### Using PIL Image
|
||||
|
||||
```python
|
||||
img = Image("photo.jpg")
|
||||
|
||||
# Access as PIL Image (RGB format)
|
||||
pil_img = img.pil_image
|
||||
|
||||
# Use PIL methods
|
||||
pil_img.show() # Display image
|
||||
pil_img.save("output.png") # Save with PIL
|
||||
```
|
||||
|
||||
## Integration with YOLO
|
||||
|
||||
```python
|
||||
from src.utils import Image
|
||||
from ultralytics import YOLO
|
||||
|
||||
# Load model and image
|
||||
model = YOLO("yolov8n.pt")
|
||||
img = Image("microscopy/cell_01.tif")
|
||||
|
||||
# Run inference (YOLO accepts file paths or numpy arrays)
|
||||
results = model(img.data)
|
||||
|
||||
# Or use the file path directly
|
||||
results = model(str(img.path))
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
```python
|
||||
from src.utils import Image, ImageLoadError
|
||||
|
||||
def process_image(image_path):
|
||||
try:
|
||||
img = Image(image_path)
|
||||
# Process the image...
|
||||
return img
|
||||
except ImageLoadError as e:
|
||||
print(f"Cannot load image: {e}")
|
||||
return None
|
||||
```
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Batch Processing
|
||||
|
||||
```python
|
||||
from pathlib import Path
|
||||
from src.utils import Image, ImageLoadError
|
||||
|
||||
def process_image_directory(directory):
|
||||
"""Process all images in a directory."""
|
||||
image_paths = Path(directory).glob("*.tif")
|
||||
|
||||
for path in image_paths:
|
||||
try:
|
||||
img = Image(path)
|
||||
print(f"Processing {img.path.name}: {img.width}x{img.height}")
|
||||
# Process the image...
|
||||
except ImageLoadError as e:
|
||||
print(f"Skipping {path}: {e}")
|
||||
```
|
||||
|
||||
### Using with OpenCV Operations
|
||||
|
||||
```python
|
||||
import cv2
|
||||
from src.utils import Image
|
||||
|
||||
img = Image("input.jpg")
|
||||
|
||||
# Apply OpenCV operations on the data
|
||||
blurred = cv2.GaussianBlur(img.data, (5, 5), 0)
|
||||
edges = cv2.Canny(img.data, 100, 200)
|
||||
|
||||
# Note: These operations don't modify the original img.data
|
||||
```
|
||||
|
||||
### Memory Efficient Processing
|
||||
|
||||
```python
|
||||
from src.utils import Image
|
||||
|
||||
# The Image class loads data into memory
|
||||
img = Image("large_image.tif")
|
||||
print(f"Image size in memory: {img.data.nbytes / (1024**2):.2f} MB")
|
||||
|
||||
# When processing many images, consider loading one at a time
|
||||
# and releasing memory by deleting the object
|
||||
del img
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always use try-except** when loading images to handle errors gracefully
|
||||
2. **Check image properties** before processing to ensure compatibility
|
||||
3. **Use copy()** when you need to modify image data without affecting the original
|
||||
4. **Path objects work too** - The class accepts both strings and Path objects
|
||||
5. **Consider memory usage** when working with large images or batches
|
||||
|
||||
## Example: Complete Workflow
|
||||
|
||||
```python
|
||||
from src.utils import Image, ImageLoadError
|
||||
from src.utils.file_utils import get_image_files
|
||||
|
||||
def analyze_microscopy_images(directory):
|
||||
"""Analyze all microscopy images in a directory."""
|
||||
|
||||
# Get all image files
|
||||
image_files = get_image_files(directory, recursive=True)
|
||||
|
||||
results = []
|
||||
for image_path in image_files:
|
||||
try:
|
||||
# Load image
|
||||
img = Image(image_path)
|
||||
|
||||
# Analyze
|
||||
result = {
|
||||
'filename': img.path.name,
|
||||
'width': img.width,
|
||||
'height': img.height,
|
||||
'channels': img.channels,
|
||||
'format': img.format,
|
||||
'size_mb': img.size_mb,
|
||||
'is_color': img.is_color()
|
||||
}
|
||||
|
||||
results.append(result)
|
||||
print(f"✓ Analyzed {img.path.name}")
|
||||
|
||||
except ImageLoadError as e:
|
||||
print(f"✗ Failed to load {image_path}: {e}")
|
||||
|
||||
return results
|
||||
|
||||
# Run analysis
|
||||
results = analyze_microscopy_images("data/datasets/cells")
|
||||
print(f"\nProcessed {len(results)} images")
|
||||
151
examples/image_demo.py
Normal file
151
examples/image_demo.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""
|
||||
Example script demonstrating the Image class functionality.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from src.utils import Image, ImageLoadError
|
||||
|
||||
|
||||
def demonstrate_image_loading():
|
||||
"""Demonstrate basic image loading functionality."""
|
||||
print("=" * 60)
|
||||
print("Image Class Demonstration")
|
||||
print("=" * 60)
|
||||
|
||||
# Example 1: Try to load an image (replace with your own path)
|
||||
example_paths = [
|
||||
"data/datasets/example.jpg",
|
||||
"data/datasets/sample.png",
|
||||
"tests/test_image.jpg",
|
||||
]
|
||||
|
||||
loaded_img = None
|
||||
for image_path in example_paths:
|
||||
if Path(image_path).exists():
|
||||
try:
|
||||
print(f"\n1. Loading image: {image_path}")
|
||||
img = Image(image_path)
|
||||
loaded_img = img
|
||||
print(f" ✓ Successfully loaded!")
|
||||
print(f" {img}")
|
||||
break
|
||||
except ImageLoadError as e:
|
||||
print(f" ✗ Failed: {e}")
|
||||
else:
|
||||
print(f"\n1. Image not found: {image_path}")
|
||||
|
||||
if loaded_img is None:
|
||||
print("\nNo example images found. Creating a test image...")
|
||||
create_test_image()
|
||||
return
|
||||
|
||||
# Example 2: Access image properties
|
||||
print(f"\n2. Image Properties:")
|
||||
print(f" Width: {loaded_img.width} pixels")
|
||||
print(f" Height: {loaded_img.height} pixels")
|
||||
print(f" Channels: {loaded_img.channels}")
|
||||
print(f" Format: {loaded_img.format.upper()}")
|
||||
print(f" Shape: {loaded_img.shape}")
|
||||
print(f" File size: {loaded_img.size_mb:.2f} MB")
|
||||
print(f" Is color: {loaded_img.is_color()}")
|
||||
print(f" Is grayscale: {loaded_img.is_grayscale()}")
|
||||
|
||||
# Example 3: Get different formats
|
||||
print(f"\n3. Accessing Image Data:")
|
||||
print(f" BGR data shape: {loaded_img.data.shape}")
|
||||
print(f" RGB data shape: {loaded_img.get_rgb().shape}")
|
||||
print(f" Grayscale shape: {loaded_img.get_grayscale().shape}")
|
||||
print(f" PIL image mode: {loaded_img.pil_image.mode}")
|
||||
|
||||
# Example 4: Resizing
|
||||
print(f"\n4. Resizing Image:")
|
||||
resized = loaded_img.resize(320, 320)
|
||||
print(f" Original size: {loaded_img.width}x{loaded_img.height}")
|
||||
print(f" Resized to: {resized.shape[1]}x{resized.shape[0]}")
|
||||
|
||||
# Example 5: Working with copies
|
||||
print(f"\n5. Creating Copies:")
|
||||
copy = loaded_img.copy()
|
||||
print(f" Created copy with shape: {copy.shape}")
|
||||
print(f" Original data unchanged: {(loaded_img.data == copy).all()}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Demonstration Complete!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def create_test_image():
|
||||
"""Create a test image for demonstration purposes."""
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
print("\nCreating a test image...")
|
||||
|
||||
# Create a colorful test image
|
||||
width, height = 400, 300
|
||||
test_img = np.zeros((height, width, 3), dtype=np.uint8)
|
||||
|
||||
# Add some colors
|
||||
test_img[:100, :] = [255, 0, 0] # Blue section
|
||||
test_img[100:200, :] = [0, 255, 0] # Green section
|
||||
test_img[200:, :] = [0, 0, 255] # Red section
|
||||
|
||||
# Save the test image
|
||||
test_path = Path("test_demo_image.png")
|
||||
cv2.imwrite(str(test_path), test_img)
|
||||
print(f"Test image created: {test_path}")
|
||||
|
||||
# Now load and demonstrate with it
|
||||
try:
|
||||
img = Image(test_path)
|
||||
print(f"\nLoaded test image: {img}")
|
||||
print(f"Dimensions: {img.width}x{img.height}")
|
||||
print(f"Channels: {img.channels}")
|
||||
print(f"Format: {img.format}")
|
||||
|
||||
# Clean up
|
||||
test_path.unlink()
|
||||
print(f"\nTest image cleaned up.")
|
||||
|
||||
except ImageLoadError as e:
|
||||
print(f"Error loading test image: {e}")
|
||||
|
||||
|
||||
def demonstrate_error_handling():
|
||||
"""Demonstrate error handling."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Error Handling Demonstration")
|
||||
print("=" * 60)
|
||||
|
||||
# Try to load non-existent file
|
||||
print("\n1. Loading non-existent file:")
|
||||
try:
|
||||
img = Image("nonexistent.jpg")
|
||||
except ImageLoadError as e:
|
||||
print(f" ✓ Caught error: {e}")
|
||||
|
||||
# Try unsupported format
|
||||
print("\n2. Loading unsupported format:")
|
||||
try:
|
||||
# Create a text file
|
||||
test_file = Path("test.txt")
|
||||
test_file.write_text("not an image")
|
||||
img = Image(test_file)
|
||||
except ImageLoadError as e:
|
||||
print(f" ✓ Caught error: {e}")
|
||||
test_file.unlink() # Clean up
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("\n")
|
||||
demonstrate_image_loading()
|
||||
print("\n")
|
||||
demonstrate_error_handling()
|
||||
print("\n")
|
||||
5
main.py
5
main.py
@@ -6,12 +6,13 @@ Main entry point for the application.
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add src directory to path
|
||||
# Add src directory to path for development mode
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from PySide6.QtWidgets import QApplication
|
||||
from PySide6.QtCore import Qt
|
||||
|
||||
from src import __version__
|
||||
from src.gui.main_window import MainWindow
|
||||
from src.utils.logger import setup_logging
|
||||
from src.utils.config_manager import ConfigManager
|
||||
@@ -37,7 +38,7 @@ def main():
|
||||
app = QApplication(sys.argv)
|
||||
app.setApplicationName("Microscopy Object Detection")
|
||||
app.setOrganizationName("MicroscopyLab")
|
||||
app.setApplicationVersion("1.0.0")
|
||||
app.setApplicationVersion(__version__)
|
||||
|
||||
# Set application style
|
||||
app.setStyle("Fusion")
|
||||
|
||||
102
pyproject.toml
Normal file
102
pyproject.toml
Normal file
@@ -0,0 +1,102 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=45", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "microscopy-object-detection"
|
||||
version = "1.0.0"
|
||||
description = "Desktop application for detecting and segmenting organelles in microscopy images using YOLOv8-seg"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
license = { text = "MIT" }
|
||||
authors = [{ name = "Your Name", email = "your.email@example.com" }]
|
||||
keywords = [
|
||||
"microscopy",
|
||||
"yolov8",
|
||||
"object-detection",
|
||||
"segmentation",
|
||||
"computer-vision",
|
||||
"deep-learning",
|
||||
]
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Science/Research",
|
||||
"Topic :: Scientific/Engineering :: Image Recognition",
|
||||
"Topic :: Scientific/Engineering :: Bio-Informatics",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Operating System :: OS Independent",
|
||||
]
|
||||
|
||||
dependencies = [
|
||||
"ultralytics>=8.0.0",
|
||||
"PySide6>=6.5.0",
|
||||
"pyqtgraph>=0.13.0",
|
||||
"numpy>=1.24.0",
|
||||
"opencv-python>=4.8.0",
|
||||
"Pillow>=10.0.0",
|
||||
"PyYAML>=6.0",
|
||||
"pandas>=2.0.0",
|
||||
"openpyxl>=3.1.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=7.0.0",
|
||||
"pytest-cov>=4.0.0",
|
||||
"black>=23.0.0",
|
||||
"pylint>=2.17.0",
|
||||
"mypy>=1.0.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/yourusername/object_detection"
|
||||
Documentation = "https://github.com/yourusername/object_detection/blob/main/README.md"
|
||||
Repository = "https://github.com/yourusername/object_detection"
|
||||
"Bug Tracker" = "https://github.com/yourusername/object_detection/issues"
|
||||
|
||||
[project.scripts]
|
||||
microscopy-detect = "src.cli:main"
|
||||
|
||||
[project.gui-scripts]
|
||||
microscopy-detect-gui = "src.gui_launcher:main"
|
||||
|
||||
[tool.setuptools]
|
||||
packages = [
|
||||
"src",
|
||||
"src.database",
|
||||
"src.model",
|
||||
"src.gui",
|
||||
"src.gui.tabs",
|
||||
"src.gui.dialogs",
|
||||
"src.gui.widgets",
|
||||
"src.utils",
|
||||
]
|
||||
include-package-data = true
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
"src.database" = ["*.sql"]
|
||||
|
||||
[tool.black]
|
||||
line-length = 88
|
||||
target-version = ['py38', 'py39', 'py310', 'py311']
|
||||
include = '\.pyi?$'
|
||||
|
||||
[tool.pylint.messages_control]
|
||||
max-line-length = 88
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.8"
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
disallow_untyped_defs = false
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
python_files = ["test_*.py"]
|
||||
python_functions = ["test_*"]
|
||||
addopts = "-v --cov=src --cov-report=term-missing"
|
||||
56
setup.py
Normal file
56
setup.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Setup script for Microscopy Object Detection Application."""
|
||||
|
||||
from setuptools import setup, find_packages
|
||||
from pathlib import Path
|
||||
|
||||
# Read the contents of README file
|
||||
this_directory = Path(__file__).parent
|
||||
long_description = (this_directory / "README.md").read_text(encoding="utf-8")
|
||||
|
||||
# Read requirements
|
||||
requirements = (this_directory / "requirements.txt").read_text().splitlines()
|
||||
requirements = [
|
||||
req.strip() for req in requirements if req.strip() and not req.startswith("#")
|
||||
]
|
||||
|
||||
setup(
|
||||
name="microscopy-object-detection",
|
||||
version="1.0.0",
|
||||
author="Your Name",
|
||||
author_email="your.email@example.com",
|
||||
description="Desktop application for detecting and segmenting organelles in microscopy images using YOLOv8-seg",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/yourusername/object_detection",
|
||||
packages=find_packages(exclude=["tests", "tests.*", "docs"]),
|
||||
include_package_data=True,
|
||||
install_requires=requirements,
|
||||
python_requires=">=3.8",
|
||||
classifiers=[
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Science/Research",
|
||||
"Topic :: Scientific/Engineering :: Image Recognition",
|
||||
"Topic :: Scientific/Engineering :: Bio-Informatics",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Operating System :: OS Independent",
|
||||
],
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"microscopy-detect=src.cli:main",
|
||||
],
|
||||
"gui_scripts": [
|
||||
"microscopy-detect-gui=src.gui_launcher:main",
|
||||
],
|
||||
},
|
||||
keywords="microscopy yolov8 object-detection segmentation computer-vision deep-learning",
|
||||
project_urls={
|
||||
"Bug Reports": "https://github.com/yourusername/object_detection/issues",
|
||||
"Source": "https://github.com/yourusername/object_detection",
|
||||
"Documentation": "https://github.com/yourusername/object_detection/blob/main/README.md",
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
Microscopy Object Detection Application
|
||||
|
||||
A desktop application for detecting and segmenting organelles and membrane
|
||||
branching structures in microscopy images using YOLOv8-seg.
|
||||
"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "Your Name"
|
||||
__email__ = "your.email@example.com"
|
||||
__license__ = "MIT"
|
||||
|
||||
# Package metadata
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"__author__",
|
||||
"__email__",
|
||||
"__license__",
|
||||
]
|
||||
|
||||
61
src/cli.py
Normal file
61
src/cli.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
Command-line interface for microscopy object detection application.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from src import __version__
|
||||
|
||||
|
||||
def main():
|
||||
"""Main CLI entry point."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Microscopy Object Detection Application - CLI Interface",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Launch GUI
|
||||
microscopy-detect-gui
|
||||
|
||||
# Show version
|
||||
microscopy-detect --version
|
||||
|
||||
# Get help
|
||||
microscopy-detect --help
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--version",
|
||||
action="version",
|
||||
version=f"microscopy-object-detection {__version__}",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--gui",
|
||||
action="store_true",
|
||||
help="Launch the GUI application (same as microscopy-detect-gui)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.gui:
|
||||
# Launch GUI
|
||||
try:
|
||||
from src.gui_launcher import main as gui_main
|
||||
|
||||
gui_main()
|
||||
except Exception as e:
|
||||
print(f"Error launching GUI: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
else:
|
||||
# Show help if no arguments provided
|
||||
parser.print_help()
|
||||
print("\nTo launch the GUI, use: microscopy-detect-gui")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -6,7 +6,7 @@ Handles all database operations including CRUD operations, queries, and exports.
|
||||
import sqlite3
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Optional, Tuple, Any
|
||||
from typing import List, Dict, Optional, Tuple, Any, Union
|
||||
from pathlib import Path
|
||||
import csv
|
||||
import hashlib
|
||||
@@ -30,18 +30,48 @@ class DatabaseManager:
|
||||
# Create directory if it doesn't exist
|
||||
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Read schema file and execute
|
||||
schema_path = Path(__file__).parent / "schema.sql"
|
||||
with open(schema_path, "r") as f:
|
||||
schema_sql = f.read()
|
||||
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
# Check if annotations table needs migration
|
||||
self._migrate_annotations_table(conn)
|
||||
|
||||
# Read schema file and execute
|
||||
schema_path = Path(__file__).parent / "schema.sql"
|
||||
with open(schema_path, "r") as f:
|
||||
schema_sql = f.read()
|
||||
|
||||
conn.executescript(schema_sql)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _migrate_annotations_table(self, conn: sqlite3.Connection) -> None:
|
||||
"""
|
||||
Migrate annotations table from old schema (class_name) to new schema (class_id).
|
||||
"""
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check if annotations table exists
|
||||
cursor.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='annotations'"
|
||||
)
|
||||
if not cursor.fetchone():
|
||||
# Table doesn't exist yet, no migration needed
|
||||
return
|
||||
|
||||
# Check if table has old schema (class_name column)
|
||||
cursor.execute("PRAGMA table_info(annotations)")
|
||||
columns = {row[1]: row for row in cursor.fetchall()}
|
||||
|
||||
if "class_name" in columns and "class_id" not in columns:
|
||||
# Old schema detected, need to migrate
|
||||
print("Migrating annotations table to new schema with class_id...")
|
||||
|
||||
# Drop old annotations table (assuming no critical data since this is a new feature)
|
||||
cursor.execute("DROP TABLE IF EXISTS annotations")
|
||||
conn.commit()
|
||||
print("Old annotations table dropped, will be recreated with new schema")
|
||||
|
||||
def get_connection(self) -> sqlite3.Connection:
|
||||
"""Get database connection with proper settings."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
@@ -56,7 +86,7 @@ class DatabaseManager:
|
||||
model_name: str,
|
||||
model_version: str,
|
||||
model_path: str,
|
||||
base_model: str = "yolov8s.pt",
|
||||
base_model: str = "yolov8s-seg.pt",
|
||||
training_params: Optional[Dict] = None,
|
||||
metrics: Optional[Dict] = None,
|
||||
) -> int:
|
||||
@@ -243,6 +273,7 @@ class DatabaseManager:
|
||||
class_name: str,
|
||||
bbox: Tuple[float, float, float, float], # (x_min, y_min, x_max, y_max)
|
||||
confidence: float,
|
||||
segmentation_mask: Optional[List[List[float]]] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
) -> int:
|
||||
"""
|
||||
@@ -254,6 +285,7 @@ class DatabaseManager:
|
||||
class_name: Detected object class
|
||||
bbox: Bounding box coordinates (normalized 0-1)
|
||||
confidence: Detection confidence score
|
||||
segmentation_mask: Polygon coordinates for segmentation [[x1,y1], [x2,y2], ...]
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
@@ -265,8 +297,8 @@ class DatabaseManager:
|
||||
x_min, y_min, x_max, y_max = bbox
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, segmentation_mask, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
image_id,
|
||||
@@ -277,6 +309,7 @@ class DatabaseManager:
|
||||
x_max,
|
||||
y_max,
|
||||
confidence,
|
||||
json.dumps(segmentation_mask) if segmentation_mask else None,
|
||||
json.dumps(metadata) if metadata else None,
|
||||
),
|
||||
)
|
||||
@@ -302,8 +335,8 @@ class DatabaseManager:
|
||||
bbox = det["bbox"]
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, segmentation_mask, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
det["image_id"],
|
||||
@@ -314,6 +347,11 @@ class DatabaseManager:
|
||||
bbox[2],
|
||||
bbox[3],
|
||||
det["confidence"],
|
||||
(
|
||||
json.dumps(det.get("segmentation_mask"))
|
||||
if det.get("segmentation_mask")
|
||||
else None
|
||||
),
|
||||
(
|
||||
json.dumps(det.get("metadata"))
|
||||
if det.get("metadata")
|
||||
@@ -385,9 +423,11 @@ class DatabaseManager:
|
||||
detections = []
|
||||
for row in cursor.fetchall():
|
||||
det = dict(row)
|
||||
# Parse JSON metadata
|
||||
# Parse JSON fields
|
||||
if det.get("metadata"):
|
||||
det["metadata"] = json.loads(det["metadata"])
|
||||
if det.get("segmentation_mask"):
|
||||
det["segmentation_mask"] = json.loads(det["segmentation_mask"])
|
||||
detections.append(det)
|
||||
|
||||
return detections
|
||||
@@ -538,6 +578,7 @@ class DatabaseManager:
|
||||
"x_max",
|
||||
"y_max",
|
||||
"confidence",
|
||||
"segmentation_mask",
|
||||
"detected_at",
|
||||
]
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
@@ -545,6 +586,11 @@ class DatabaseManager:
|
||||
|
||||
for det in detections:
|
||||
row = {k: det[k] for k in fieldnames if k in det}
|
||||
# Convert segmentation mask list to JSON string for CSV
|
||||
if row.get("segmentation_mask") and isinstance(
|
||||
row["segmentation_mask"], list
|
||||
):
|
||||
row["segmentation_mask"] = json.dumps(row["segmentation_mask"])
|
||||
writer.writerow(row)
|
||||
|
||||
return True
|
||||
@@ -577,22 +623,46 @@ class DatabaseManager:
|
||||
def add_annotation(
|
||||
self,
|
||||
image_id: int,
|
||||
class_name: str,
|
||||
class_id: int,
|
||||
bbox: Tuple[float, float, float, float],
|
||||
annotator: str,
|
||||
segmentation_mask: Optional[List[List[float]]] = None,
|
||||
verified: bool = False,
|
||||
) -> int:
|
||||
"""Add manual annotation."""
|
||||
"""
|
||||
Add manual annotation.
|
||||
|
||||
Args:
|
||||
image_id: ID of the image
|
||||
class_id: ID of the object class (foreign key to object_classes)
|
||||
bbox: Bounding box coordinates (normalized 0-1)
|
||||
annotator: Name of person/tool creating annotation
|
||||
segmentation_mask: Polygon coordinates for segmentation
|
||||
verified: Whether annotation has been verified
|
||||
|
||||
Returns:
|
||||
ID of the inserted annotation
|
||||
"""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
x_min, y_min, x_max, y_max = bbox
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO annotations (image_id, class_name, x_min, y_min, x_max, y_max, annotator, verified)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO annotations (image_id, class_id, x_min, y_min, x_max, y_max, segmentation_mask, annotator, verified)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(image_id, class_name, x_min, y_min, x_max, y_max, annotator, verified),
|
||||
(
|
||||
image_id,
|
||||
class_id,
|
||||
x_min,
|
||||
y_min,
|
||||
x_max,
|
||||
y_max,
|
||||
json.dumps(segmentation_mask) if segmentation_mask else None,
|
||||
annotator,
|
||||
verified,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.lastrowid
|
||||
@@ -600,15 +670,197 @@ class DatabaseManager:
|
||||
conn.close()
|
||||
|
||||
def get_annotations_for_image(self, image_id: int) -> List[Dict]:
|
||||
"""Get all annotations for an image."""
|
||||
"""
|
||||
Get all annotations for an image with class information.
|
||||
|
||||
Args:
|
||||
image_id: ID of the image
|
||||
|
||||
Returns:
|
||||
List of annotation dictionaries with joined class information
|
||||
"""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM annotations WHERE image_id = ?", (image_id,))
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
a.*,
|
||||
c.class_name,
|
||||
c.color as class_color,
|
||||
c.description as class_description
|
||||
FROM annotations a
|
||||
JOIN object_classes c ON a.class_id = c.id
|
||||
WHERE a.image_id = ?
|
||||
ORDER BY a.created_at DESC
|
||||
""",
|
||||
(image_id,),
|
||||
)
|
||||
annotations = []
|
||||
for row in cursor.fetchall():
|
||||
ann = dict(row)
|
||||
if ann.get("segmentation_mask"):
|
||||
ann["segmentation_mask"] = json.loads(ann["segmentation_mask"])
|
||||
annotations.append(ann)
|
||||
return annotations
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def delete_annotation(self, annotation_id: int) -> bool:
|
||||
"""
|
||||
Delete a manual annotation by ID.
|
||||
|
||||
Args:
|
||||
annotation_id: ID of the annotation to delete
|
||||
|
||||
Returns:
|
||||
True if an annotation was deleted, False otherwise.
|
||||
"""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM annotations WHERE id = ?", (annotation_id,))
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# ==================== Object Class Operations ====================
|
||||
|
||||
def get_object_classes(self) -> List[Dict]:
|
||||
"""
|
||||
Get all object classes.
|
||||
|
||||
Returns:
|
||||
List of object class dictionaries
|
||||
"""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM object_classes ORDER BY class_name")
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_object_class_by_id(self, class_id: int) -> Optional[Dict]:
|
||||
"""Get object class by ID."""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM object_classes WHERE id = ?", (class_id,))
|
||||
row = cursor.fetchone()
|
||||
return dict(row) if row else None
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_object_class_by_name(self, class_name: str) -> Optional[Dict]:
|
||||
"""Get object class by name."""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT * FROM object_classes WHERE class_name = ?", (class_name,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return dict(row) if row else None
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def add_object_class(
|
||||
self, class_name: str, color: str, description: Optional[str] = None
|
||||
) -> int:
|
||||
"""
|
||||
Add a new object class.
|
||||
|
||||
Args:
|
||||
class_name: Name of the object class
|
||||
color: Hex color code (e.g., '#FF0000')
|
||||
description: Optional description
|
||||
|
||||
Returns:
|
||||
ID of the inserted object class
|
||||
"""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO object_classes (class_name, color, description)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
(class_name, color, description),
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.lastrowid
|
||||
except sqlite3.IntegrityError:
|
||||
# Class already exists
|
||||
existing = self.get_object_class_by_name(class_name)
|
||||
return existing["id"] if existing else None
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def update_object_class(
|
||||
self,
|
||||
class_id: int,
|
||||
class_name: Optional[str] = None,
|
||||
color: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Update an object class.
|
||||
|
||||
Args:
|
||||
class_id: ID of the class to update
|
||||
class_name: New class name (optional)
|
||||
color: New color (optional)
|
||||
description: New description (optional)
|
||||
|
||||
Returns:
|
||||
True if updated, False otherwise
|
||||
"""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
updates = {}
|
||||
if class_name is not None:
|
||||
updates["class_name"] = class_name
|
||||
if color is not None:
|
||||
updates["color"] = color
|
||||
if description is not None:
|
||||
updates["description"] = description
|
||||
|
||||
if not updates:
|
||||
return False
|
||||
|
||||
set_clauses = [f"{key} = ?" for key in updates.keys()]
|
||||
params = list(updates.values()) + [class_id]
|
||||
|
||||
query = f"UPDATE object_classes SET {', '.join(set_clauses)} WHERE id = ?"
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(query, params)
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def delete_object_class(self, class_id: int) -> bool:
|
||||
"""
|
||||
Delete an object class.
|
||||
|
||||
Args:
|
||||
class_id: ID of the class to delete
|
||||
|
||||
Returns:
|
||||
True if deleted, False otherwise
|
||||
"""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM object_classes WHERE id = ?", (class_id,))
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
@staticmethod
|
||||
def calculate_checksum(file_path: str) -> str:
|
||||
"""Calculate MD5 checksum of a file."""
|
||||
|
||||
@@ -5,7 +5,7 @@ These dataclasses represent the database entities.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Tuple
|
||||
from typing import Optional, Dict, Tuple, List
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -46,6 +46,9 @@ class Detection:
|
||||
class_name: str
|
||||
bbox: Tuple[float, float, float, float] # (x_min, y_min, x_max, y_max)
|
||||
confidence: float
|
||||
segmentation_mask: Optional[
|
||||
List[List[float]]
|
||||
] # List of polygon coordinates [[x1,y1], [x2,y2], ...]
|
||||
detected_at: datetime
|
||||
metadata: Optional[Dict]
|
||||
|
||||
@@ -58,6 +61,9 @@ class Annotation:
|
||||
image_id: int
|
||||
class_name: str
|
||||
bbox: Tuple[float, float, float, float] # (x_min, y_min, x_max, y_max)
|
||||
segmentation_mask: Optional[
|
||||
List[List[float]]
|
||||
] # List of polygon coordinates [[x1,y1], [x2,y2], ...]
|
||||
annotator: str
|
||||
created_at: datetime
|
||||
verified: bool
|
||||
|
||||
@@ -37,25 +37,44 @@ CREATE TABLE IF NOT EXISTS detections (
|
||||
x_max REAL NOT NULL CHECK(x_max >= 0 AND x_max <= 1),
|
||||
y_max REAL NOT NULL CHECK(y_max >= 0 AND y_max <= 1),
|
||||
confidence REAL NOT NULL CHECK(confidence >= 0 AND confidence <= 1),
|
||||
segmentation_mask TEXT, -- JSON string of polygon coordinates [[x1,y1], [x2,y2], ...]
|
||||
detected_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
metadata TEXT, -- JSON string for additional metadata
|
||||
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- Annotations table: stores manual annotations (future feature)
|
||||
-- Object classes table: stores annotation class definitions with colors
|
||||
CREATE TABLE IF NOT EXISTS object_classes (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
class_name TEXT NOT NULL UNIQUE,
|
||||
color TEXT NOT NULL, -- Hex color code (e.g., '#FF0000')
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
description TEXT
|
||||
);
|
||||
|
||||
-- Insert default object classes
|
||||
INSERT OR IGNORE INTO object_classes (class_name, color, description) VALUES
|
||||
('cell', '#FF0000', 'Cell object'),
|
||||
('nucleus', '#00FF00', 'Cell nucleus'),
|
||||
('mitochondria', '#0000FF', 'Mitochondria'),
|
||||
('vesicle', '#FFFF00', 'Vesicle');
|
||||
|
||||
-- Annotations table: stores manual annotations
|
||||
CREATE TABLE IF NOT EXISTS annotations (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
image_id INTEGER NOT NULL,
|
||||
class_name TEXT NOT NULL,
|
||||
class_id INTEGER NOT NULL,
|
||||
x_min REAL NOT NULL CHECK(x_min >= 0 AND x_min <= 1),
|
||||
y_min REAL NOT NULL CHECK(y_min >= 0 AND y_min <= 1),
|
||||
x_max REAL NOT NULL CHECK(x_max >= 0 AND x_max <= 1),
|
||||
y_max REAL NOT NULL CHECK(y_max >= 0 AND y_max <= 1),
|
||||
segmentation_mask TEXT, -- JSON string of polygon coordinates [[x1,y1], [x2,y2], ...]
|
||||
annotator TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
verified BOOLEAN DEFAULT 0,
|
||||
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE
|
||||
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (class_id) REFERENCES object_classes (id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- Create indexes for performance optimization
|
||||
@@ -67,4 +86,6 @@ CREATE INDEX IF NOT EXISTS idx_detections_confidence ON detections(confidence);
|
||||
CREATE INDEX IF NOT EXISTS idx_images_relative_path ON images(relative_path);
|
||||
CREATE INDEX IF NOT EXISTS idx_images_added_at ON images(added_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_annotations_image_id ON annotations(image_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_annotations_class_id ON annotations(class_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_object_classes_class_name ON object_classes(class_name);
|
||||
@@ -121,7 +121,7 @@ class ConfigDialog(QDialog):
|
||||
models_layout.addRow("Models Directory:", self.models_dir_edit)
|
||||
|
||||
self.base_model_edit = QLineEdit()
|
||||
self.base_model_edit.setPlaceholderText("yolov8s.pt")
|
||||
self.base_model_edit.setPlaceholderText("yolov8s-seg.pt")
|
||||
models_layout.addRow("Default Base Model:", self.base_model_edit)
|
||||
|
||||
models_group.setLayout(models_layout)
|
||||
@@ -232,7 +232,7 @@ class ConfigDialog(QDialog):
|
||||
self.config_manager.get("models.models_directory", "data/models")
|
||||
)
|
||||
self.base_model_edit.setText(
|
||||
self.config_manager.get("models.default_base_model", "yolov8s.pt")
|
||||
self.config_manager.get("models.default_base_model", "yolov8s-seg.pt")
|
||||
)
|
||||
|
||||
# Training settings
|
||||
|
||||
@@ -13,7 +13,7 @@ from PySide6.QtWidgets import (
|
||||
QVBoxLayout,
|
||||
QLabel,
|
||||
)
|
||||
from PySide6.QtCore import Qt, QTimer
|
||||
from PySide6.QtCore import Qt, QTimer, QSettings
|
||||
from PySide6.QtGui import QAction, QKeySequence
|
||||
|
||||
from src.database.db_manager import DatabaseManager
|
||||
@@ -52,8 +52,8 @@ class MainWindow(QMainWindow):
|
||||
self._create_tab_widget()
|
||||
self._create_status_bar()
|
||||
|
||||
# Center window on screen
|
||||
self._center_window()
|
||||
# Restore window geometry or center window on screen
|
||||
self._restore_window_state()
|
||||
|
||||
logger.info("Main window initialized")
|
||||
|
||||
@@ -156,6 +156,24 @@ class MainWindow(QMainWindow):
|
||||
(screen.width() - size.width()) // 2, (screen.height() - size.height()) // 2
|
||||
)
|
||||
|
||||
def _restore_window_state(self):
|
||||
"""Restore window geometry from settings or center window."""
|
||||
settings = QSettings("microscopy_app", "object_detection")
|
||||
geometry = settings.value("main_window/geometry")
|
||||
|
||||
if geometry:
|
||||
self.restoreGeometry(geometry)
|
||||
logger.debug("Restored window geometry from settings")
|
||||
else:
|
||||
self._center_window()
|
||||
logger.debug("Centered window on screen")
|
||||
|
||||
def _save_window_state(self):
|
||||
"""Save window geometry to settings."""
|
||||
settings = QSettings("microscopy_app", "object_detection")
|
||||
settings.setValue("main_window/geometry", self.saveGeometry())
|
||||
logger.debug("Saved window geometry to settings")
|
||||
|
||||
def _show_settings(self):
|
||||
"""Show settings dialog."""
|
||||
logger.info("Opening settings dialog")
|
||||
@@ -276,6 +294,13 @@ class MainWindow(QMainWindow):
|
||||
)
|
||||
|
||||
if reply == QMessageBox.Yes:
|
||||
# Save window state before closing
|
||||
self._save_window_state()
|
||||
|
||||
# Save annotation tab state if it exists
|
||||
if hasattr(self, "annotation_tab"):
|
||||
self.annotation_tab.save_state()
|
||||
|
||||
logger.info("Application closing")
|
||||
event.accept()
|
||||
else:
|
||||
|
||||
@@ -1,16 +1,33 @@
|
||||
"""
|
||||
Annotation tab for the microscopy object detection application.
|
||||
Future feature for manual annotation.
|
||||
Manual annotation with pen tool and object class management.
|
||||
"""
|
||||
|
||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
|
||||
from PySide6.QtWidgets import (
|
||||
QWidget,
|
||||
QVBoxLayout,
|
||||
QHBoxLayout,
|
||||
QLabel,
|
||||
QGroupBox,
|
||||
QPushButton,
|
||||
QFileDialog,
|
||||
QMessageBox,
|
||||
QSplitter,
|
||||
)
|
||||
from PySide6.QtCore import Qt, QSettings
|
||||
from pathlib import Path
|
||||
|
||||
from src.database.db_manager import DatabaseManager
|
||||
from src.utils.config_manager import ConfigManager
|
||||
from src.utils.image import Image, ImageLoadError
|
||||
from src.utils.logger import get_logger
|
||||
from src.gui.widgets import AnnotationCanvasWidget, AnnotationToolsWidget
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AnnotationTab(QWidget):
|
||||
"""Annotation tab placeholder (future feature)."""
|
||||
"""Annotation tab for manual image annotation."""
|
||||
|
||||
def __init__(
|
||||
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
||||
@@ -18,6 +35,12 @@ class AnnotationTab(QWidget):
|
||||
super().__init__(parent)
|
||||
self.db_manager = db_manager
|
||||
self.config_manager = config_manager
|
||||
self.current_image = None
|
||||
self.current_image_path = None
|
||||
self.current_image_id = None
|
||||
self.current_annotations = []
|
||||
# IDs of annotations currently selected on the canvas (multi-select)
|
||||
self.selected_annotation_ids = []
|
||||
|
||||
self._setup_ui()
|
||||
|
||||
@@ -25,24 +48,519 @@ class AnnotationTab(QWidget):
|
||||
"""Setup user interface."""
|
||||
layout = QVBoxLayout()
|
||||
|
||||
group = QGroupBox("Annotation Tool (Future Feature)")
|
||||
group_layout = QVBoxLayout()
|
||||
label = QLabel(
|
||||
"Annotation functionality will be implemented in future version.\n\n"
|
||||
"Planned Features:\n"
|
||||
"- Image browser\n"
|
||||
"- Drawing tools for bounding boxes\n"
|
||||
"- Class label assignment\n"
|
||||
"- Export annotations to YOLO format\n"
|
||||
"- Annotation verification"
|
||||
)
|
||||
group_layout.addWidget(label)
|
||||
group.setLayout(group_layout)
|
||||
# Main horizontal splitter to divide left (image) and right (controls)
|
||||
self.main_splitter = QSplitter(Qt.Horizontal)
|
||||
self.main_splitter.setHandleWidth(10)
|
||||
|
||||
layout.addWidget(group)
|
||||
layout.addStretch()
|
||||
# { Left splitter for image display and zoom info
|
||||
self.left_splitter = QSplitter(Qt.Vertical)
|
||||
self.left_splitter.setHandleWidth(10)
|
||||
|
||||
# Annotation canvas section
|
||||
canvas_group = QGroupBox("Annotation Canvas")
|
||||
canvas_layout = QVBoxLayout()
|
||||
|
||||
# Use the AnnotationCanvasWidget
|
||||
self.annotation_canvas = AnnotationCanvasWidget()
|
||||
self.annotation_canvas.zoom_changed.connect(self._on_zoom_changed)
|
||||
self.annotation_canvas.annotation_drawn.connect(self._on_annotation_drawn)
|
||||
# Selection of existing polylines (when tool is not in drawing mode)
|
||||
self.annotation_canvas.annotation_selected.connect(self._on_annotation_selected)
|
||||
canvas_layout.addWidget(self.annotation_canvas)
|
||||
|
||||
canvas_group.setLayout(canvas_layout)
|
||||
self.left_splitter.addWidget(canvas_group)
|
||||
|
||||
# Controls info
|
||||
controls_info = QLabel(
|
||||
"Zoom: Mouse wheel or +/- keys | Drawing: Enable pen and drag mouse"
|
||||
)
|
||||
controls_info.setStyleSheet("QLabel { color: #888; font-style: italic; }")
|
||||
self.left_splitter.addWidget(controls_info)
|
||||
# }
|
||||
|
||||
# { Right splitter for annotation tools and controls
|
||||
self.right_splitter = QSplitter(Qt.Vertical)
|
||||
self.right_splitter.setHandleWidth(10)
|
||||
|
||||
# Annotation tools section
|
||||
self.annotation_tools = AnnotationToolsWidget(self.db_manager)
|
||||
self.annotation_tools.polyline_enabled_changed.connect(
|
||||
self.annotation_canvas.set_polyline_enabled
|
||||
)
|
||||
self.annotation_tools.polyline_pen_color_changed.connect(
|
||||
self.annotation_canvas.set_polyline_pen_color
|
||||
)
|
||||
self.annotation_tools.polyline_pen_width_changed.connect(
|
||||
self.annotation_canvas.set_polyline_pen_width
|
||||
)
|
||||
# Show / hide bounding boxes
|
||||
self.annotation_tools.show_bboxes_changed.connect(
|
||||
self.annotation_canvas.set_show_bboxes
|
||||
)
|
||||
# RDP simplification controls
|
||||
self.annotation_tools.simplify_on_finish_changed.connect(
|
||||
self._on_simplify_on_finish_changed
|
||||
)
|
||||
self.annotation_tools.simplify_epsilon_changed.connect(
|
||||
self._on_simplify_epsilon_changed
|
||||
)
|
||||
# Class selection and class-color changes
|
||||
self.annotation_tools.class_selected.connect(self._on_class_selected)
|
||||
self.annotation_tools.class_color_changed.connect(self._on_class_color_changed)
|
||||
self.annotation_tools.clear_annotations_requested.connect(
|
||||
self._on_clear_annotations
|
||||
)
|
||||
# Delete selected annotation on canvas
|
||||
self.annotation_tools.delete_selected_annotation_requested.connect(
|
||||
self._on_delete_selected_annotation
|
||||
)
|
||||
self.right_splitter.addWidget(self.annotation_tools)
|
||||
|
||||
# Image loading section
|
||||
load_group = QGroupBox("Image Loading")
|
||||
load_layout = QVBoxLayout()
|
||||
|
||||
# Load image button
|
||||
button_layout = QHBoxLayout()
|
||||
self.load_image_btn = QPushButton("Load Image")
|
||||
self.load_image_btn.clicked.connect(self._load_image)
|
||||
button_layout.addWidget(self.load_image_btn)
|
||||
button_layout.addStretch()
|
||||
load_layout.addLayout(button_layout)
|
||||
|
||||
# Image info label
|
||||
self.image_info_label = QLabel("No image loaded")
|
||||
load_layout.addWidget(self.image_info_label)
|
||||
|
||||
load_group.setLayout(load_layout)
|
||||
self.right_splitter.addWidget(load_group)
|
||||
# }
|
||||
|
||||
# Add both splitters to the main horizontal splitter
|
||||
self.main_splitter.addWidget(self.left_splitter)
|
||||
self.main_splitter.addWidget(self.right_splitter)
|
||||
|
||||
# Set initial sizes: 75% for left (image), 25% for right (controls)
|
||||
self.main_splitter.setSizes([750, 250])
|
||||
|
||||
layout.addWidget(self.main_splitter)
|
||||
self.setLayout(layout)
|
||||
|
||||
# Restore splitter positions from settings
|
||||
self._restore_state()
|
||||
|
||||
def _load_image(self):
|
||||
"""Load and display an image file."""
|
||||
# Get last opened directory from QSettings
|
||||
settings = QSettings("microscopy_app", "object_detection")
|
||||
last_dir = settings.value("annotation_tab/last_directory", None)
|
||||
|
||||
# Fallback to image repository path or home directory
|
||||
if last_dir and Path(last_dir).exists():
|
||||
start_dir = last_dir
|
||||
else:
|
||||
repo_path = self.config_manager.get_image_repository_path()
|
||||
start_dir = repo_path if repo_path else str(Path.home())
|
||||
|
||||
# Open file dialog
|
||||
file_path, _ = QFileDialog.getOpenFileName(
|
||||
self,
|
||||
"Select Image",
|
||||
start_dir,
|
||||
"Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)",
|
||||
)
|
||||
|
||||
if not file_path:
|
||||
return
|
||||
|
||||
try:
|
||||
# Load image using Image class
|
||||
self.current_image = Image(file_path)
|
||||
self.current_image_path = file_path
|
||||
|
||||
# Store the directory for next time
|
||||
settings.setValue(
|
||||
"annotation_tab/last_directory", str(Path(file_path).parent)
|
||||
)
|
||||
|
||||
# Get or create image in database
|
||||
relative_path = str(Path(file_path).name) # Simplified for now
|
||||
self.current_image_id = self.db_manager.get_or_create_image(
|
||||
relative_path,
|
||||
Path(file_path).name,
|
||||
self.current_image.width,
|
||||
self.current_image.height,
|
||||
)
|
||||
|
||||
# Display image using the AnnotationCanvasWidget
|
||||
self.annotation_canvas.load_image(self.current_image)
|
||||
|
||||
# Load and display any existing annotations for this image
|
||||
self._load_annotations_for_current_image()
|
||||
|
||||
# Update info label
|
||||
self._update_image_info()
|
||||
|
||||
logger.info(f"Loaded image: {file_path} (DB ID: {self.current_image_id})")
|
||||
|
||||
except ImageLoadError as e:
|
||||
logger.error(f"Failed to load image: {e}")
|
||||
QMessageBox.critical(
|
||||
self, "Error Loading Image", f"Failed to load image:\n{str(e)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error loading image: {e}")
|
||||
QMessageBox.critical(self, "Error", f"Unexpected error:\n{str(e)}")
|
||||
|
||||
def _update_image_info(self):
|
||||
"""Update the image info label with current image details."""
|
||||
if self.current_image is None:
|
||||
self.image_info_label.setText("No image loaded")
|
||||
return
|
||||
|
||||
zoom_percentage = self.annotation_canvas.get_zoom_percentage()
|
||||
info_text = (
|
||||
f"File: {Path(self.current_image_path).name}\n"
|
||||
f"Size: {self.current_image.width}x{self.current_image.height} pixels\n"
|
||||
f"Channels: {self.current_image.channels}\n"
|
||||
f"Data type: {self.current_image.dtype}\n"
|
||||
f"Format: {self.current_image.format.upper()}\n"
|
||||
f"File size: {self.current_image.size_mb:.2f} MB\n"
|
||||
f"Zoom: {zoom_percentage}%"
|
||||
)
|
||||
self.image_info_label.setText(info_text)
|
||||
|
||||
def _on_zoom_changed(self, zoom_scale: float):
|
||||
"""Handle zoom level changes from the annotation canvas."""
|
||||
self._update_image_info()
|
||||
|
||||
def _on_annotation_drawn(self, points: list):
|
||||
"""
|
||||
Handle when an annotation stroke is drawn.
|
||||
|
||||
Saves the new annotation directly to the database and refreshes the
|
||||
on-canvas display of annotations for the current image.
|
||||
"""
|
||||
# Ensure we have an image loaded and in the DB
|
||||
if not self.current_image or not self.current_image_id:
|
||||
logger.warning("Annotation drawn but no image loaded")
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"No Image",
|
||||
"Please load an image before drawing annotations.",
|
||||
)
|
||||
return
|
||||
|
||||
current_class = self.annotation_tools.get_current_class()
|
||||
|
||||
if not current_class:
|
||||
logger.warning("Annotation drawn but no object class selected")
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"No Class Selected",
|
||||
"Please select an object class before drawing annotations.",
|
||||
)
|
||||
return
|
||||
|
||||
if not points:
|
||||
logger.warning("Annotation drawn with no points, ignoring")
|
||||
return
|
||||
|
||||
# points are [(x_norm, y_norm), ...]
|
||||
xs = [p[0] for p in points]
|
||||
ys = [p[1] for p in points]
|
||||
x_min, x_max = min(xs), max(xs)
|
||||
y_min, y_max = min(ys), max(ys)
|
||||
|
||||
# Store segmentation mask in [y_norm, x_norm] format to match DB
|
||||
db_polyline = [[float(y), float(x)] for (x, y) in points]
|
||||
|
||||
try:
|
||||
annotation_id = self.db_manager.add_annotation(
|
||||
image_id=self.current_image_id,
|
||||
class_id=current_class["id"],
|
||||
bbox=(x_min, y_min, x_max, y_max),
|
||||
annotator="manual",
|
||||
segmentation_mask=db_polyline,
|
||||
verified=False,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Saved annotation (ID: {annotation_id}) for class "
|
||||
f"'{current_class['class_name']}' "
|
||||
f"Bounding box: ({x_min:.3f}, {y_min:.3f}) to ({x_max:.3f}, {y_max:.3f})\n"
|
||||
f"with {len(points)} polyline points"
|
||||
)
|
||||
|
||||
# Reload annotations from DB and redraw (respecting current class filter)
|
||||
self._load_annotations_for_current_image()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save annotation: {e}")
|
||||
QMessageBox.critical(self, "Error", f"Failed to save annotation:\n{str(e)}")
|
||||
|
||||
def _on_annotation_selected(self, annotation_ids):
|
||||
"""
|
||||
Handle selection of existing annotations on the canvas.
|
||||
|
||||
Args:
|
||||
annotation_ids: List of selected annotation IDs, or None/empty if cleared.
|
||||
"""
|
||||
if not annotation_ids:
|
||||
self.selected_annotation_ids = []
|
||||
self.annotation_tools.set_has_selected_annotation(False)
|
||||
logger.debug("Annotation selection cleared on canvas")
|
||||
return
|
||||
|
||||
# Normalize to a unique, sorted list of integer IDs
|
||||
ids = sorted({int(aid) for aid in annotation_ids if isinstance(aid, int)})
|
||||
self.selected_annotation_ids = ids
|
||||
self.annotation_tools.set_has_selected_annotation(bool(ids))
|
||||
logger.debug(f"Annotations selected on canvas: IDs={ids}")
|
||||
|
||||
def _on_simplify_on_finish_changed(self, enabled: bool):
|
||||
"""Update canvas simplify-on-finish flag from tools widget."""
|
||||
self.annotation_canvas.simplify_on_finish = enabled
|
||||
logger.debug(f"Annotation simplification on finish set to {enabled}")
|
||||
|
||||
def _on_simplify_epsilon_changed(self, epsilon: float):
|
||||
"""Update canvas RDP epsilon from tools widget."""
|
||||
self.annotation_canvas.simplify_epsilon = float(epsilon)
|
||||
logger.debug(f"Annotation simplification epsilon set to {epsilon}")
|
||||
|
||||
def _on_class_color_changed(self):
|
||||
"""
|
||||
Handle changes to the selected object's class color.
|
||||
|
||||
When the user updates a class color in the tools widget, reload the
|
||||
annotations for the current image so that all polylines are redrawn
|
||||
using the updated per-class colors.
|
||||
"""
|
||||
if not self.current_image_id:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Class color changed; reloading annotations for image ID {self.current_image_id}"
|
||||
)
|
||||
self._load_annotations_for_current_image()
|
||||
|
||||
def _on_class_selected(self, class_data):
|
||||
"""
|
||||
Handle when an object class is selected or cleared.
|
||||
|
||||
When a specific class is selected, only annotations of that class are drawn.
|
||||
When the selection is cleared ("-- Select Class --"), all annotations are shown.
|
||||
"""
|
||||
if class_data:
|
||||
logger.debug(f"Object class selected: {class_data['class_name']}")
|
||||
else:
|
||||
logger.debug(
|
||||
'No class selected ("-- Select Class --"), showing all annotations'
|
||||
)
|
||||
|
||||
# Changing the class filter invalidates any previous selection
|
||||
self.selected_annotation_ids = []
|
||||
self.annotation_tools.set_has_selected_annotation(False)
|
||||
|
||||
# Whenever the selection changes, update which annotations are visible
|
||||
self._redraw_annotations_for_current_filter()
|
||||
|
||||
def _on_clear_annotations(self):
|
||||
"""Handle clearing all annotations."""
|
||||
self.annotation_canvas.clear_annotations()
|
||||
# Clear in-memory state and selection, but keep DB entries unchanged
|
||||
self.current_annotations = []
|
||||
self.selected_annotation_ids = []
|
||||
self.annotation_tools.set_has_selected_annotation(False)
|
||||
logger.info("Cleared all annotations")
|
||||
|
||||
def _on_delete_selected_annotation(self):
|
||||
"""Handle deleting the currently selected annotation(s) (if any)."""
|
||||
if not self.selected_annotation_ids:
|
||||
QMessageBox.information(
|
||||
self,
|
||||
"No Selection",
|
||||
"No annotation is currently selected.",
|
||||
)
|
||||
return
|
||||
|
||||
count = len(self.selected_annotation_ids)
|
||||
if count == 1:
|
||||
question = "Are you sure you want to delete the selected annotation?"
|
||||
title = "Delete Annotation"
|
||||
else:
|
||||
question = (
|
||||
f"Are you sure you want to delete the {count} selected annotations?"
|
||||
)
|
||||
title = "Delete Annotations"
|
||||
|
||||
reply = QMessageBox.question(
|
||||
self,
|
||||
title,
|
||||
question,
|
||||
QMessageBox.Yes | QMessageBox.No,
|
||||
QMessageBox.No,
|
||||
)
|
||||
if reply != QMessageBox.Yes:
|
||||
return
|
||||
|
||||
failed_ids = []
|
||||
try:
|
||||
for ann_id in self.selected_annotation_ids:
|
||||
try:
|
||||
deleted = self.db_manager.delete_annotation(ann_id)
|
||||
if not deleted:
|
||||
failed_ids.append(ann_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete annotation ID {ann_id}: {e}")
|
||||
failed_ids.append(ann_id)
|
||||
|
||||
if failed_ids:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"Partial Failure",
|
||||
"Some annotations could not be deleted:\n"
|
||||
+ ", ".join(str(a) for a in failed_ids),
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Deleted {count} annotation(s): "
|
||||
+ ", ".join(str(a) for a in self.selected_annotation_ids)
|
||||
)
|
||||
|
||||
# Clear selection and reload annotations for the current image from DB
|
||||
self.selected_annotation_ids = []
|
||||
self.annotation_tools.set_has_selected_annotation(False)
|
||||
self._load_annotations_for_current_image()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete annotations: {e}")
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Error",
|
||||
f"Failed to delete annotations:\n{str(e)}",
|
||||
)
|
||||
|
||||
def _load_annotations_for_current_image(self):
|
||||
"""
|
||||
Load all annotations for the current image from the database and
|
||||
redraw them on the canvas, honoring the currently selected class
|
||||
filter (if any).
|
||||
"""
|
||||
if not self.current_image_id:
|
||||
self.current_annotations = []
|
||||
self.annotation_canvas.clear_annotations()
|
||||
self.selected_annotation_ids = []
|
||||
self.annotation_tools.set_has_selected_annotation(False)
|
||||
return
|
||||
|
||||
try:
|
||||
self.current_annotations = self.db_manager.get_annotations_for_image(
|
||||
self.current_image_id
|
||||
)
|
||||
# New annotations loaded; reset any selection
|
||||
self.selected_annotation_ids = []
|
||||
self.annotation_tools.set_has_selected_annotation(False)
|
||||
self._redraw_annotations_for_current_filter()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to load annotations for image {self.current_image_id}: {e}"
|
||||
)
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Error",
|
||||
f"Failed to load annotations for this image:\n{str(e)}",
|
||||
)
|
||||
|
||||
def _redraw_annotations_for_current_filter(self):
|
||||
"""
|
||||
Redraw annotations for the current image, optionally filtered by the
|
||||
currently selected object class.
|
||||
"""
|
||||
# Clear current on-canvas annotations but keep the image
|
||||
self.annotation_canvas.clear_annotations()
|
||||
|
||||
if not self.current_annotations:
|
||||
return
|
||||
|
||||
current_class = self.annotation_tools.get_current_class()
|
||||
selected_class_id = current_class["id"] if current_class else None
|
||||
|
||||
drawn_count = 0
|
||||
for ann in self.current_annotations:
|
||||
# Filter by class if one is selected
|
||||
if (
|
||||
selected_class_id is not None
|
||||
and ann.get("class_id") != selected_class_id
|
||||
):
|
||||
continue
|
||||
|
||||
if ann.get("segmentation_mask"):
|
||||
polyline = ann["segmentation_mask"]
|
||||
color = ann.get("class_color", "#FF0000")
|
||||
|
||||
self.annotation_canvas.draw_saved_polyline(
|
||||
polyline,
|
||||
color,
|
||||
width=3,
|
||||
annotation_id=ann["id"],
|
||||
)
|
||||
self.annotation_canvas.draw_saved_bbox(
|
||||
[ann["x_min"], ann["y_min"], ann["x_max"], ann["y_max"]],
|
||||
color,
|
||||
width=3,
|
||||
)
|
||||
drawn_count += 1
|
||||
|
||||
logger.info(
|
||||
f"Displayed {drawn_count} annotation(s) for current image with "
|
||||
f"{'no class filter' if selected_class_id is None else f'class_id={selected_class_id}'}"
|
||||
)
|
||||
|
||||
def _restore_state(self):
|
||||
"""Restore splitter positions from settings."""
|
||||
settings = QSettings("microscopy_app", "object_detection")
|
||||
|
||||
# Restore main splitter state
|
||||
main_state = settings.value("annotation_tab/main_splitter_state")
|
||||
if main_state:
|
||||
self.main_splitter.restoreState(main_state)
|
||||
logger.debug("Restored main splitter state")
|
||||
|
||||
# Restore left splitter state
|
||||
left_state = settings.value("annotation_tab/left_splitter_state")
|
||||
if left_state:
|
||||
self.left_splitter.restoreState(left_state)
|
||||
logger.debug("Restored left splitter state")
|
||||
|
||||
# Restore right splitter state
|
||||
right_state = settings.value("annotation_tab/right_splitter_state")
|
||||
if right_state:
|
||||
self.right_splitter.restoreState(right_state)
|
||||
logger.debug("Restored right splitter state")
|
||||
|
||||
def save_state(self):
|
||||
"""Save splitter positions to settings."""
|
||||
settings = QSettings("microscopy_app", "object_detection")
|
||||
|
||||
# Save main splitter state
|
||||
settings.setValue(
|
||||
"annotation_tab/main_splitter_state", self.main_splitter.saveState()
|
||||
)
|
||||
|
||||
# Save left splitter state
|
||||
settings.setValue(
|
||||
"annotation_tab/left_splitter_state", self.left_splitter.saveState()
|
||||
)
|
||||
|
||||
# Save right splitter state
|
||||
settings.setValue(
|
||||
"annotation_tab/right_splitter_state", self.right_splitter.saveState()
|
||||
)
|
||||
|
||||
logger.debug("Saved annotation tab splitter states")
|
||||
|
||||
def refresh(self):
|
||||
"""Refresh the tab."""
|
||||
pass
|
||||
|
||||
@@ -159,7 +159,7 @@ class DetectionTab(QWidget):
|
||||
|
||||
# Add base model option
|
||||
base_model = self.config_manager.get(
|
||||
"models.default_base_model", "yolov8s.pt"
|
||||
"models.default_base_model", "yolov8s-seg.pt"
|
||||
)
|
||||
self.model_combo.addItem(
|
||||
f"Base Model ({base_model})", {"id": 0, "path": base_model}
|
||||
@@ -256,7 +256,7 @@ class DetectionTab(QWidget):
|
||||
if model_id == 0:
|
||||
# Create database entry for base model
|
||||
base_model = self.config_manager.get(
|
||||
"models.default_base_model", "yolov8s.pt"
|
||||
"models.default_base_model", "yolov8s-seg.pt"
|
||||
)
|
||||
model_id = self.db_manager.add_model(
|
||||
model_name="Base Model",
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
"""GUI widgets for the microscopy object detection application."""
|
||||
|
||||
from src.gui.widgets.image_display_widget import ImageDisplayWidget
|
||||
from src.gui.widgets.annotation_canvas_widget import AnnotationCanvasWidget
|
||||
from src.gui.widgets.annotation_tools_widget import AnnotationToolsWidget
|
||||
|
||||
__all__ = ["ImageDisplayWidget", "AnnotationCanvasWidget", "AnnotationToolsWidget"]
|
||||
|
||||
887
src/gui/widgets/annotation_canvas_widget.py
Normal file
887
src/gui/widgets/annotation_canvas_widget.py
Normal file
@@ -0,0 +1,887 @@
|
||||
"""
|
||||
Annotation canvas widget for drawing annotations on images.
|
||||
Currently supports polyline drawing tool with color selection for manual annotation.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QScrollArea
|
||||
from PySide6.QtGui import (
|
||||
QPixmap,
|
||||
QImage,
|
||||
QPainter,
|
||||
QPen,
|
||||
QColor,
|
||||
QKeyEvent,
|
||||
QMouseEvent,
|
||||
QPaintEvent,
|
||||
)
|
||||
from PySide6.QtCore import Qt, QEvent, Signal, QPoint
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from src.utils.image import Image, ImageLoadError
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def perpendicular_distance(
|
||||
point: Tuple[float, float],
|
||||
start: Tuple[float, float],
|
||||
end: Tuple[float, float],
|
||||
) -> float:
|
||||
"""Perpendicular distance from `point` to the line defined by `start`->`end`."""
|
||||
(x, y), (x1, y1), (x2, y2) = point, start, end
|
||||
dx = x2 - x1
|
||||
dy = y2 - y1
|
||||
if dx == 0.0 and dy == 0.0:
|
||||
return math.hypot(x - x1, y - y1)
|
||||
num = abs(dy * x - dx * y + x2 * y1 - y2 * x1)
|
||||
den = math.hypot(dx, dy)
|
||||
return num / den
|
||||
|
||||
|
||||
def rdp(points: List[Tuple[float, float]], epsilon: float) -> List[Tuple[float, float]]:
|
||||
"""
|
||||
Recursive Ramer-Douglas-Peucker (RDP) polyline simplification.
|
||||
|
||||
Args:
|
||||
points: List of (x, y) points.
|
||||
epsilon: Maximum allowed perpendicular distance in pixels.
|
||||
|
||||
Returns:
|
||||
Simplified list of (x, y) points including first and last.
|
||||
"""
|
||||
if len(points) <= 2:
|
||||
return list(points)
|
||||
|
||||
start = points[0]
|
||||
end = points[-1]
|
||||
max_dist = -1.0
|
||||
index = -1
|
||||
|
||||
for i in range(1, len(points) - 1):
|
||||
d = perpendicular_distance(points[i], start, end)
|
||||
if d > max_dist:
|
||||
max_dist = d
|
||||
index = i
|
||||
|
||||
if max_dist > epsilon:
|
||||
# Recursive split
|
||||
left = rdp(points[: index + 1], epsilon)
|
||||
right = rdp(points[index:], epsilon)
|
||||
# Concatenate but avoid duplicate at split point
|
||||
return left[:-1] + right
|
||||
|
||||
# Keep only start and end
|
||||
return [start, end]
|
||||
|
||||
|
||||
def simplify_polyline(
|
||||
points: List[Tuple[float, float]], epsilon: float
|
||||
) -> List[Tuple[float, float]]:
|
||||
"""
|
||||
Simplify a polyline with RDP while preserving closure semantics.
|
||||
|
||||
If the polyline is closed (first == last), the duplicate last point is removed
|
||||
before simplification and then re-added after simplification.
|
||||
"""
|
||||
if not points:
|
||||
return []
|
||||
|
||||
pts = [(float(x), float(y)) for x, y in points]
|
||||
closed = False
|
||||
|
||||
if len(pts) >= 2 and pts[0] == pts[-1]:
|
||||
closed = True
|
||||
pts = pts[:-1] # remove duplicate last for simplification
|
||||
|
||||
if len(pts) <= 2:
|
||||
simplified = list(pts)
|
||||
else:
|
||||
simplified = rdp(pts, epsilon)
|
||||
|
||||
if closed and simplified:
|
||||
if simplified[0] != simplified[-1]:
|
||||
simplified.append(simplified[0])
|
||||
|
||||
return simplified
|
||||
|
||||
|
||||
class AnnotationCanvasWidget(QWidget):
|
||||
"""
|
||||
Widget for displaying images and drawing annotations with zoom and drawing tools.
|
||||
|
||||
Features:
|
||||
- Display images with zoom functionality
|
||||
- Polyline tool for drawing annotations
|
||||
- Configurable pen color and width
|
||||
- Mouse-based drawing interface
|
||||
- Zoom in/out with mouse wheel and keyboard
|
||||
|
||||
Signals:
|
||||
zoom_changed: Emitted when zoom level changes (float zoom_scale)
|
||||
annotation_drawn: Emitted when a new stroke is completed (list of points)
|
||||
"""
|
||||
|
||||
zoom_changed = Signal(float)
|
||||
annotation_drawn = Signal(list) # List of (x, y) points in normalized coordinates
|
||||
# Emitted when the user selects an existing polyline on the canvas.
|
||||
# Carries the associated annotation_id (int) or None if selection is cleared
|
||||
annotation_selected = Signal(object)
|
||||
|
||||
def __init__(self, parent=None):
|
||||
"""Initialize the annotation canvas widget."""
|
||||
super().__init__(parent)
|
||||
|
||||
self.current_image = None
|
||||
self.original_pixmap = None
|
||||
self.annotation_pixmap = None # Overlay for annotations
|
||||
self.zoom_scale = 1.0
|
||||
self.zoom_min = 0.1
|
||||
self.zoom_max = 10.0
|
||||
self.zoom_step = 0.1
|
||||
self.zoom_wheel_step = 0.15
|
||||
|
||||
# Drawing / interaction state
|
||||
self.is_drawing = False
|
||||
self.polyline_enabled = False
|
||||
self.polyline_pen_color = QColor(255, 0, 0, 128) # Default red with 50% alpha
|
||||
self.polyline_pen_width = 3
|
||||
self.show_bboxes: bool = True # Control visibility of bounding boxes
|
||||
|
||||
# Current stroke and stored polylines (in image coordinates, pixel units)
|
||||
self.current_stroke: List[Tuple[float, float]] = []
|
||||
self.polylines: List[List[Tuple[float, float]]] = []
|
||||
self.stroke_meta: List[Dict[str, Any]] = [] # per-polyline style (color, width)
|
||||
# Optional DB annotation_id for each stored polyline (None for temporary / unsaved)
|
||||
self.polyline_annotation_ids: List[Optional[int]] = []
|
||||
# Indices in self.polylines of the currently selected polylines (multi-select)
|
||||
self.selected_polyline_indices: List[int] = []
|
||||
|
||||
# Stored bounding boxes in normalized coordinates (x_min, y_min, x_max, y_max)
|
||||
self.bboxes: List[List[float]] = []
|
||||
self.bbox_meta: List[Dict[str, Any]] = [] # per-bbox style (color, width)
|
||||
|
||||
# Legacy collection of strokes in normalized coordinates (kept for API compatibility)
|
||||
self.all_strokes: List[dict] = []
|
||||
|
||||
# RDP simplification parameters (in pixels)
|
||||
self.simplify_on_finish: bool = True
|
||||
self.simplify_epsilon: float = 2.0
|
||||
self.sample_threshold: float = 2.0 # minimum movement to sample a new point
|
||||
|
||||
self._setup_ui()
|
||||
|
||||
def _setup_ui(self):
|
||||
"""Setup user interface."""
|
||||
layout = QVBoxLayout()
|
||||
layout.setContentsMargins(0, 0, 0, 0)
|
||||
|
||||
# Scroll area for canvas
|
||||
self.scroll_area = QScrollArea()
|
||||
self.scroll_area.setWidgetResizable(True)
|
||||
self.scroll_area.setMinimumHeight(400)
|
||||
|
||||
self.canvas_label = QLabel("No image loaded")
|
||||
self.canvas_label.setAlignment(Qt.AlignCenter)
|
||||
self.canvas_label.setStyleSheet(
|
||||
"QLabel { background-color: #2b2b2b; color: #888; }"
|
||||
)
|
||||
self.canvas_label.setScaledContents(False)
|
||||
self.canvas_label.setMouseTracking(True)
|
||||
|
||||
self.scroll_area.setWidget(self.canvas_label)
|
||||
self.scroll_area.viewport().installEventFilter(self)
|
||||
|
||||
layout.addWidget(self.scroll_area)
|
||||
self.setLayout(layout)
|
||||
|
||||
self.setFocusPolicy(Qt.StrongFocus)
|
||||
|
||||
def load_image(self, image: Image):
|
||||
"""
|
||||
Load and display an image.
|
||||
|
||||
Args:
|
||||
image: Image object to display
|
||||
"""
|
||||
self.current_image = image
|
||||
self.zoom_scale = 1.0
|
||||
self.clear_annotations()
|
||||
self._display_image()
|
||||
logger.debug(
|
||||
f"Loaded image into annotation canvas: {image.width}x{image.height}"
|
||||
)
|
||||
|
||||
def clear(self):
|
||||
"""Clear the displayed image and all annotations."""
|
||||
self.current_image = None
|
||||
self.original_pixmap = None
|
||||
self.annotation_pixmap = None
|
||||
self.zoom_scale = 1.0
|
||||
self.clear_annotations()
|
||||
self.canvas_label.setText("No image loaded")
|
||||
self.canvas_label.setPixmap(QPixmap())
|
||||
|
||||
def clear_annotations(self):
|
||||
"""Clear all drawn annotations."""
|
||||
self.all_strokes = []
|
||||
self.current_stroke = []
|
||||
self.polylines = []
|
||||
self.stroke_meta = []
|
||||
self.polyline_annotation_ids = []
|
||||
self.selected_polyline_indices = []
|
||||
self.bboxes = []
|
||||
self.bbox_meta = []
|
||||
self.is_drawing = False
|
||||
if self.annotation_pixmap:
|
||||
self.annotation_pixmap.fill(Qt.transparent)
|
||||
self._update_display()
|
||||
|
||||
def _display_image(self):
|
||||
"""Display the current image in the canvas."""
|
||||
if self.current_image is None:
|
||||
return
|
||||
|
||||
try:
|
||||
# Get RGB image data
|
||||
if self.current_image.channels == 3:
|
||||
image_data = self.current_image.get_rgb()
|
||||
height, width, channels = image_data.shape
|
||||
else:
|
||||
image_data = self.current_image.get_grayscale()
|
||||
height, width = image_data.shape
|
||||
|
||||
image_data = np.ascontiguousarray(image_data)
|
||||
bytes_per_line = image_data.strides[0]
|
||||
|
||||
qimage = QImage(
|
||||
image_data.data,
|
||||
width,
|
||||
height,
|
||||
bytes_per_line,
|
||||
self.current_image.qtimage_format,
|
||||
)
|
||||
|
||||
self.original_pixmap = QPixmap.fromImage(qimage)
|
||||
|
||||
# Create transparent overlay for annotations
|
||||
self.annotation_pixmap = QPixmap(self.original_pixmap.size())
|
||||
self.annotation_pixmap.fill(Qt.transparent)
|
||||
|
||||
self._apply_zoom()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error displaying image: {e}")
|
||||
raise ImageLoadError(f"Failed to display image: {str(e)}")
|
||||
|
||||
def _apply_zoom(self):
|
||||
"""Apply current zoom level to the displayed image."""
|
||||
if self.original_pixmap is None:
|
||||
return
|
||||
|
||||
scaled_width = int(self.original_pixmap.width() * self.zoom_scale)
|
||||
scaled_height = int(self.original_pixmap.height() * self.zoom_scale)
|
||||
|
||||
# Scale both image and annotations
|
||||
scaled_image = self.original_pixmap.scaled(
|
||||
scaled_width,
|
||||
scaled_height,
|
||||
Qt.KeepAspectRatio,
|
||||
(
|
||||
Qt.SmoothTransformation
|
||||
if self.zoom_scale >= 1.0
|
||||
else Qt.FastTransformation
|
||||
),
|
||||
)
|
||||
|
||||
scaled_annotations = self.annotation_pixmap.scaled(
|
||||
scaled_width,
|
||||
scaled_height,
|
||||
Qt.KeepAspectRatio,
|
||||
(
|
||||
Qt.SmoothTransformation
|
||||
if self.zoom_scale >= 1.0
|
||||
else Qt.FastTransformation
|
||||
),
|
||||
)
|
||||
|
||||
# Composite image and annotations
|
||||
combined = QPixmap(scaled_image.size())
|
||||
painter = QPainter(combined)
|
||||
painter.drawPixmap(0, 0, scaled_image)
|
||||
painter.drawPixmap(0, 0, scaled_annotations)
|
||||
painter.end()
|
||||
|
||||
self.canvas_label.setPixmap(combined)
|
||||
self.canvas_label.setScaledContents(False)
|
||||
self.canvas_label.adjustSize()
|
||||
|
||||
self.zoom_changed.emit(self.zoom_scale)
|
||||
|
||||
def _update_display(self):
|
||||
"""Update display after drawing."""
|
||||
self._apply_zoom()
|
||||
|
||||
def set_polyline_enabled(self, enabled: bool):
|
||||
"""Enable or disable polyline tool."""
|
||||
self.polyline_enabled = enabled
|
||||
if enabled:
|
||||
self.canvas_label.setCursor(Qt.CrossCursor)
|
||||
else:
|
||||
self.canvas_label.setCursor(Qt.ArrowCursor)
|
||||
|
||||
def set_polyline_pen_color(self, color: QColor):
|
||||
"""Set polyline pen color."""
|
||||
self.polyline_pen_color = color
|
||||
|
||||
def set_polyline_pen_width(self, width: int):
|
||||
"""Set polyline pen width."""
|
||||
self.polyline_pen_width = max(1, width)
|
||||
|
||||
def get_zoom_percentage(self) -> int:
|
||||
"""Get current zoom level as percentage."""
|
||||
return int(self.zoom_scale * 100)
|
||||
|
||||
def zoom_in(self):
|
||||
"""Zoom in on the image."""
|
||||
if self.original_pixmap is None:
|
||||
return
|
||||
new_scale = self.zoom_scale + self.zoom_step
|
||||
if new_scale <= self.zoom_max:
|
||||
self.zoom_scale = new_scale
|
||||
self._apply_zoom()
|
||||
|
||||
def zoom_out(self):
|
||||
"""Zoom out from the image."""
|
||||
if self.original_pixmap is None:
|
||||
return
|
||||
new_scale = self.zoom_scale - self.zoom_step
|
||||
if new_scale >= self.zoom_min:
|
||||
self.zoom_scale = new_scale
|
||||
self._apply_zoom()
|
||||
|
||||
def reset_zoom(self):
|
||||
"""Reset zoom to 100%."""
|
||||
if self.original_pixmap is None:
|
||||
return
|
||||
self.zoom_scale = 1.0
|
||||
self._apply_zoom()
|
||||
|
||||
def _canvas_to_image_coords(self, pos: QPoint) -> Optional[Tuple[int, int]]:
|
||||
"""Convert canvas coordinates to image coordinates, accounting for zoom and centering."""
|
||||
if self.original_pixmap is None or self.canvas_label.pixmap() is None:
|
||||
return None
|
||||
|
||||
# Get the displayed pixmap size (after zoom)
|
||||
displayed_pixmap = self.canvas_label.pixmap()
|
||||
displayed_width = displayed_pixmap.width()
|
||||
displayed_height = displayed_pixmap.height()
|
||||
|
||||
# Calculate offset due to label centering (label might be larger than pixmap)
|
||||
label_width = self.canvas_label.width()
|
||||
label_height = self.canvas_label.height()
|
||||
offset_x = max(0, (label_width - displayed_width) // 2)
|
||||
offset_y = max(0, (label_height - displayed_height) // 2)
|
||||
|
||||
# Adjust position for offset and convert to image coordinates
|
||||
x = (pos.x() - offset_x) / self.zoom_scale
|
||||
y = (pos.y() - offset_y) / self.zoom_scale
|
||||
|
||||
# Check bounds
|
||||
if (
|
||||
0 <= x < self.original_pixmap.width()
|
||||
and 0 <= y < self.original_pixmap.height()
|
||||
):
|
||||
return (int(x), int(y))
|
||||
return None
|
||||
|
||||
def _find_polyline_at(
|
||||
self, img_x: float, img_y: float, threshold_px: float = 5.0
|
||||
) -> Optional[int]:
|
||||
"""
|
||||
Find index of polyline whose geometry is within threshold_px of (img_x, img_y).
|
||||
Returns the index in self.polylines, or None if none is close enough.
|
||||
"""
|
||||
best_index: Optional[int] = None
|
||||
best_dist: float = float("inf")
|
||||
|
||||
for idx, polyline in enumerate(self.polylines):
|
||||
if len(polyline) < 2:
|
||||
continue
|
||||
|
||||
# Quick bounding-box check to skip obviously distant polylines
|
||||
xs = [p[0] for p in polyline]
|
||||
ys = [p[1] for p in polyline]
|
||||
if img_x < min(xs) - threshold_px or img_x > max(xs) + threshold_px:
|
||||
continue
|
||||
if img_y < min(ys) - threshold_px or img_y > max(ys) + threshold_px:
|
||||
continue
|
||||
|
||||
# Precise distance to all segments
|
||||
for (x1, y1), (x2, y2) in zip(polyline[:-1], polyline[1:]):
|
||||
d = perpendicular_distance(
|
||||
(img_x, img_y), (float(x1), float(y1)), (float(x2), float(y2))
|
||||
)
|
||||
if d < best_dist:
|
||||
best_dist = d
|
||||
best_index = idx
|
||||
|
||||
if best_index is not None and best_dist <= threshold_px:
|
||||
return best_index
|
||||
return None
|
||||
|
||||
def _image_to_normalized_coords(self, x: int, y: int) -> Tuple[float, float]:
|
||||
"""Convert image coordinates to normalized coordinates (0-1)."""
|
||||
if self.original_pixmap is None:
|
||||
return (0.0, 0.0)
|
||||
|
||||
norm_x = x / self.original_pixmap.width()
|
||||
norm_y = y / self.original_pixmap.height()
|
||||
return (norm_x, norm_y)
|
||||
|
||||
def _add_polyline(
|
||||
self,
|
||||
img_points: List[Tuple[float, float]],
|
||||
color: QColor,
|
||||
width: int,
|
||||
annotation_id: Optional[int] = None,
|
||||
):
|
||||
"""Store a polyline in image coordinates and redraw annotations."""
|
||||
if not img_points or len(img_points) < 2:
|
||||
return
|
||||
|
||||
# Ensure all points are tuples of floats
|
||||
normalized_points = [(float(x), float(y)) for x, y in img_points]
|
||||
self.polylines.append(normalized_points)
|
||||
self.stroke_meta.append({"color": QColor(color), "width": int(width)})
|
||||
self.polyline_annotation_ids.append(annotation_id)
|
||||
|
||||
self._redraw_annotations()
|
||||
|
||||
def _redraw_annotations(self):
|
||||
"""Redraw all stored polylines and (optionally) bounding boxes onto the annotation pixmap."""
|
||||
if self.annotation_pixmap is None:
|
||||
return
|
||||
|
||||
# Clear existing overlay
|
||||
self.annotation_pixmap.fill(Qt.transparent)
|
||||
|
||||
painter = QPainter(self.annotation_pixmap)
|
||||
|
||||
# Draw polylines
|
||||
for idx, (polyline, meta) in enumerate(zip(self.polylines, self.stroke_meta)):
|
||||
pen_color: QColor = meta.get("color", self.polyline_pen_color)
|
||||
width: int = meta.get("width", self.polyline_pen_width)
|
||||
|
||||
if idx in self.selected_polyline_indices:
|
||||
# Highlight selected polylines in a distinct color / width
|
||||
highlight_color = QColor(255, 255, 0, 200) # yellow, semi-opaque
|
||||
pen = QPen(
|
||||
highlight_color,
|
||||
width + 1,
|
||||
Qt.SolidLine,
|
||||
Qt.RoundCap,
|
||||
Qt.RoundJoin,
|
||||
)
|
||||
else:
|
||||
pen = QPen(
|
||||
pen_color,
|
||||
width,
|
||||
Qt.SolidLine,
|
||||
Qt.RoundCap,
|
||||
Qt.RoundJoin,
|
||||
)
|
||||
|
||||
painter.setPen(pen)
|
||||
for (x1, y1), (x2, y2) in zip(polyline[:-1], polyline[1:]):
|
||||
painter.drawLine(int(x1), int(y1), int(x2), int(y2))
|
||||
|
||||
# Draw bounding boxes (dashed) if enabled
|
||||
if self.show_bboxes and self.original_pixmap is not None and self.bboxes:
|
||||
img_width = float(self.original_pixmap.width())
|
||||
img_height = float(self.original_pixmap.height())
|
||||
|
||||
for bbox, meta in zip(self.bboxes, self.bbox_meta):
|
||||
if len(bbox) != 4:
|
||||
continue
|
||||
|
||||
x_min_norm, y_min_norm, x_max_norm, y_max_norm = bbox
|
||||
x_min = int(x_min_norm * img_width)
|
||||
y_min = int(y_min_norm * img_height)
|
||||
x_max = int(x_max_norm * img_width)
|
||||
y_max = int(y_max_norm * img_height)
|
||||
|
||||
rect_width = x_max - x_min
|
||||
rect_height = y_max - y_min
|
||||
|
||||
pen_color: QColor = meta.get("color", QColor(255, 0, 0, 128))
|
||||
width: int = meta.get("width", self.polyline_pen_width)
|
||||
pen = QPen(
|
||||
pen_color,
|
||||
width,
|
||||
Qt.DashLine,
|
||||
Qt.SquareCap,
|
||||
Qt.MiterJoin,
|
||||
)
|
||||
painter.setPen(pen)
|
||||
painter.drawRect(x_min, y_min, rect_width, rect_height)
|
||||
|
||||
painter.end()
|
||||
|
||||
self._update_display()
|
||||
|
||||
def mousePressEvent(self, event: QMouseEvent):
|
||||
"""Handle mouse press events for drawing and selecting polylines."""
|
||||
if self.annotation_pixmap is None:
|
||||
super().mousePressEvent(event)
|
||||
return
|
||||
|
||||
# Map click to image coordinates
|
||||
label_pos = self.canvas_label.mapFromGlobal(event.globalPos())
|
||||
img_coords = self._canvas_to_image_coords(label_pos)
|
||||
|
||||
# Left button + drawing tool enabled -> start a new stroke
|
||||
if event.button() == Qt.LeftButton and self.polyline_enabled:
|
||||
if img_coords:
|
||||
self.is_drawing = True
|
||||
self.current_stroke = [(float(img_coords[0]), float(img_coords[1]))]
|
||||
return
|
||||
|
||||
# Left button + drawing tool disabled -> attempt selection of existing polyline
|
||||
if event.button() == Qt.LeftButton and not self.polyline_enabled:
|
||||
if img_coords:
|
||||
idx = self._find_polyline_at(float(img_coords[0]), float(img_coords[1]))
|
||||
if idx is not None:
|
||||
if event.modifiers() & Qt.ShiftModifier:
|
||||
# Multi-select mode: add to current selection (if not already selected)
|
||||
if idx not in self.selected_polyline_indices:
|
||||
self.selected_polyline_indices.append(idx)
|
||||
else:
|
||||
# Single-select mode: replace current selection
|
||||
self.selected_polyline_indices = [idx]
|
||||
|
||||
# Build list of selected annotation IDs (ignore None entries)
|
||||
selected_ids: List[int] = []
|
||||
for sel_idx in self.selected_polyline_indices:
|
||||
if 0 <= sel_idx < len(self.polyline_annotation_ids):
|
||||
ann_id = self.polyline_annotation_ids[sel_idx]
|
||||
if isinstance(ann_id, int):
|
||||
selected_ids.append(ann_id)
|
||||
|
||||
if selected_ids:
|
||||
self.annotation_selected.emit(selected_ids)
|
||||
else:
|
||||
# No valid DB-backed annotations in selection
|
||||
self.annotation_selected.emit(None)
|
||||
else:
|
||||
# Clicked on empty space -> clear selection
|
||||
self.selected_polyline_indices = []
|
||||
self.annotation_selected.emit(None)
|
||||
|
||||
self._redraw_annotations()
|
||||
return
|
||||
|
||||
# Fallback for other buttons / cases
|
||||
super().mousePressEvent(event)
|
||||
|
||||
def mouseMoveEvent(self, event: QMouseEvent):
|
||||
"""Handle mouse move events for drawing."""
|
||||
if (
|
||||
not self.is_drawing
|
||||
or not self.polyline_enabled
|
||||
or self.annotation_pixmap is None
|
||||
):
|
||||
super().mouseMoveEvent(event)
|
||||
return
|
||||
|
||||
# Get accurate position using global coordinates
|
||||
label_pos = self.canvas_label.mapFromGlobal(event.globalPos())
|
||||
img_coords = self._canvas_to_image_coords(label_pos)
|
||||
|
||||
if img_coords and len(self.current_stroke) > 0:
|
||||
last_point = self.current_stroke[-1]
|
||||
dx = img_coords[0] - last_point[0]
|
||||
dy = img_coords[1] - last_point[1]
|
||||
|
||||
# Only sample a new point if we moved enough pixels
|
||||
if math.hypot(dx, dy) < self.sample_threshold:
|
||||
return
|
||||
|
||||
# Draw line from last point to current point for interactive feedback
|
||||
painter = QPainter(self.annotation_pixmap)
|
||||
pen = QPen(
|
||||
self.polyline_pen_color,
|
||||
self.polyline_pen_width,
|
||||
Qt.SolidLine,
|
||||
Qt.RoundCap,
|
||||
Qt.RoundJoin,
|
||||
)
|
||||
painter.setPen(pen)
|
||||
painter.drawLine(
|
||||
int(last_point[0]),
|
||||
int(last_point[1]),
|
||||
int(img_coords[0]),
|
||||
int(img_coords[1]),
|
||||
)
|
||||
painter.end()
|
||||
|
||||
self.current_stroke.append((float(img_coords[0]), float(img_coords[1])))
|
||||
self._update_display()
|
||||
|
||||
def mouseReleaseEvent(self, event: QMouseEvent):
|
||||
"""Handle mouse release events to complete a stroke."""
|
||||
if not self.is_drawing or event.button() != Qt.LeftButton:
|
||||
super().mouseReleaseEvent(event)
|
||||
return
|
||||
|
||||
self.is_drawing = False
|
||||
|
||||
if len(self.current_stroke) > 1 and self.original_pixmap is not None:
|
||||
# Ensure the stroke is closed by connecting end -> start
|
||||
raw_points = list(self.current_stroke)
|
||||
if raw_points[0] != raw_points[-1]:
|
||||
raw_points.append(raw_points[0])
|
||||
|
||||
# Optional RDP simplification (in image pixel space)
|
||||
if self.simplify_on_finish:
|
||||
simplified = simplify_polyline(raw_points, self.simplify_epsilon)
|
||||
else:
|
||||
simplified = raw_points
|
||||
|
||||
if len(simplified) >= 2:
|
||||
# Store polyline and redraw all annotations
|
||||
self._add_polyline(
|
||||
simplified, self.polyline_pen_color, self.polyline_pen_width
|
||||
)
|
||||
|
||||
# Convert to normalized coordinates for metadata + signal
|
||||
normalized_stroke = [
|
||||
self._image_to_normalized_coords(int(x), int(y))
|
||||
for (x, y) in simplified
|
||||
]
|
||||
self.all_strokes.append(
|
||||
{
|
||||
"points": normalized_stroke,
|
||||
"color": self.polyline_pen_color.name(),
|
||||
"alpha": self.polyline_pen_color.alpha(),
|
||||
"width": self.polyline_pen_width,
|
||||
}
|
||||
)
|
||||
|
||||
# Emit signal with normalized coordinates
|
||||
self.annotation_drawn.emit(normalized_stroke)
|
||||
logger.debug(
|
||||
f"Completed stroke with {len(simplified)} points "
|
||||
f"(normalized len={len(normalized_stroke)})"
|
||||
)
|
||||
|
||||
self.current_stroke = []
|
||||
|
||||
def get_all_strokes(self) -> List[dict]:
|
||||
"""Get all drawn strokes with metadata."""
|
||||
return self.all_strokes
|
||||
|
||||
def get_annotation_parameters(self) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
Get all annotation parameters including bounding box and polyline.
|
||||
|
||||
Returns:
|
||||
List of dictionaries, each containing:
|
||||
- 'bbox': [x_min, y_min, x_max, y_max] in normalized image coordinates
|
||||
- 'polyline': List of [y_norm, x_norm] points describing the polygon
|
||||
"""
|
||||
if self.original_pixmap is None or not self.polylines:
|
||||
return None
|
||||
|
||||
img_width = float(self.original_pixmap.width())
|
||||
img_height = float(self.original_pixmap.height())
|
||||
|
||||
params: List[Dict[str, Any]] = []
|
||||
|
||||
for idx, polyline in enumerate(self.polylines):
|
||||
if len(polyline) < 2:
|
||||
continue
|
||||
|
||||
xs = [p[0] for p in polyline]
|
||||
ys = [p[1] for p in polyline]
|
||||
|
||||
x_min_norm = min(xs) / img_width
|
||||
x_max_norm = max(xs) / img_width
|
||||
y_min_norm = min(ys) / img_height
|
||||
y_max_norm = max(ys) / img_height
|
||||
|
||||
# Store polyline as [y_norm, x_norm] to match DB convention and
|
||||
# the expectations of draw_saved_polyline().
|
||||
normalized_polyline = [
|
||||
[y / img_height, x / img_width] for (x, y) in polyline
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
f"Polyline {idx}: {len(polyline)} points, "
|
||||
f"bbox=({x_min_norm:.3f}, {y_min_norm:.3f})-({x_max_norm:.3f}, {y_max_norm:.3f})"
|
||||
)
|
||||
|
||||
params.append(
|
||||
{
|
||||
"bbox": [x_min_norm, y_min_norm, x_max_norm, y_max_norm],
|
||||
"polyline": normalized_polyline,
|
||||
}
|
||||
)
|
||||
|
||||
return params or None
|
||||
|
||||
def draw_saved_polyline(
|
||||
self,
|
||||
polyline: List[List[float]],
|
||||
color: str,
|
||||
width: int = 3,
|
||||
annotation_id: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Draw a polyline from database coordinates onto the annotation canvas.
|
||||
|
||||
Args:
|
||||
polyline: List of [x, y] coordinate pairs in normalized coordinates (0-1)
|
||||
color: Color hex string (e.g., '#FF0000')
|
||||
width: Line width in pixels
|
||||
"""
|
||||
if not self.annotation_pixmap or not self.original_pixmap:
|
||||
logger.warning("Cannot draw polyline: no image loaded")
|
||||
return
|
||||
|
||||
if len(polyline) < 2:
|
||||
logger.warning("Polyline has less than 2 points, cannot draw")
|
||||
return
|
||||
|
||||
# Convert normalized coordinates to image coordinates
|
||||
# Polyline is stored as [[y_norm, x_norm], ...] (row_norm, col_norm format)
|
||||
img_width = self.original_pixmap.width()
|
||||
img_height = self.original_pixmap.height()
|
||||
|
||||
logger.debug(f"Loading polyline with {len(polyline)} points")
|
||||
logger.debug(f" Image size: {img_width}x{img_height}")
|
||||
logger.debug(f" First 3 normalized points from DB: {polyline[:3]}")
|
||||
|
||||
img_coords: List[Tuple[float, float]] = []
|
||||
for y_norm, x_norm in polyline:
|
||||
x = float(x_norm * img_width)
|
||||
y = float(y_norm * img_height)
|
||||
img_coords.append((x, y))
|
||||
|
||||
logger.debug(f" First 3 pixel coords: {img_coords[:3]}")
|
||||
|
||||
# Store and redraw using common pipeline
|
||||
pen_color = QColor(color)
|
||||
pen_color.setAlpha(128) # Add semi-transparency
|
||||
self._add_polyline(img_coords, pen_color, width, annotation_id=annotation_id)
|
||||
|
||||
# Store in all_strokes for consistency (uses normalized coordinates)
|
||||
self.all_strokes.append(
|
||||
{"points": polyline, "color": color, "alpha": 128, "width": width}
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Drew saved polyline with {len(polyline)} points in color {color}"
|
||||
)
|
||||
|
||||
def draw_saved_bbox(self, bbox: List[float], color: str, width: int = 3):
|
||||
"""
|
||||
Draw a bounding box from database coordinates onto the annotation canvas.
|
||||
|
||||
Args:
|
||||
bbox: Bounding box as [x_min_norm, y_min_norm, x_max_norm, y_max_norm]
|
||||
in normalized coordinates (0-1)
|
||||
color: Color hex string (e.g., '#FF0000')
|
||||
width: Line width in pixels
|
||||
"""
|
||||
if not self.annotation_pixmap or not self.original_pixmap:
|
||||
logger.warning("Cannot draw bounding box: no image loaded")
|
||||
return
|
||||
|
||||
if len(bbox) != 4:
|
||||
logger.warning(
|
||||
f"Invalid bounding box format: expected 4 values, got {len(bbox)}"
|
||||
)
|
||||
return
|
||||
|
||||
# Convert normalized coordinates to image coordinates (for logging/debug)
|
||||
img_width = self.original_pixmap.width()
|
||||
img_height = self.original_pixmap.height()
|
||||
|
||||
x_min_norm, y_min_norm, x_max_norm, y_max_norm = bbox
|
||||
x_min = int(x_min_norm * img_width)
|
||||
y_min = int(y_min_norm * img_height)
|
||||
x_max = int(x_max_norm * img_width)
|
||||
y_max = int(y_max_norm * img_height)
|
||||
|
||||
logger.debug(f"Drawing bounding box: {bbox}")
|
||||
logger.debug(f" Image size: {img_width}x{img_height}")
|
||||
logger.debug(f" Pixel coords: ({x_min}, {y_min}) to ({x_max}, {y_max})")
|
||||
|
||||
# Store bounding box (normalized) and its style; actual drawing happens
|
||||
# in _redraw_annotations() together with all polylines.
|
||||
pen_color = QColor(color)
|
||||
pen_color.setAlpha(128) # Add semi-transparency
|
||||
self.bboxes.append(
|
||||
[float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)]
|
||||
)
|
||||
self.bbox_meta.append({"color": pen_color, "width": int(width)})
|
||||
|
||||
# Store in all_strokes for consistency
|
||||
self.all_strokes.append(
|
||||
{"bbox": bbox, "color": color, "alpha": 128, "width": width}
|
||||
)
|
||||
|
||||
# Redraw overlay (polylines + all bounding boxes)
|
||||
self._redraw_annotations()
|
||||
logger.debug(f"Drew saved bounding box in color {color}")
|
||||
|
||||
def set_show_bboxes(self, show: bool):
|
||||
"""
|
||||
Enable or disable drawing of bounding boxes.
|
||||
|
||||
Args:
|
||||
show: If True, draw bounding boxes; if False, hide them.
|
||||
"""
|
||||
self.show_bboxes = bool(show)
|
||||
logger.debug(f"Set show_bboxes to {self.show_bboxes}")
|
||||
self._redraw_annotations()
|
||||
|
||||
def keyPressEvent(self, event: QKeyEvent):
|
||||
"""Handle keyboard events for zooming."""
|
||||
if event.key() in (Qt.Key_Plus, Qt.Key_Equal):
|
||||
self.zoom_in()
|
||||
event.accept()
|
||||
elif event.key() == Qt.Key_Minus:
|
||||
self.zoom_out()
|
||||
event.accept()
|
||||
elif event.key() == Qt.Key_0 and event.modifiers() == Qt.ControlModifier:
|
||||
self.reset_zoom()
|
||||
event.accept()
|
||||
else:
|
||||
super().keyPressEvent(event)
|
||||
|
||||
def eventFilter(self, obj, event: QEvent) -> bool:
|
||||
"""Event filter to capture wheel events for zooming."""
|
||||
if event.type() == QEvent.Wheel:
|
||||
wheel_event = event
|
||||
if self.original_pixmap is not None:
|
||||
delta = wheel_event.angleDelta().y()
|
||||
|
||||
if delta > 0:
|
||||
new_scale = self.zoom_scale + self.zoom_wheel_step
|
||||
if new_scale <= self.zoom_max:
|
||||
self.zoom_scale = new_scale
|
||||
self._apply_zoom()
|
||||
else:
|
||||
new_scale = self.zoom_scale - self.zoom_wheel_step
|
||||
if new_scale >= self.zoom_min:
|
||||
self.zoom_scale = new_scale
|
||||
self._apply_zoom()
|
||||
|
||||
return True
|
||||
|
||||
return super().eventFilter(obj, event)
|
||||
478
src/gui/widgets/annotation_tools_widget.py
Normal file
478
src/gui/widgets/annotation_tools_widget.py
Normal file
@@ -0,0 +1,478 @@
|
||||
"""
|
||||
Annotation tools widget for controlling annotation parameters.
|
||||
Includes polyline tool, color picker, class selection, and annotation management.
|
||||
"""
|
||||
|
||||
from PySide6.QtWidgets import (
|
||||
QWidget,
|
||||
QVBoxLayout,
|
||||
QHBoxLayout,
|
||||
QLabel,
|
||||
QGroupBox,
|
||||
QPushButton,
|
||||
QComboBox,
|
||||
QSpinBox,
|
||||
QDoubleSpinBox,
|
||||
QCheckBox,
|
||||
QColorDialog,
|
||||
QInputDialog,
|
||||
QMessageBox,
|
||||
)
|
||||
from PySide6.QtGui import QColor, QIcon, QPixmap, QPainter
|
||||
from PySide6.QtCore import Qt, Signal
|
||||
from typing import Optional, Dict
|
||||
|
||||
from src.database.db_manager import DatabaseManager
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AnnotationToolsWidget(QWidget):
|
||||
"""
|
||||
Widget for annotation tool controls.
|
||||
|
||||
Features:
|
||||
- Enable/disable polyline tool
|
||||
- Color selection for polyline pen
|
||||
- Object class selection
|
||||
- Add new object classes
|
||||
- Pen width control
|
||||
- Clear annotations
|
||||
|
||||
Signals:
|
||||
polyline_enabled_changed: Emitted when polyline tool is enabled/disabled (bool)
|
||||
polyline_pen_color_changed: Emitted when polyline pen color changes (QColor)
|
||||
polyline_pen_width_changed: Emitted when polyline pen width changes (int)
|
||||
class_selected: Emitted when object class is selected (dict)
|
||||
clear_annotations_requested: Emitted when clear button is pressed
|
||||
"""
|
||||
|
||||
polyline_enabled_changed = Signal(bool)
|
||||
polyline_pen_color_changed = Signal(QColor)
|
||||
polyline_pen_width_changed = Signal(int)
|
||||
simplify_on_finish_changed = Signal(bool)
|
||||
simplify_epsilon_changed = Signal(float)
|
||||
# Toggle visibility of bounding boxes on the canvas
|
||||
show_bboxes_changed = Signal(bool)
|
||||
class_selected = Signal(dict)
|
||||
class_color_changed = Signal()
|
||||
clear_annotations_requested = Signal()
|
||||
# Request deletion of the currently selected annotation on the canvas
|
||||
delete_selected_annotation_requested = Signal()
|
||||
|
||||
def __init__(self, db_manager: DatabaseManager, parent=None):
|
||||
"""
|
||||
Initialize annotation tools widget.
|
||||
|
||||
Args:
|
||||
db_manager: Database manager instance
|
||||
parent: Parent widget
|
||||
"""
|
||||
super().__init__(parent)
|
||||
self.db_manager = db_manager
|
||||
self.polyline_enabled = False
|
||||
self.current_color = QColor(255, 0, 0, 128) # Red with 50% alpha
|
||||
self.current_class = None
|
||||
|
||||
self._setup_ui()
|
||||
self._load_object_classes()
|
||||
|
||||
def _setup_ui(self):
|
||||
"""Setup user interface."""
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# Polyline Tool Group
|
||||
polyline_group = QGroupBox("Polyline Tool")
|
||||
polyline_layout = QVBoxLayout()
|
||||
|
||||
# Enable/Disable polyline tool
|
||||
button_layout = QHBoxLayout()
|
||||
self.polyline_toggle_btn = QPushButton("Start Drawing Polyline")
|
||||
self.polyline_toggle_btn.setCheckable(True)
|
||||
self.polyline_toggle_btn.clicked.connect(self._on_polyline_toggle)
|
||||
button_layout.addWidget(self.polyline_toggle_btn)
|
||||
polyline_layout.addLayout(button_layout)
|
||||
|
||||
# Polyline pen width control
|
||||
width_layout = QHBoxLayout()
|
||||
width_layout.addWidget(QLabel("Pen Width:"))
|
||||
self.polyline_pen_width_spin = QSpinBox()
|
||||
self.polyline_pen_width_spin.setMinimum(1)
|
||||
self.polyline_pen_width_spin.setMaximum(20)
|
||||
self.polyline_pen_width_spin.setValue(3)
|
||||
self.polyline_pen_width_spin.valueChanged.connect(
|
||||
self._on_polyline_pen_width_changed
|
||||
)
|
||||
width_layout.addWidget(self.polyline_pen_width_spin)
|
||||
width_layout.addStretch()
|
||||
polyline_layout.addLayout(width_layout)
|
||||
|
||||
# Simplification controls (RDP)
|
||||
simplify_layout = QHBoxLayout()
|
||||
self.simplify_checkbox = QCheckBox("Simplify on finish")
|
||||
self.simplify_checkbox.setChecked(True)
|
||||
self.simplify_checkbox.stateChanged.connect(self._on_simplify_toggle)
|
||||
simplify_layout.addWidget(self.simplify_checkbox)
|
||||
|
||||
simplify_layout.addWidget(QLabel("epsilon (px):"))
|
||||
self.eps_spin = QDoubleSpinBox()
|
||||
self.eps_spin.setRange(0.0, 1000.0)
|
||||
self.eps_spin.setSingleStep(0.5)
|
||||
self.eps_spin.setValue(2.0)
|
||||
self.eps_spin.valueChanged.connect(self._on_eps_change)
|
||||
simplify_layout.addWidget(self.eps_spin)
|
||||
simplify_layout.addStretch()
|
||||
polyline_layout.addLayout(simplify_layout)
|
||||
|
||||
polyline_group.setLayout(polyline_layout)
|
||||
layout.addWidget(polyline_group)
|
||||
|
||||
# Object Class Group
|
||||
class_group = QGroupBox("Object Class")
|
||||
class_layout = QVBoxLayout()
|
||||
|
||||
# Class selection dropdown
|
||||
self.class_combo = QComboBox()
|
||||
self.class_combo.currentIndexChanged.connect(self._on_class_selected)
|
||||
class_layout.addWidget(self.class_combo)
|
||||
|
||||
# Add / manage classes
|
||||
class_button_layout = QHBoxLayout()
|
||||
self.add_class_btn = QPushButton("Add New Class")
|
||||
self.add_class_btn.clicked.connect(self._on_add_class)
|
||||
class_button_layout.addWidget(self.add_class_btn)
|
||||
|
||||
self.refresh_classes_btn = QPushButton("Refresh")
|
||||
self.refresh_classes_btn.clicked.connect(self._load_object_classes)
|
||||
class_button_layout.addWidget(self.refresh_classes_btn)
|
||||
class_layout.addLayout(class_button_layout)
|
||||
|
||||
# Class color (associated with selected object class)
|
||||
color_layout = QHBoxLayout()
|
||||
color_layout.addWidget(QLabel("Class Color:"))
|
||||
self.color_btn = QPushButton()
|
||||
self.color_btn.setFixedSize(40, 30)
|
||||
self.color_btn.clicked.connect(self._on_color_picker)
|
||||
self._update_color_button()
|
||||
color_layout.addWidget(self.color_btn)
|
||||
color_layout.addStretch()
|
||||
class_layout.addLayout(color_layout)
|
||||
|
||||
# Selected class info
|
||||
self.class_info_label = QLabel("No class selected")
|
||||
self.class_info_label.setWordWrap(True)
|
||||
self.class_info_label.setStyleSheet(
|
||||
"QLabel { color: #888; font-style: italic; }"
|
||||
)
|
||||
class_layout.addWidget(self.class_info_label)
|
||||
|
||||
class_group.setLayout(class_layout)
|
||||
layout.addWidget(class_group)
|
||||
|
||||
# Actions Group
|
||||
actions_group = QGroupBox("Actions")
|
||||
actions_layout = QVBoxLayout()
|
||||
|
||||
# Show / hide bounding boxes
|
||||
self.show_bboxes_checkbox = QCheckBox("Show bounding boxes")
|
||||
self.show_bboxes_checkbox.setChecked(True)
|
||||
self.show_bboxes_checkbox.stateChanged.connect(self._on_show_bboxes_toggle)
|
||||
actions_layout.addWidget(self.show_bboxes_checkbox)
|
||||
|
||||
self.clear_btn = QPushButton("Clear All Annotations")
|
||||
self.clear_btn.clicked.connect(self._on_clear_annotations)
|
||||
actions_layout.addWidget(self.clear_btn)
|
||||
|
||||
# Delete currently selected annotation (enabled when a selection exists)
|
||||
self.delete_selected_btn = QPushButton("Delete Selected Annotation")
|
||||
self.delete_selected_btn.clicked.connect(self._on_delete_selected_annotation)
|
||||
self.delete_selected_btn.setEnabled(False)
|
||||
actions_layout.addWidget(self.delete_selected_btn)
|
||||
|
||||
actions_group.setLayout(actions_layout)
|
||||
layout.addWidget(actions_group)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def _update_color_button(self):
|
||||
"""Update the color button appearance with current color."""
|
||||
pixmap = QPixmap(40, 30)
|
||||
pixmap.fill(self.current_color)
|
||||
|
||||
# Add border
|
||||
painter = QPainter(pixmap)
|
||||
painter.setPen(Qt.black)
|
||||
painter.drawRect(0, 0, pixmap.width() - 1, pixmap.height() - 1)
|
||||
painter.end()
|
||||
|
||||
self.color_btn.setIcon(QIcon(pixmap))
|
||||
self.color_btn.setStyleSheet(f"background-color: {self.current_color.name()};")
|
||||
|
||||
def _load_object_classes(self):
|
||||
"""Load object classes from database and populate combo box."""
|
||||
try:
|
||||
classes = self.db_manager.get_object_classes()
|
||||
|
||||
# Clear and repopulate combo box
|
||||
self.class_combo.clear()
|
||||
self.class_combo.addItem("-- Select Class / Show All --", None)
|
||||
|
||||
for cls in classes:
|
||||
self.class_combo.addItem(cls["class_name"], cls)
|
||||
|
||||
logger.debug(f"Loaded {len(classes)} object classes")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading object classes: {e}")
|
||||
QMessageBox.warning(
|
||||
self, "Error", f"Failed to load object classes:\n{str(e)}"
|
||||
)
|
||||
|
||||
def _on_polyline_toggle(self, checked: bool):
|
||||
"""Handle polyline tool enable/disable."""
|
||||
self.polyline_enabled = checked
|
||||
|
||||
if checked:
|
||||
self.polyline_toggle_btn.setText("Stop Drawing Polyline")
|
||||
self.polyline_toggle_btn.setStyleSheet(
|
||||
"QPushButton { background-color: #4CAF50; }"
|
||||
)
|
||||
else:
|
||||
self.polyline_toggle_btn.setText("Start Drawing Polyline")
|
||||
self.polyline_toggle_btn.setStyleSheet("")
|
||||
|
||||
self.polyline_enabled_changed.emit(self.polyline_enabled)
|
||||
logger.debug(f"Polyline tool {'enabled' if checked else 'disabled'}")
|
||||
|
||||
def _on_polyline_pen_width_changed(self, width: int):
|
||||
"""Handle polyline pen width changes."""
|
||||
self.polyline_pen_width_changed.emit(width)
|
||||
logger.debug(f"Polyline pen width changed to {width}")
|
||||
|
||||
def _on_simplify_toggle(self, state: int):
|
||||
"""Handle simplify-on-finish checkbox toggle."""
|
||||
enabled = bool(state)
|
||||
self.simplify_on_finish_changed.emit(enabled)
|
||||
logger.debug(f"Simplify on finish set to {enabled}")
|
||||
|
||||
def _on_eps_change(self, val: float):
|
||||
"""Handle epsilon (RDP tolerance) value changes."""
|
||||
epsilon = float(val)
|
||||
self.simplify_epsilon_changed.emit(epsilon)
|
||||
logger.debug(f"Simplification epsilon changed to {epsilon}")
|
||||
|
||||
def _on_show_bboxes_toggle(self, state: int):
|
||||
"""Handle 'Show bounding boxes' checkbox toggle."""
|
||||
show = bool(state)
|
||||
self.show_bboxes_changed.emit(show)
|
||||
logger.debug(f"Show bounding boxes set to {show}")
|
||||
|
||||
def _on_color_picker(self):
|
||||
"""Open color picker dialog and update the selected object's class color."""
|
||||
if not self.current_class:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"No Class Selected",
|
||||
"Please select an object class before changing its color.",
|
||||
)
|
||||
return
|
||||
|
||||
# Use current class color (without alpha) as the base
|
||||
base_color = QColor(self.current_class.get("color", self.current_color.name()))
|
||||
color = QColorDialog.getColor(
|
||||
base_color,
|
||||
self,
|
||||
"Select Class Color",
|
||||
QColorDialog.ShowAlphaChannel, # Allow alpha in UI, but store RGB in DB
|
||||
)
|
||||
|
||||
if not color.isValid():
|
||||
return
|
||||
|
||||
# Normalize to opaque RGB for storage
|
||||
new_color = QColor(color)
|
||||
new_color.setAlpha(255)
|
||||
hex_color = new_color.name()
|
||||
|
||||
try:
|
||||
# Update in database
|
||||
self.db_manager.update_object_class(
|
||||
class_id=self.current_class["id"], color=hex_color
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update class color in database: {e}")
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Error",
|
||||
f"Failed to update class color in database:\n{str(e)}",
|
||||
)
|
||||
return
|
||||
|
||||
# Update local class data and combo box item data
|
||||
self.current_class["color"] = hex_color
|
||||
current_index = self.class_combo.currentIndex()
|
||||
if current_index >= 0:
|
||||
self.class_combo.setItemData(current_index, dict(self.current_class))
|
||||
|
||||
# Update info label text
|
||||
info_text = f"Class: {self.current_class['class_name']}\nColor: {hex_color}"
|
||||
if self.current_class.get("description"):
|
||||
info_text += f"\nDescription: {self.current_class['description']}"
|
||||
self.class_info_label.setText(info_text)
|
||||
|
||||
# Use semi-transparent version for polyline pen / button preview
|
||||
class_color = QColor(hex_color)
|
||||
class_color.setAlpha(128)
|
||||
self.current_color = class_color
|
||||
self._update_color_button()
|
||||
self.polyline_pen_color_changed.emit(class_color)
|
||||
|
||||
logger.debug(
|
||||
f"Updated class '{self.current_class['class_name']}' color to "
|
||||
f"{hex_color} (polyline pen alpha={class_color.alpha()})"
|
||||
)
|
||||
|
||||
# Notify listeners (e.g., AnnotationTab) so they can reload/redraw
|
||||
self.class_color_changed.emit()
|
||||
|
||||
def _on_class_selected(self, index: int):
|
||||
"""Handle object class selection (including '-- Select Class --')."""
|
||||
class_data = self.class_combo.currentData()
|
||||
|
||||
if class_data:
|
||||
self.current_class = class_data
|
||||
|
||||
# Update info label
|
||||
info_text = (
|
||||
f"Class: {class_data['class_name']}\n" f"Color: {class_data['color']}"
|
||||
)
|
||||
if class_data.get("description"):
|
||||
info_text += f"\nDescription: {class_data['description']}"
|
||||
|
||||
self.class_info_label.setText(info_text)
|
||||
|
||||
# Update polyline pen color to match class color with semi-transparency
|
||||
class_color = QColor(class_data["color"])
|
||||
if class_color.isValid():
|
||||
# Add 50% alpha for semi-transparency
|
||||
class_color.setAlpha(128)
|
||||
self.current_color = class_color
|
||||
self._update_color_button()
|
||||
self.polyline_pen_color_changed.emit(class_color)
|
||||
|
||||
self.class_selected.emit(class_data)
|
||||
logger.debug(f"Selected class: {class_data['class_name']}")
|
||||
else:
|
||||
# "-- Select Class --" chosen: clear current class and show all annotations
|
||||
self.current_class = None
|
||||
self.class_info_label.setText("No class selected")
|
||||
self.class_selected.emit(None)
|
||||
logger.debug("Class selection cleared: showing annotations for all classes")
|
||||
|
||||
def _on_add_class(self):
|
||||
"""Handle adding a new object class."""
|
||||
# Get class name
|
||||
class_name, ok = QInputDialog.getText(
|
||||
self, "Add Object Class", "Enter class name:"
|
||||
)
|
||||
|
||||
if not ok or not class_name.strip():
|
||||
return
|
||||
|
||||
class_name = class_name.strip()
|
||||
|
||||
# Check if class already exists
|
||||
existing = self.db_manager.get_object_class_by_name(class_name)
|
||||
if existing:
|
||||
QMessageBox.warning(
|
||||
self, "Class Exists", f"A class named '{class_name}' already exists."
|
||||
)
|
||||
return
|
||||
|
||||
# Get color
|
||||
color = QColorDialog.getColor(self.current_color, self, "Select Class Color")
|
||||
|
||||
if not color.isValid():
|
||||
return
|
||||
|
||||
# Get optional description
|
||||
description, ok = QInputDialog.getText(
|
||||
self, "Class Description", "Enter class description (optional):"
|
||||
)
|
||||
|
||||
if not ok:
|
||||
description = None
|
||||
|
||||
# Add to database
|
||||
try:
|
||||
class_id = self.db_manager.add_object_class(
|
||||
class_name, color.name(), description.strip() if description else None
|
||||
)
|
||||
|
||||
logger.info(f"Added new object class: {class_name} (ID: {class_id})")
|
||||
|
||||
# Reload classes and select the new one
|
||||
self._load_object_classes()
|
||||
|
||||
# Find and select the newly added class
|
||||
for i in range(self.class_combo.count()):
|
||||
class_data = self.class_combo.itemData(i)
|
||||
if class_data and class_data.get("id") == class_id:
|
||||
self.class_combo.setCurrentIndex(i)
|
||||
break
|
||||
|
||||
QMessageBox.information(
|
||||
self, "Success", f"Class '{class_name}' added successfully!"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding object class: {e}")
|
||||
QMessageBox.critical(
|
||||
self, "Error", f"Failed to add object class:\n{str(e)}"
|
||||
)
|
||||
|
||||
def _on_clear_annotations(self):
|
||||
"""Handle clear annotations button."""
|
||||
reply = QMessageBox.question(
|
||||
self,
|
||||
"Clear Annotations",
|
||||
"Are you sure you want to clear all annotations?",
|
||||
QMessageBox.Yes | QMessageBox.No,
|
||||
QMessageBox.No,
|
||||
)
|
||||
|
||||
if reply == QMessageBox.Yes:
|
||||
self.clear_annotations_requested.emit()
|
||||
logger.debug("Clear annotations requested")
|
||||
|
||||
def _on_delete_selected_annotation(self):
|
||||
"""Handle delete selected annotation button."""
|
||||
self.delete_selected_annotation_requested.emit()
|
||||
logger.debug("Delete selected annotation requested")
|
||||
|
||||
def set_has_selected_annotation(self, has_selection: bool):
|
||||
"""
|
||||
Enable/disable actions that require a selected annotation.
|
||||
|
||||
Args:
|
||||
has_selection: True if an annotation is currently selected on the canvas.
|
||||
"""
|
||||
self.delete_selected_btn.setEnabled(bool(has_selection))
|
||||
|
||||
def get_current_class(self) -> Optional[Dict]:
|
||||
"""Get currently selected object class."""
|
||||
return self.current_class
|
||||
|
||||
def get_polyline_pen_color(self) -> QColor:
|
||||
"""Get current polyline pen color."""
|
||||
return self.current_color
|
||||
|
||||
def get_polyline_pen_width(self) -> int:
|
||||
"""Get current polyline pen width."""
|
||||
return self.polyline_pen_width_spin.value()
|
||||
|
||||
def is_polyline_enabled(self) -> bool:
|
||||
"""Check if polyline tool is enabled."""
|
||||
return self.polyline_enabled
|
||||
282
src/gui/widgets/image_display_widget.py
Normal file
282
src/gui/widgets/image_display_widget.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""
|
||||
Image display widget with zoom functionality for the microscopy object detection application.
|
||||
Reusable widget for displaying images with zoom controls.
|
||||
"""
|
||||
|
||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QScrollArea
|
||||
from PySide6.QtGui import QPixmap, QImage, QKeyEvent
|
||||
from PySide6.QtCore import Qt, QEvent, Signal
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
from src.utils.image import Image, ImageLoadError
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ImageDisplayWidget(QWidget):
|
||||
"""
|
||||
Reusable widget for displaying images with zoom functionality.
|
||||
|
||||
Features:
|
||||
- Display images from Image objects
|
||||
- Zoom in/out with mouse wheel
|
||||
- Zoom in/out with +/- keyboard keys
|
||||
- Reset zoom with Ctrl+0
|
||||
- Scroll area for large images
|
||||
|
||||
Signals:
|
||||
zoom_changed: Emitted when zoom level changes (float zoom_scale)
|
||||
"""
|
||||
|
||||
zoom_changed = Signal(float) # Emitted when zoom level changes
|
||||
|
||||
def __init__(self, parent=None):
|
||||
"""
|
||||
Initialize the image display widget.
|
||||
|
||||
Args:
|
||||
parent: Parent widget
|
||||
"""
|
||||
super().__init__(parent)
|
||||
|
||||
self.current_image = None
|
||||
self.original_pixmap = None # Store original pixmap for zoom
|
||||
self.zoom_scale = 1.0 # Current zoom scale
|
||||
self.zoom_min = 0.1 # Minimum zoom (10%)
|
||||
self.zoom_max = 10.0 # Maximum zoom (1000%)
|
||||
self.zoom_step = 0.1 # Zoom step for +/- keys
|
||||
self.zoom_wheel_step = 0.15 # Zoom step for mouse wheel
|
||||
|
||||
self._setup_ui()
|
||||
|
||||
def _setup_ui(self):
|
||||
"""Setup user interface."""
|
||||
layout = QVBoxLayout()
|
||||
layout.setContentsMargins(0, 0, 0, 0)
|
||||
|
||||
# Scroll area for image
|
||||
self.scroll_area = QScrollArea()
|
||||
self.scroll_area.setWidgetResizable(True)
|
||||
self.scroll_area.setMinimumHeight(400)
|
||||
|
||||
self.image_label = QLabel("No image loaded")
|
||||
self.image_label.setAlignment(Qt.AlignCenter)
|
||||
self.image_label.setStyleSheet(
|
||||
"QLabel { background-color: #2b2b2b; color: #888; }"
|
||||
)
|
||||
self.image_label.setScaledContents(False)
|
||||
|
||||
# Enable mouse tracking for wheel events
|
||||
self.image_label.setMouseTracking(True)
|
||||
self.scroll_area.setWidget(self.image_label)
|
||||
|
||||
# Install event filter to capture wheel events on scroll area
|
||||
self.scroll_area.viewport().installEventFilter(self)
|
||||
|
||||
layout.addWidget(self.scroll_area)
|
||||
self.setLayout(layout)
|
||||
|
||||
# Set focus policy to receive keyboard events
|
||||
self.setFocusPolicy(Qt.StrongFocus)
|
||||
|
||||
def load_image(self, image: Image):
|
||||
"""
|
||||
Load and display an image.
|
||||
|
||||
Args:
|
||||
image: Image object to display
|
||||
|
||||
Raises:
|
||||
ImageLoadError: If image cannot be displayed
|
||||
"""
|
||||
self.current_image = image
|
||||
|
||||
# Reset zoom when loading new image
|
||||
self.zoom_scale = 1.0
|
||||
|
||||
# Convert to QPixmap and display
|
||||
self._display_image()
|
||||
|
||||
logger.debug(f"Loaded image into display widget: {image.width}x{image.height}")
|
||||
|
||||
def clear(self):
|
||||
"""Clear the displayed image."""
|
||||
self.current_image = None
|
||||
self.original_pixmap = None
|
||||
self.zoom_scale = 1.0
|
||||
self.image_label.setText("No image loaded")
|
||||
self.image_label.setPixmap(QPixmap())
|
||||
logger.debug("Cleared image display")
|
||||
|
||||
def _display_image(self):
|
||||
"""Display the current image in the image label."""
|
||||
if self.current_image is None:
|
||||
return
|
||||
|
||||
try:
|
||||
# Get RGB image data
|
||||
if self.current_image.channels == 3:
|
||||
image_data = self.current_image.get_rgb()
|
||||
height, width, channels = image_data.shape
|
||||
else:
|
||||
image_data = self.current_image.get_grayscale()
|
||||
height, width = image_data.shape
|
||||
channels = 1
|
||||
|
||||
# Ensure data is contiguous for proper QImage display
|
||||
image_data = np.ascontiguousarray(image_data)
|
||||
|
||||
# Use actual stride from numpy array for correct display
|
||||
bytes_per_line = image_data.strides[0]
|
||||
|
||||
qimage = QImage(
|
||||
image_data.data,
|
||||
width,
|
||||
height,
|
||||
bytes_per_line,
|
||||
self.current_image.qtimage_format,
|
||||
)
|
||||
|
||||
# Convert to pixmap
|
||||
pixmap = QPixmap.fromImage(qimage)
|
||||
|
||||
# Store original pixmap for zooming
|
||||
self.original_pixmap = pixmap
|
||||
|
||||
# Apply zoom and display
|
||||
self._apply_zoom()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error displaying image: {e}")
|
||||
raise ImageLoadError(f"Failed to display image: {str(e)}")
|
||||
|
||||
def _apply_zoom(self):
|
||||
"""Apply current zoom level to the displayed image."""
|
||||
if self.original_pixmap is None:
|
||||
return
|
||||
|
||||
# Calculate scaled size
|
||||
scaled_width = int(self.original_pixmap.width() * self.zoom_scale)
|
||||
scaled_height = int(self.original_pixmap.height() * self.zoom_scale)
|
||||
|
||||
# Scale pixmap
|
||||
scaled_pixmap = self.original_pixmap.scaled(
|
||||
scaled_width,
|
||||
scaled_height,
|
||||
Qt.KeepAspectRatio,
|
||||
(
|
||||
Qt.SmoothTransformation
|
||||
if self.zoom_scale >= 1.0
|
||||
else Qt.FastTransformation
|
||||
),
|
||||
)
|
||||
|
||||
# Display in label
|
||||
self.image_label.setPixmap(scaled_pixmap)
|
||||
self.image_label.setScaledContents(False)
|
||||
self.image_label.adjustSize()
|
||||
|
||||
# Emit zoom changed signal
|
||||
self.zoom_changed.emit(self.zoom_scale)
|
||||
|
||||
def zoom_in(self):
|
||||
"""Zoom in on the image."""
|
||||
if self.original_pixmap is None:
|
||||
return
|
||||
|
||||
new_scale = self.zoom_scale + self.zoom_step
|
||||
if new_scale <= self.zoom_max:
|
||||
self.zoom_scale = new_scale
|
||||
self._apply_zoom()
|
||||
logger.debug(f"Zoomed in to {int(self.zoom_scale * 100)}%")
|
||||
|
||||
def zoom_out(self):
|
||||
"""Zoom out from the image."""
|
||||
if self.original_pixmap is None:
|
||||
return
|
||||
|
||||
new_scale = self.zoom_scale - self.zoom_step
|
||||
if new_scale >= self.zoom_min:
|
||||
self.zoom_scale = new_scale
|
||||
self._apply_zoom()
|
||||
logger.debug(f"Zoomed out to {int(self.zoom_scale * 100)}%")
|
||||
|
||||
def reset_zoom(self):
|
||||
"""Reset zoom to 100%."""
|
||||
if self.original_pixmap is None:
|
||||
return
|
||||
|
||||
self.zoom_scale = 1.0
|
||||
self._apply_zoom()
|
||||
logger.debug("Reset zoom to 100%")
|
||||
|
||||
def set_zoom(self, scale: float):
|
||||
"""
|
||||
Set zoom to a specific scale.
|
||||
|
||||
Args:
|
||||
scale: Zoom scale (1.0 = 100%)
|
||||
"""
|
||||
if self.original_pixmap is None:
|
||||
return
|
||||
|
||||
# Clamp to min/max
|
||||
scale = max(self.zoom_min, min(self.zoom_max, scale))
|
||||
|
||||
self.zoom_scale = scale
|
||||
self._apply_zoom()
|
||||
logger.debug(f"Set zoom to {int(self.zoom_scale * 100)}%")
|
||||
|
||||
def get_zoom_percentage(self) -> int:
|
||||
"""
|
||||
Get current zoom level as percentage.
|
||||
|
||||
Returns:
|
||||
Zoom level as integer percentage (e.g., 100 for 100%)
|
||||
"""
|
||||
return int(self.zoom_scale * 100)
|
||||
|
||||
def keyPressEvent(self, event: QKeyEvent):
|
||||
"""Handle keyboard events for zooming."""
|
||||
if event.key() in (Qt.Key_Plus, Qt.Key_Equal):
|
||||
# + or = key (= is the unshifted + on many keyboards)
|
||||
self.zoom_in()
|
||||
event.accept()
|
||||
elif event.key() == Qt.Key_Minus:
|
||||
# - key
|
||||
self.zoom_out()
|
||||
event.accept()
|
||||
elif event.key() == Qt.Key_0 and event.modifiers() == Qt.ControlModifier:
|
||||
# Ctrl+0 to reset zoom
|
||||
self.reset_zoom()
|
||||
event.accept()
|
||||
else:
|
||||
super().keyPressEvent(event)
|
||||
|
||||
def eventFilter(self, obj, event: QEvent) -> bool:
|
||||
"""Event filter to capture wheel events for zooming."""
|
||||
if event.type() == QEvent.Wheel:
|
||||
wheel_event = event
|
||||
if self.original_pixmap is not None:
|
||||
# Get wheel angle delta
|
||||
delta = wheel_event.angleDelta().y()
|
||||
|
||||
# Zoom in/out based on wheel direction
|
||||
if delta > 0:
|
||||
# Scroll up = zoom in
|
||||
new_scale = self.zoom_scale + self.zoom_wheel_step
|
||||
if new_scale <= self.zoom_max:
|
||||
self.zoom_scale = new_scale
|
||||
self._apply_zoom()
|
||||
else:
|
||||
# Scroll down = zoom out
|
||||
new_scale = self.zoom_scale - self.zoom_wheel_step
|
||||
if new_scale >= self.zoom_min:
|
||||
self.zoom_scale = new_scale
|
||||
self._apply_zoom()
|
||||
|
||||
return True # Event handled
|
||||
|
||||
return super().eventFilter(obj, event)
|
||||
49
src/gui_launcher.py
Normal file
49
src/gui_launcher.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""GUI launcher module for microscopy object detection application."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from PySide6.QtWidgets import QApplication
|
||||
from PySide6.QtCore import Qt
|
||||
|
||||
from src import __version__
|
||||
from src.gui.main_window import MainWindow
|
||||
from src.utils.logger import setup_logging
|
||||
from src.utils.config_manager import ConfigManager
|
||||
|
||||
|
||||
def main():
|
||||
"""Launch the GUI application."""
|
||||
# Setup logging
|
||||
config_manager = ConfigManager()
|
||||
log_config = config_manager.get_section("logging")
|
||||
setup_logging(
|
||||
log_file=log_config.get("file", "logs/app.log"),
|
||||
level=log_config.get("level", "INFO"),
|
||||
log_format=log_config.get("format"),
|
||||
)
|
||||
|
||||
# Enable High DPI scaling
|
||||
QApplication.setHighDpiScaleFactorRoundingPolicy(
|
||||
Qt.HighDpiScaleFactorRoundingPolicy.PassThrough
|
||||
)
|
||||
|
||||
# Create Qt application
|
||||
app = QApplication(sys.argv)
|
||||
app.setApplicationName("Microscopy Object Detection")
|
||||
app.setOrganizationName("MicroscopyLab")
|
||||
app.setApplicationVersion(__version__)
|
||||
|
||||
# Set application style
|
||||
app.setStyle("Fusion")
|
||||
|
||||
# Create and show main window
|
||||
window = MainWindow()
|
||||
window.show()
|
||||
|
||||
# Run application
|
||||
sys.exit(app.exec())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -87,6 +87,7 @@ class InferenceEngine:
|
||||
"class_name": det["class_name"],
|
||||
"bbox": tuple(bbox_normalized),
|
||||
"confidence": det["confidence"],
|
||||
"segmentation_mask": det.get("segmentation_mask"),
|
||||
"metadata": {"class_id": det["class_id"]},
|
||||
}
|
||||
detection_records.append(record)
|
||||
@@ -160,6 +161,7 @@ class InferenceEngine:
|
||||
conf: float = 0.25,
|
||||
bbox_thickness: int = 2,
|
||||
bbox_colors: Optional[Dict[str, str]] = None,
|
||||
draw_masks: bool = True,
|
||||
) -> tuple:
|
||||
"""
|
||||
Detect objects and return annotated image.
|
||||
@@ -169,6 +171,7 @@ class InferenceEngine:
|
||||
conf: Confidence threshold
|
||||
bbox_thickness: Thickness of bounding boxes
|
||||
bbox_colors: Dictionary mapping class names to hex colors
|
||||
draw_masks: Whether to draw segmentation masks (if available)
|
||||
|
||||
Returns:
|
||||
Tuple of (detections, annotated_image_array)
|
||||
@@ -189,12 +192,8 @@ class InferenceEngine:
|
||||
bbox_colors = {}
|
||||
default_color = self._hex_to_bgr(bbox_colors.get("default", "#00FF00"))
|
||||
|
||||
# Draw bounding boxes
|
||||
# Draw detections
|
||||
for det in detections:
|
||||
# Get absolute coordinates
|
||||
bbox_abs = det["bbox_absolute"]
|
||||
x1, y1, x2, y2 = [int(v) for v in bbox_abs]
|
||||
|
||||
# Get color for this class
|
||||
class_name = det["class_name"]
|
||||
color_hex = bbox_colors.get(
|
||||
@@ -202,7 +201,33 @@ class InferenceEngine:
|
||||
)
|
||||
color = self._hex_to_bgr(color_hex)
|
||||
|
||||
# Draw box
|
||||
# Draw segmentation mask if available and requested
|
||||
if draw_masks and det.get("segmentation_mask"):
|
||||
mask_normalized = det["segmentation_mask"]
|
||||
if mask_normalized and len(mask_normalized) > 0:
|
||||
# Convert normalized coordinates to absolute pixels
|
||||
mask_points = np.array(
|
||||
[
|
||||
[int(pt[0] * width), int(pt[1] * height)]
|
||||
for pt in mask_normalized
|
||||
],
|
||||
dtype=np.int32,
|
||||
)
|
||||
|
||||
# Create a semi-transparent overlay
|
||||
overlay = img.copy()
|
||||
cv2.fillPoly(overlay, [mask_points], color)
|
||||
# Blend with original image (30% opacity)
|
||||
cv2.addWeighted(overlay, 0.3, img, 0.7, 0, img)
|
||||
|
||||
# Draw mask contour
|
||||
cv2.polylines(img, [mask_points], True, color, bbox_thickness)
|
||||
|
||||
# Get absolute coordinates for bounding box
|
||||
bbox_abs = det["bbox_absolute"]
|
||||
x1, y1, x2, y2 = [int(v) for v in bbox_abs]
|
||||
|
||||
# Draw bounding box
|
||||
cv2.rectangle(img, (x1, y1), (x2, y2), color, bbox_thickness)
|
||||
|
||||
# Prepare label
|
||||
|
||||
@@ -16,7 +16,7 @@ logger = get_logger(__name__)
|
||||
class YOLOWrapper:
|
||||
"""Wrapper for YOLOv8 model operations."""
|
||||
|
||||
def __init__(self, model_path: str = "yolov8s.pt"):
|
||||
def __init__(self, model_path: str = "yolov8s-seg.pt"):
|
||||
"""
|
||||
Initialize YOLO model.
|
||||
|
||||
@@ -282,6 +282,10 @@ class YOLOWrapper:
|
||||
boxes = result.boxes
|
||||
image_path = str(result.path)
|
||||
orig_shape = result.orig_shape # (height, width)
|
||||
height, width = orig_shape
|
||||
|
||||
# Check if this is a segmentation model with masks
|
||||
has_masks = hasattr(result, "masks") and result.masks is not None
|
||||
|
||||
for i in range(len(boxes)):
|
||||
# Get normalized coordinates
|
||||
@@ -299,6 +303,33 @@ class YOLOWrapper:
|
||||
float(v) for v in boxes.xyxy[i].cpu().numpy()
|
||||
], # Absolute pixels
|
||||
}
|
||||
|
||||
# Extract segmentation mask if available
|
||||
if has_masks:
|
||||
try:
|
||||
# Get the mask for this detection
|
||||
mask_data = result.masks.xy[
|
||||
i
|
||||
] # Polygon coordinates in absolute pixels
|
||||
|
||||
# Convert to normalized coordinates
|
||||
if len(mask_data) > 0:
|
||||
mask_normalized = []
|
||||
for point in mask_data:
|
||||
x_norm = float(point[0]) / width
|
||||
y_norm = float(point[1]) / height
|
||||
mask_normalized.append([x_norm, y_norm])
|
||||
detection["segmentation_mask"] = mask_normalized
|
||||
else:
|
||||
detection["segmentation_mask"] = None
|
||||
except Exception as mask_error:
|
||||
logger.warning(
|
||||
f"Error extracting mask for detection {i}: {mask_error}"
|
||||
)
|
||||
detection["segmentation_mask"] = None
|
||||
else:
|
||||
detection["segmentation_mask"] = None
|
||||
|
||||
detections.append(detection)
|
||||
|
||||
return detections
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Utility modules for the microscopy object detection application.
|
||||
"""
|
||||
|
||||
from src.utils.image import Image, ImageLoadError
|
||||
|
||||
__all__ = ["Image", "ImageLoadError"]
|
||||
|
||||
@@ -56,7 +56,7 @@ class ConfigManager:
|
||||
],
|
||||
},
|
||||
"models": {
|
||||
"default_base_model": "yolov8s.pt",
|
||||
"default_base_model": "yolov8s-seg.pt",
|
||||
"models_directory": "data/models",
|
||||
},
|
||||
"training": {
|
||||
|
||||
291
src/utils/image.py
Normal file
291
src/utils/image.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""
|
||||
Image loading and management utilities for the microscopy object detection application.
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.file_utils import validate_file_path, is_image_file
|
||||
|
||||
from PySide6.QtGui import QImage
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ImageLoadError(Exception):
|
||||
"""Exception raised when an image cannot be loaded."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Image:
|
||||
"""
|
||||
A class for loading and managing images from file paths.
|
||||
|
||||
Supports multiple image formats: .jpg, .jpeg, .png, .tif, .tiff, .bmp
|
||||
Provides access to image data in multiple formats (OpenCV/numpy, PIL).
|
||||
|
||||
Attributes:
|
||||
path: Path to the image file
|
||||
data: Image data as numpy array (OpenCV format, BGR)
|
||||
pil_image: Image data as PIL Image (RGB)
|
||||
width: Image width in pixels
|
||||
height: Image height in pixels
|
||||
channels: Number of color channels
|
||||
format: Image file format
|
||||
size_bytes: File size in bytes
|
||||
"""
|
||||
|
||||
SUPPORTED_EXTENSIONS = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
|
||||
|
||||
def __init__(self, image_path: Union[str, Path]):
|
||||
"""
|
||||
Initialize an Image object by loading from a file path.
|
||||
|
||||
Args:
|
||||
image_path: Path to the image file (string or Path object)
|
||||
|
||||
Raises:
|
||||
ImageLoadError: If the image cannot be loaded or is invalid
|
||||
"""
|
||||
self.path = Path(image_path)
|
||||
self._data: Optional[np.ndarray] = None
|
||||
self._pil_image: Optional[PILImage.Image] = None
|
||||
self._width: int = 0
|
||||
self._height: int = 0
|
||||
self._channels: int = 0
|
||||
self._format: str = ""
|
||||
self._size_bytes: int = 0
|
||||
self._dtype: Optional[np.dtype] = None
|
||||
|
||||
# Load the image
|
||||
self._load()
|
||||
|
||||
def _load(self) -> None:
|
||||
"""
|
||||
Load the image from disk.
|
||||
|
||||
Raises:
|
||||
ImageLoadError: If the image cannot be loaded
|
||||
"""
|
||||
# Validate path
|
||||
if not validate_file_path(str(self.path), must_exist=True):
|
||||
raise ImageLoadError(f"Invalid or non-existent file path: {self.path}")
|
||||
|
||||
# Check file extension
|
||||
if not is_image_file(str(self.path), self.SUPPORTED_EXTENSIONS):
|
||||
ext = self.path.suffix.lower()
|
||||
raise ImageLoadError(
|
||||
f"Unsupported image format: {ext}. "
|
||||
f"Supported formats: {', '.join(self.SUPPORTED_EXTENSIONS)}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Load with OpenCV (returns BGR format)
|
||||
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
|
||||
|
||||
if self._data is None:
|
||||
raise ImageLoadError(f"Failed to load image with OpenCV: {self.path}")
|
||||
|
||||
# Extract metadata
|
||||
self._height, self._width = self._data.shape[:2]
|
||||
self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1
|
||||
self._format = self.path.suffix.lower().lstrip(".")
|
||||
self._size_bytes = self.path.stat().st_size
|
||||
self._dtype = self._data.dtype
|
||||
|
||||
# Load PIL version for compatibility (convert BGR to RGB)
|
||||
if self._channels == 3:
|
||||
rgb_data = cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
|
||||
self._pil_image = PILImage.fromarray(rgb_data)
|
||||
elif self._channels == 4:
|
||||
rgba_data = cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA)
|
||||
self._pil_image = PILImage.fromarray(rgba_data)
|
||||
else:
|
||||
# Grayscale
|
||||
self._pil_image = PILImage.fromarray(self._data)
|
||||
|
||||
logger.info(
|
||||
f"Successfully loaded image: {self.path.name} "
|
||||
f"({self._width}x{self._height}, {self._channels} channels, "
|
||||
f"{self._format.upper()})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading image {self.path}: {e}")
|
||||
raise ImageLoadError(f"Failed to load image: {e}") from e
|
||||
|
||||
@property
|
||||
def data(self) -> np.ndarray:
|
||||
"""
|
||||
Get image data as numpy array (OpenCV format, BGR or grayscale).
|
||||
|
||||
Returns:
|
||||
Image data as numpy array
|
||||
"""
|
||||
if self._data is None:
|
||||
raise ImageLoadError("Image data not available")
|
||||
return self._data
|
||||
|
||||
@property
|
||||
def pil_image(self) -> PILImage.Image:
|
||||
"""
|
||||
Get image data as PIL Image (RGB or grayscale).
|
||||
|
||||
Returns:
|
||||
PIL Image object
|
||||
"""
|
||||
if self._pil_image is None:
|
||||
raise ImageLoadError("PIL image not available")
|
||||
return self._pil_image
|
||||
|
||||
@property
|
||||
def width(self) -> int:
|
||||
"""Get image width in pixels."""
|
||||
return self._width
|
||||
|
||||
@property
|
||||
def height(self) -> int:
|
||||
"""Get image height in pixels."""
|
||||
return self._height
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, int, int]:
|
||||
"""
|
||||
Get image shape as (height, width, channels).
|
||||
|
||||
Returns:
|
||||
Tuple of (height, width, channels)
|
||||
"""
|
||||
print("shape", self._height, self._width, self._channels)
|
||||
return (self._height, self._width, self._channels)
|
||||
|
||||
@property
|
||||
def channels(self) -> int:
|
||||
"""Get number of color channels."""
|
||||
return self._channels
|
||||
|
||||
@property
|
||||
def format(self) -> str:
|
||||
"""Get image file format (e.g., 'jpg', 'png')."""
|
||||
return self._format
|
||||
|
||||
@property
|
||||
def size_bytes(self) -> int:
|
||||
"""Get file size in bytes."""
|
||||
return self._size_bytes
|
||||
|
||||
@property
|
||||
def size_mb(self) -> float:
|
||||
"""Get file size in megabytes."""
|
||||
return self._size_bytes / (1024 * 1024)
|
||||
|
||||
@property
|
||||
def dtype(self) -> np.dtype:
|
||||
"""Get the data type of the image array."""
|
||||
if self._dtype is None:
|
||||
raise ImageLoadError("Image dtype not available")
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def qtimage_format(self) -> QImage.Format:
|
||||
"""
|
||||
Get the appropriate QImage format for the image.
|
||||
|
||||
Returns:
|
||||
QImage.Format enum value
|
||||
"""
|
||||
if self._channels == 3:
|
||||
return QImage.Format_RGB888
|
||||
elif self._channels == 4:
|
||||
return QImage.Format_RGBA8888
|
||||
elif self._channels == 1:
|
||||
if self._dtype == np.uint16:
|
||||
return QImage.Format_Grayscale16
|
||||
else:
|
||||
return QImage.Format_Grayscale8
|
||||
else:
|
||||
raise ImageLoadError(f"Unsupported number of channels: {self._channels}")
|
||||
|
||||
def get_rgb(self) -> np.ndarray:
|
||||
"""
|
||||
Get image data as RGB numpy array.
|
||||
|
||||
Returns:
|
||||
Image data in RGB format as numpy array
|
||||
"""
|
||||
if self._channels == 3:
|
||||
return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
|
||||
elif self._channels == 4:
|
||||
return cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA)
|
||||
else:
|
||||
return self._data
|
||||
|
||||
def get_grayscale(self) -> np.ndarray:
|
||||
"""
|
||||
Get image as grayscale numpy array.
|
||||
|
||||
Returns:
|
||||
Grayscale image as numpy array
|
||||
"""
|
||||
if self._channels == 1:
|
||||
return self._data
|
||||
else:
|
||||
return cv2.cvtColor(self._data, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
def copy(self) -> np.ndarray:
|
||||
"""
|
||||
Get a copy of the image data.
|
||||
|
||||
Returns:
|
||||
Copy of image data as numpy array
|
||||
"""
|
||||
return self._data.copy()
|
||||
|
||||
def resize(self, width: int, height: int) -> np.ndarray:
|
||||
"""
|
||||
Resize the image to specified dimensions.
|
||||
|
||||
Args:
|
||||
width: Target width in pixels
|
||||
height: Target height in pixels
|
||||
|
||||
Returns:
|
||||
Resized image as numpy array (does not modify original)
|
||||
"""
|
||||
return cv2.resize(self._data, (width, height))
|
||||
|
||||
def is_grayscale(self) -> bool:
|
||||
"""
|
||||
Check if image is grayscale.
|
||||
|
||||
Returns:
|
||||
True if image is grayscale (1 channel)
|
||||
"""
|
||||
return self._channels == 1
|
||||
|
||||
def is_color(self) -> bool:
|
||||
"""
|
||||
Check if image is color.
|
||||
|
||||
Returns:
|
||||
True if image has 3 or more channels
|
||||
"""
|
||||
return self._channels >= 3
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of the Image object."""
|
||||
return (
|
||||
f"Image(path='{self.path.name}', "
|
||||
f"shape=({self._width}x{self._height}x{self._channels}), "
|
||||
f"format={self._format}, "
|
||||
f"size={self.size_mb:.2f}MB)"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the Image object."""
|
||||
return self.__repr__()
|
||||
145
tests/test_image.py
Normal file
145
tests/test_image.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""
|
||||
Tests for the Image class.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from src.utils.image import Image, ImageLoadError
|
||||
|
||||
|
||||
class TestImage:
|
||||
"""Test cases for the Image class."""
|
||||
|
||||
def test_load_nonexistent_file(self):
|
||||
"""Test loading a non-existent file raises ImageLoadError."""
|
||||
with pytest.raises(ImageLoadError):
|
||||
Image("nonexistent_file.jpg")
|
||||
|
||||
def test_load_unsupported_format(self, tmp_path):
|
||||
"""Test loading an unsupported format raises ImageLoadError."""
|
||||
# Create a dummy file with unsupported extension
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("not an image")
|
||||
|
||||
with pytest.raises(ImageLoadError):
|
||||
Image(test_file)
|
||||
|
||||
def test_supported_extensions(self):
|
||||
"""Test that supported extensions are correctly defined."""
|
||||
expected_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
|
||||
assert Image.SUPPORTED_EXTENSIONS == expected_extensions
|
||||
|
||||
def test_image_properties(self, tmp_path):
|
||||
"""Test image properties after loading."""
|
||||
# Create a simple test image using numpy and cv2
|
||||
import cv2
|
||||
|
||||
test_img = np.zeros((100, 200, 3), dtype=np.uint8)
|
||||
test_img[:, :] = [255, 0, 0] # Blue in BGR
|
||||
|
||||
test_file = tmp_path / "test.jpg"
|
||||
cv2.imwrite(str(test_file), test_img)
|
||||
|
||||
# Load the image
|
||||
img = Image(test_file)
|
||||
|
||||
# Check properties
|
||||
assert img.width == 200
|
||||
assert img.height == 100
|
||||
assert img.channels == 3
|
||||
assert img.format == "jpg"
|
||||
assert img.shape == (100, 200, 3)
|
||||
assert img.size_bytes > 0
|
||||
assert img.is_color()
|
||||
assert not img.is_grayscale()
|
||||
|
||||
def test_get_rgb(self, tmp_path):
|
||||
"""Test RGB conversion."""
|
||||
import cv2
|
||||
|
||||
# Create BGR image
|
||||
test_img = np.zeros((50, 50, 3), dtype=np.uint8)
|
||||
test_img[:, :] = [255, 0, 0] # Blue in BGR
|
||||
|
||||
test_file = tmp_path / "test_rgb.png"
|
||||
cv2.imwrite(str(test_file), test_img)
|
||||
|
||||
img = Image(test_file)
|
||||
rgb_data = img.get_rgb()
|
||||
|
||||
# RGB should have red channel at 255
|
||||
assert rgb_data[0, 0, 0] == 0 # R
|
||||
assert rgb_data[0, 0, 1] == 0 # G
|
||||
assert rgb_data[0, 0, 2] == 255 # B (was BGR blue)
|
||||
|
||||
def test_get_grayscale(self, tmp_path):
|
||||
"""Test grayscale conversion."""
|
||||
import cv2
|
||||
|
||||
test_img = np.zeros((50, 50, 3), dtype=np.uint8)
|
||||
test_img[:, :] = [128, 128, 128]
|
||||
|
||||
test_file = tmp_path / "test_gray.png"
|
||||
cv2.imwrite(str(test_file), test_img)
|
||||
|
||||
img = Image(test_file)
|
||||
gray_data = img.get_grayscale()
|
||||
|
||||
assert len(gray_data.shape) == 2 # Should be 2D
|
||||
assert gray_data.shape == (50, 50)
|
||||
|
||||
def test_copy(self, tmp_path):
|
||||
"""Test copying image data."""
|
||||
import cv2
|
||||
|
||||
test_img = np.zeros((50, 50, 3), dtype=np.uint8)
|
||||
|
||||
test_file = tmp_path / "test_copy.png"
|
||||
cv2.imwrite(str(test_file), test_img)
|
||||
|
||||
img = Image(test_file)
|
||||
copied = img.copy()
|
||||
|
||||
# Modify copy
|
||||
copied[0, 0] = [255, 255, 255]
|
||||
|
||||
# Original should be unchanged
|
||||
assert not np.array_equal(img.data[0, 0], copied[0, 0])
|
||||
|
||||
def test_resize(self, tmp_path):
|
||||
"""Test image resizing."""
|
||||
import cv2
|
||||
|
||||
test_img = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
test_file = tmp_path / "test_resize.png"
|
||||
cv2.imwrite(str(test_file), test_img)
|
||||
|
||||
img = Image(test_file)
|
||||
resized = img.resize(50, 50)
|
||||
|
||||
assert resized.shape == (50, 50, 3)
|
||||
# Original should be unchanged
|
||||
assert img.width == 100
|
||||
assert img.height == 100
|
||||
|
||||
def test_str_repr(self, tmp_path):
|
||||
"""Test string representation."""
|
||||
import cv2
|
||||
|
||||
test_img = np.zeros((100, 200, 3), dtype=np.uint8)
|
||||
|
||||
test_file = tmp_path / "test_str.jpg"
|
||||
cv2.imwrite(str(test_file), test_img)
|
||||
|
||||
img = Image(test_file)
|
||||
|
||||
str_repr = str(img)
|
||||
assert "test_str.jpg" in str_repr
|
||||
assert "100x200x3" in str_repr
|
||||
assert "jpg" in str_repr
|
||||
|
||||
repr_str = repr(img)
|
||||
assert "Image" in repr_str
|
||||
assert "test_str.jpg" in repr_str
|
||||
Reference in New Issue
Block a user