Compare commits
71 Commits
6bd2b100ca
...
monkey-pat
| Author | SHA1 | Date | |
|---|---|---|---|
| d03ffdc4d0 | |||
| 8d30e6bb7a | |||
| f810fec4d8 | |||
| 9c8931e6f3 | |||
| 20578c1fdf | |||
| 2c494dac49 | |||
| 506c74e53a | |||
| eefda5b878 | |||
| 31cb6a6c8e | |||
| 0c19ea2557 | |||
| 89e47591db | |||
| 69cde09e53 | |||
| fcbd5fb16d | |||
| ca52312925 | |||
| 0a93bf797a | |||
| d998c65665 | |||
| 510eabfa94 | |||
| 395d263900 | |||
| e98d287b8a | |||
| d25101de2d | |||
| f88beef188 | |||
| 2fd9a2acf4 | |||
| 2bcd18cc75 | |||
| 5d25378c46 | |||
| 2b0b48921e | |||
| b0c05f0225 | |||
| 97badaa390 | |||
| 8f8132ce61 | |||
| 6ae7481e25 | |||
| 061f8b3ca2 | |||
| a8e5db3135 | |||
| 268ed5175e | |||
| 5e9d3b1dc4 | |||
| 7d83e9b9b1 | |||
| e364d06217 | |||
| e5036c10cf | |||
| c7e388d9ae | |||
| 6b995e7325 | |||
| 0e0741d323 | |||
| dd99a0677c | |||
| 9c4c39fb39 | |||
| 20a87c9040 | |||
| 9f7d2be1ac | |||
| dbde07c0e8 | |||
| b3c5a51dbb | |||
| 9a221acb63 | |||
| 32a6a122bd | |||
| 9ba44043ef | |||
| 8eb1cc8c86 | |||
| e4ce882a18 | |||
| 6b6d6fad03 | |||
| c0684a9c14 | |||
| 221c80aa8c | |||
| 833b222fad | |||
| 5370d31dce | |||
| 5d196c3a4a | |||
| f719c7ec40 | |||
| e6a5e74fa1 | |||
| 35e2398e95 | |||
| c3d44ac945 | |||
| dad5c2bf74 | |||
| 73cb698488 | |||
| 12f2bf94d5 | |||
| 710b684456 | |||
| fc22479621 | |||
| f84dea0bff | |||
| bb26d43dd7 | |||
| 4b5d2a7c45 | |||
| 42fb2b782d | |||
| 310e0b2285 | |||
| 9011276584 |
@@ -2,11 +2,11 @@
|
|||||||
|
|
||||||
## Project Overview
|
## Project Overview
|
||||||
|
|
||||||
A desktop application for detecting organelles and membrane branching structures in microscopy images using YOLOv8s, with comprehensive training, validation, and visualization capabilities.
|
A desktop application for detecting and segmenting organelles and membrane branching structures in microscopy images using YOLOv8s-seg, with comprehensive training, validation, and visualization capabilities including pixel-accurate segmentation masks.
|
||||||
|
|
||||||
## Technology Stack
|
## Technology Stack
|
||||||
|
|
||||||
- **ML Framework**: Ultralytics YOLOv8 (YOLOv8s.pt model)
|
- **ML Framework**: Ultralytics YOLOv8 (YOLOv8s-seg.pt segmentation model)
|
||||||
- **GUI Framework**: PySide6 (Qt6 for Python)
|
- **GUI Framework**: PySide6 (Qt6 for Python)
|
||||||
- **Visualization**: pyqtgraph
|
- **Visualization**: pyqtgraph
|
||||||
- **Database**: SQLite3
|
- **Database**: SQLite3
|
||||||
@@ -110,6 +110,7 @@ erDiagram
|
|||||||
float x_max
|
float x_max
|
||||||
float y_max
|
float y_max
|
||||||
float confidence
|
float confidence
|
||||||
|
text segmentation_mask
|
||||||
datetime detected_at
|
datetime detected_at
|
||||||
json metadata
|
json metadata
|
||||||
}
|
}
|
||||||
@@ -122,6 +123,7 @@ erDiagram
|
|||||||
float y_min
|
float y_min
|
||||||
float x_max
|
float x_max
|
||||||
float y_max
|
float y_max
|
||||||
|
text segmentation_mask
|
||||||
string annotator
|
string annotator
|
||||||
datetime created_at
|
datetime created_at
|
||||||
boolean verified
|
boolean verified
|
||||||
@@ -139,7 +141,7 @@ Stores information about trained models and their versions.
|
|||||||
| model_name | TEXT | NOT NULL | User-friendly model name |
|
| model_name | TEXT | NOT NULL | User-friendly model name |
|
||||||
| model_version | TEXT | NOT NULL | Version string (e.g., "v1.0") |
|
| model_version | TEXT | NOT NULL | Version string (e.g., "v1.0") |
|
||||||
| model_path | TEXT | NOT NULL | Path to model weights file |
|
| model_path | TEXT | NOT NULL | Path to model weights file |
|
||||||
| base_model | TEXT | NOT NULL | Base model used (e.g., "yolov8s.pt") |
|
| base_model | TEXT | NOT NULL | Base model used (e.g., "yolov8s-seg.pt") |
|
||||||
| created_at | TIMESTAMP | DEFAULT CURRENT_TIMESTAMP | Model creation timestamp |
|
| created_at | TIMESTAMP | DEFAULT CURRENT_TIMESTAMP | Model creation timestamp |
|
||||||
| training_params | JSON | | Training hyperparameters |
|
| training_params | JSON | | Training hyperparameters |
|
||||||
| metrics | JSON | | Validation metrics (mAP, precision, recall) |
|
| metrics | JSON | | Validation metrics (mAP, precision, recall) |
|
||||||
@@ -159,7 +161,7 @@ Stores metadata about microscopy images.
|
|||||||
| checksum | TEXT | | MD5 hash for integrity verification |
|
| checksum | TEXT | | MD5 hash for integrity verification |
|
||||||
|
|
||||||
#### **detections** table
|
#### **detections** table
|
||||||
Stores object detection results.
|
Stores object detection results with optional segmentation masks.
|
||||||
|
|
||||||
| Column | Type | Constraints | Description |
|
| Column | Type | Constraints | Description |
|
||||||
|--------|------|-------------|-------------|
|
|--------|------|-------------|-------------|
|
||||||
@@ -172,11 +174,12 @@ Stores object detection results.
|
|||||||
| x_max | REAL | NOT NULL | Bounding box right coordinate (normalized 0-1) |
|
| x_max | REAL | NOT NULL | Bounding box right coordinate (normalized 0-1) |
|
||||||
| y_max | REAL | NOT NULL | Bounding box bottom coordinate (normalized 0-1) |
|
| y_max | REAL | NOT NULL | Bounding box bottom coordinate (normalized 0-1) |
|
||||||
| confidence | REAL | NOT NULL | Detection confidence score (0-1) |
|
| confidence | REAL | NOT NULL | Detection confidence score (0-1) |
|
||||||
|
| segmentation_mask | TEXT | | JSON array of polygon coordinates [[x1,y1], [x2,y2], ...] (normalized 0-1) |
|
||||||
| detected_at | TIMESTAMP | DEFAULT CURRENT_TIMESTAMP | When detection was performed |
|
| detected_at | TIMESTAMP | DEFAULT CURRENT_TIMESTAMP | When detection was performed |
|
||||||
| metadata | JSON | | Additional metadata (processing time, etc.) |
|
| metadata | JSON | | Additional metadata (processing time, etc.) |
|
||||||
|
|
||||||
#### **annotations** table
|
#### **annotations** table
|
||||||
Stores manual annotations for training data (future feature).
|
Stores manual annotations for training data with optional segmentation masks (future feature).
|
||||||
|
|
||||||
| Column | Type | Constraints | Description |
|
| Column | Type | Constraints | Description |
|
||||||
|--------|------|-------------|-------------|
|
|--------|------|-------------|-------------|
|
||||||
@@ -187,6 +190,7 @@ Stores manual annotations for training data (future feature).
|
|||||||
| y_min | REAL | NOT NULL | Bounding box top coordinate (normalized) |
|
| y_min | REAL | NOT NULL | Bounding box top coordinate (normalized) |
|
||||||
| x_max | REAL | NOT NULL | Bounding box right coordinate (normalized) |
|
| x_max | REAL | NOT NULL | Bounding box right coordinate (normalized) |
|
||||||
| y_max | REAL | NOT NULL | Bounding box bottom coordinate (normalized) |
|
| y_max | REAL | NOT NULL | Bounding box bottom coordinate (normalized) |
|
||||||
|
| segmentation_mask | TEXT | | JSON array of polygon coordinates [[x1,y1], [x2,y2], ...] (normalized 0-1) |
|
||||||
| annotator | TEXT | | Name of person who created annotation |
|
| annotator | TEXT | | Name of person who created annotation |
|
||||||
| created_at | TIMESTAMP | DEFAULT CURRENT_TIMESTAMP | Annotation timestamp |
|
| created_at | TIMESTAMP | DEFAULT CURRENT_TIMESTAMP | Annotation timestamp |
|
||||||
| verified | BOOLEAN | DEFAULT 0 | Whether annotation is verified |
|
| verified | BOOLEAN | DEFAULT 0 | Whether annotation is verified |
|
||||||
@@ -245,8 +249,9 @@ graph TB
|
|||||||
### Key Components
|
### Key Components
|
||||||
|
|
||||||
#### 1. **YOLO Wrapper** ([`src/model/yolo_wrapper.py`](src/model/yolo_wrapper.py))
|
#### 1. **YOLO Wrapper** ([`src/model/yolo_wrapper.py`](src/model/yolo_wrapper.py))
|
||||||
Encapsulates YOLOv8 operations:
|
Encapsulates YOLOv8-seg operations:
|
||||||
- Load pre-trained YOLOv8s model
|
- Load pre-trained YOLOv8s-seg segmentation model
|
||||||
|
- Extract pixel-accurate segmentation masks
|
||||||
- Fine-tune on custom microscopy dataset
|
- Fine-tune on custom microscopy dataset
|
||||||
- Export trained models
|
- Export trained models
|
||||||
- Provide training progress callbacks
|
- Provide training progress callbacks
|
||||||
@@ -255,10 +260,10 @@ Encapsulates YOLOv8 operations:
|
|||||||
**Key Methods:**
|
**Key Methods:**
|
||||||
```python
|
```python
|
||||||
class YOLOWrapper:
|
class YOLOWrapper:
|
||||||
def __init__(self, model_path: str = "yolov8s.pt")
|
def __init__(self, model_path: str = "yolov8s-seg.pt")
|
||||||
def train(self, data_yaml: str, epochs: int, callbacks: dict)
|
def train(self, data_yaml: str, epochs: int, callbacks: dict)
|
||||||
def validate(self, data_yaml: str) -> dict
|
def validate(self, data_yaml: str) -> dict
|
||||||
def predict(self, image_path: str, conf: float) -> list
|
def predict(self, image_path: str, conf: float) -> list # Returns detections with segmentation masks
|
||||||
def export_model(self, format: str, output_path: str)
|
def export_model(self, format: str, output_path: str)
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -435,7 +440,7 @@ image_repository:
|
|||||||
allowed_extensions: [".jpg", ".jpeg", ".png", ".tif", ".tiff"]
|
allowed_extensions: [".jpg", ".jpeg", ".png", ".tif", ".tiff"]
|
||||||
|
|
||||||
models:
|
models:
|
||||||
default_base_model: "yolov8s.pt"
|
default_base_model: "yolov8s-seg.pt"
|
||||||
models_directory: "data/models"
|
models_directory: "data/models"
|
||||||
|
|
||||||
training:
|
training:
|
||||||
|
|||||||
178
BUILD.md
Normal file
178
BUILD.md
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
# Building and Publishing Guide
|
||||||
|
|
||||||
|
This guide explains how to build and publish the microscopy-object-detection package.
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install build twine
|
||||||
|
```
|
||||||
|
|
||||||
|
## Building the Package
|
||||||
|
|
||||||
|
### 1. Clean Previous Builds
|
||||||
|
|
||||||
|
```bash
|
||||||
|
rm -rf build/ dist/ *.egg-info
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Build Distribution Archives
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m build
|
||||||
|
```
|
||||||
|
|
||||||
|
This will create both wheel (`.whl`) and source distribution (`.tar.gz`) in the `dist/` directory.
|
||||||
|
|
||||||
|
### 3. Verify the Build
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ls dist/
|
||||||
|
# Should show:
|
||||||
|
# microscopy_object_detection-1.0.0-py3-none-any.whl
|
||||||
|
# microscopy_object_detection-1.0.0.tar.gz
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing the Package Locally
|
||||||
|
|
||||||
|
### Install in Development Mode
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
### Install from Built Package
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install dist/microscopy_object_detection-1.0.0-py3-none-any.whl
|
||||||
|
```
|
||||||
|
|
||||||
|
### Test the Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Test CLI
|
||||||
|
microscopy-detect --version
|
||||||
|
|
||||||
|
# Test GUI launcher
|
||||||
|
microscopy-detect-gui
|
||||||
|
```
|
||||||
|
|
||||||
|
## Publishing to PyPI
|
||||||
|
|
||||||
|
### 1. Configure PyPI Credentials
|
||||||
|
|
||||||
|
Create or update `~/.pypirc`:
|
||||||
|
|
||||||
|
```ini
|
||||||
|
[pypi]
|
||||||
|
username = __token__
|
||||||
|
password = pypi-YOUR-API-TOKEN-HERE
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Upload to Test PyPI (Recommended First)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m twine upload --repository testpypi dist/*
|
||||||
|
```
|
||||||
|
|
||||||
|
Then test installation:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install --index-url https://test.pypi.org/simple/ microscopy-object-detection
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Upload to PyPI
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m twine upload dist/*
|
||||||
|
```
|
||||||
|
|
||||||
|
## Version Management
|
||||||
|
|
||||||
|
Update version in multiple files:
|
||||||
|
- `setup.py`: Update `version` parameter
|
||||||
|
- `pyproject.toml`: Update `version` field
|
||||||
|
- `src/__init__.py`: Update `__version__` variable
|
||||||
|
|
||||||
|
## Git Tags
|
||||||
|
|
||||||
|
After publishing, tag the release:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git tag -a v1.0.0 -m "Release version 1.0.0"
|
||||||
|
git push origin v1.0.0
|
||||||
|
```
|
||||||
|
|
||||||
|
## Package Structure
|
||||||
|
|
||||||
|
The built package includes:
|
||||||
|
- All Python source files in `src/`
|
||||||
|
- Configuration files in `config/`
|
||||||
|
- Database schema file (`src/database/schema.sql`)
|
||||||
|
- Documentation files (README.md, LICENSE, etc.)
|
||||||
|
- Entry points for CLI and GUI
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Import Errors
|
||||||
|
If you get import errors, ensure:
|
||||||
|
- All `__init__.py` files are present
|
||||||
|
- Package structure follows the setup configuration
|
||||||
|
- Dependencies are listed in `requirements.txt`
|
||||||
|
|
||||||
|
### Missing Files
|
||||||
|
If files are missing in the built package:
|
||||||
|
- Check `MANIFEST.in` includes the required patterns
|
||||||
|
- Check `pyproject.toml` package-data configuration
|
||||||
|
- Rebuild with `python -m build --no-isolation` for debugging
|
||||||
|
|
||||||
|
### Version Conflicts
|
||||||
|
If version conflicts occur:
|
||||||
|
- Ensure version is consistent across all files
|
||||||
|
- Clear build artifacts and rebuild
|
||||||
|
- Check for cached installations: `pip list | grep microscopy`
|
||||||
|
|
||||||
|
## CI/CD Integration
|
||||||
|
|
||||||
|
### GitHub Actions Example
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
name: Build and Publish
|
||||||
|
|
||||||
|
on:
|
||||||
|
release:
|
||||||
|
types: [created]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
deploy:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: '3.8'
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
pip install build twine
|
||||||
|
- name: Build package
|
||||||
|
run: python -m build
|
||||||
|
- name: Publish to PyPI
|
||||||
|
env:
|
||||||
|
TWINE_USERNAME: __token__
|
||||||
|
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
|
||||||
|
run: twine upload dist/*
|
||||||
|
```
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
1. **Version Bumping**: Use semantic versioning (MAJOR.MINOR.PATCH)
|
||||||
|
2. **Testing**: Always test on Test PyPI before publishing to PyPI
|
||||||
|
3. **Documentation**: Update README.md and CHANGELOG.md for each release
|
||||||
|
4. **Git Tags**: Tag releases in git for easy reference
|
||||||
|
5. **Dependencies**: Keep requirements.txt updated and specify version ranges
|
||||||
|
|
||||||
|
## Resources
|
||||||
|
|
||||||
|
- [Python Packaging Guide](https://packaging.python.org/)
|
||||||
|
- [setuptools Documentation](https://setuptools.pypa.io/)
|
||||||
|
- [PyPI Publishing Guide](https://packaging.python.org/tutorials/packaging-projects/)
|
||||||
236
INSTALL_TEST.md
Normal file
236
INSTALL_TEST.md
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
# Installation Testing Guide
|
||||||
|
|
||||||
|
This guide helps you verify that the package installation works correctly.
|
||||||
|
|
||||||
|
## Clean Installation Test
|
||||||
|
|
||||||
|
### 1. Remove Any Previous Installations
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Deactivate any active virtual environment
|
||||||
|
deactivate
|
||||||
|
|
||||||
|
# Remove old virtual environment (if exists)
|
||||||
|
rm -rf venv
|
||||||
|
|
||||||
|
# Create fresh virtual environment
|
||||||
|
python3 -m venv venv
|
||||||
|
source venv/bin/activate # On Linux/Mac
|
||||||
|
# or
|
||||||
|
venv\Scripts\activate # On Windows
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Install the Package
|
||||||
|
|
||||||
|
#### Option A: Editable/Development Install
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
This allows you to modify source code and see changes immediately.
|
||||||
|
|
||||||
|
#### Option B: Regular Install
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install .
|
||||||
|
```
|
||||||
|
|
||||||
|
This installs the package as if it were from PyPI.
|
||||||
|
|
||||||
|
### 3. Verify Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Check package is installed
|
||||||
|
pip list | grep microscopy
|
||||||
|
|
||||||
|
# Check version
|
||||||
|
microscopy-detect --version
|
||||||
|
# Expected output: microscopy-object-detection 1.0.0
|
||||||
|
|
||||||
|
# Test Python import
|
||||||
|
python -c "import src; print(src.__version__)"
|
||||||
|
# Expected output: 1.0.0
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Test Entry Points
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Test CLI
|
||||||
|
microscopy-detect --help
|
||||||
|
|
||||||
|
# Test GUI launcher (will open window)
|
||||||
|
microscopy-detect-gui
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. Verify Package Contents
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Run this in Python shell
|
||||||
|
import src
|
||||||
|
import src.database
|
||||||
|
import src.model
|
||||||
|
import src.gui
|
||||||
|
|
||||||
|
# Check schema file is included
|
||||||
|
from pathlib import Path
|
||||||
|
import src.database
|
||||||
|
db_path = Path(src.database.__file__).parent
|
||||||
|
schema_file = db_path / 'schema.sql'
|
||||||
|
print(f"Schema file exists: {schema_file.exists()}")
|
||||||
|
# Expected: Schema file exists: True
|
||||||
|
```
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Issue: ModuleNotFoundError
|
||||||
|
|
||||||
|
**Error:**
|
||||||
|
```
|
||||||
|
ModuleNotFoundError: No module named 'src'
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solution:**
|
||||||
|
```bash
|
||||||
|
# Reinstall with verbose output
|
||||||
|
pip install -e . -v
|
||||||
|
|
||||||
|
# Or try regular install
|
||||||
|
pip install . --force-reinstall
|
||||||
|
```
|
||||||
|
|
||||||
|
### Issue: Entry Points Not Working
|
||||||
|
|
||||||
|
**Error:**
|
||||||
|
```
|
||||||
|
microscopy-detect: command not found
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solution:**
|
||||||
|
```bash
|
||||||
|
# Check if scripts are in PATH
|
||||||
|
which microscopy-detect
|
||||||
|
|
||||||
|
# If not found, check pip install location
|
||||||
|
pip show microscopy-object-detection
|
||||||
|
|
||||||
|
# You might need to add to PATH or use full path
|
||||||
|
~/.local/bin/microscopy-detect # Linux
|
||||||
|
```
|
||||||
|
|
||||||
|
### Issue: Import Errors for PySide6
|
||||||
|
|
||||||
|
**Error:**
|
||||||
|
```
|
||||||
|
ImportError: cannot import name 'QApplication' from 'PySide6.QtWidgets'
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solution:**
|
||||||
|
```bash
|
||||||
|
# Install Qt dependencies (Linux only)
|
||||||
|
sudo apt-get install libxcb-xinerama0
|
||||||
|
|
||||||
|
# Reinstall PySide6
|
||||||
|
pip uninstall PySide6
|
||||||
|
pip install PySide6
|
||||||
|
```
|
||||||
|
|
||||||
|
### Issue: Config Files Not Found
|
||||||
|
|
||||||
|
**Error:**
|
||||||
|
```
|
||||||
|
FileNotFoundError: config/app_config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solution:**
|
||||||
|
The config file should be created automatically. If not:
|
||||||
|
```bash
|
||||||
|
# Create config directory in your home
|
||||||
|
mkdir -p ~/.microscopy-detect
|
||||||
|
cp config/app_config.yaml ~/.microscopy-detect/
|
||||||
|
|
||||||
|
# Or run from source directory first time
|
||||||
|
cd /home/martin/code/object_detection
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Manual Testing Checklist
|
||||||
|
|
||||||
|
- [ ] Package installs without errors
|
||||||
|
- [ ] Version command works (`microscopy-detect --version`)
|
||||||
|
- [ ] Help command works (`microscopy-detect --help`)
|
||||||
|
- [ ] GUI launches (`microscopy-detect-gui`)
|
||||||
|
- [ ] Can import all modules in Python
|
||||||
|
- [ ] Database schema file is accessible
|
||||||
|
- [ ] Configuration loads correctly
|
||||||
|
|
||||||
|
## Build and Install from Wheel
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build the package
|
||||||
|
python -m build
|
||||||
|
|
||||||
|
# Install from wheel
|
||||||
|
pip install dist/microscopy_object_detection-1.0.0-py3-none-any.whl
|
||||||
|
|
||||||
|
# Test
|
||||||
|
microscopy-detect --version
|
||||||
|
```
|
||||||
|
|
||||||
|
## Uninstall
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip uninstall microscopy-object-detection
|
||||||
|
```
|
||||||
|
|
||||||
|
## Development Workflow
|
||||||
|
|
||||||
|
### After Code Changes
|
||||||
|
|
||||||
|
If installed with `-e` (editable mode):
|
||||||
|
- Python code changes are immediately available
|
||||||
|
- No need to reinstall
|
||||||
|
|
||||||
|
If installed with regular `pip install .`:
|
||||||
|
- Reinstall after changes: `pip install . --force-reinstall`
|
||||||
|
|
||||||
|
### After Adding New Files
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Reinstall to include new files
|
||||||
|
pip install -e . --force-reinstall
|
||||||
|
```
|
||||||
|
|
||||||
|
## Expected Installation Output
|
||||||
|
|
||||||
|
```
|
||||||
|
Processing /home/martin/code/object_detection
|
||||||
|
Installing build dependencies ... done
|
||||||
|
Getting requirements to build wheel ... done
|
||||||
|
Preparing metadata (pyproject.toml) ... done
|
||||||
|
Building wheels for collected packages: microscopy-object-detection
|
||||||
|
Building wheel for microscopy-object-detection (pyproject.toml) ... done
|
||||||
|
Successfully built microscopy-object-detection
|
||||||
|
Installing collected packages: microscopy-object-detection
|
||||||
|
Successfully installed microscopy-object-detection-1.0.0
|
||||||
|
```
|
||||||
|
|
||||||
|
## Success Criteria
|
||||||
|
|
||||||
|
Installation is successful when:
|
||||||
|
1. ✅ No error messages during installation
|
||||||
|
2. ✅ `pip list` shows the package
|
||||||
|
3. ✅ `microscopy-detect --version` returns correct version
|
||||||
|
4. ✅ GUI launches without errors
|
||||||
|
5. ✅ All Python modules can be imported
|
||||||
|
6. ✅ Database operations work
|
||||||
|
7. ✅ Detection functionality works
|
||||||
|
|
||||||
|
## Next Steps After Successful Install
|
||||||
|
|
||||||
|
1. Configure image repository path
|
||||||
|
2. Run first detection
|
||||||
|
3. Train a custom model
|
||||||
|
4. Export results
|
||||||
|
|
||||||
|
For usage instructions, see [QUICKSTART.md](QUICKSTART.md)
|
||||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2024 Your Name
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
37
MANIFEST.in
Normal file
37
MANIFEST.in
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
# Include documentation files
|
||||||
|
include README.md
|
||||||
|
include LICENSE
|
||||||
|
include ARCHITECTURE.md
|
||||||
|
include IMPLEMENTATION_GUIDE.md
|
||||||
|
include QUICKSTART.md
|
||||||
|
include PLAN_SUMMARY.md
|
||||||
|
|
||||||
|
# Include requirements
|
||||||
|
include requirements.txt
|
||||||
|
|
||||||
|
# Include configuration files
|
||||||
|
recursive-include config *.yaml
|
||||||
|
recursive-include config *.yml
|
||||||
|
|
||||||
|
# Include database schema
|
||||||
|
recursive-include src/database *.sql
|
||||||
|
|
||||||
|
# Include tests
|
||||||
|
recursive-include tests *.py
|
||||||
|
|
||||||
|
# Exclude compiled Python files
|
||||||
|
global-exclude *.pyc
|
||||||
|
global-exclude *.pyo
|
||||||
|
global-exclude __pycache__
|
||||||
|
global-exclude *.so
|
||||||
|
global-exclude .DS_Store
|
||||||
|
|
||||||
|
# Exclude git and IDE files
|
||||||
|
global-exclude .git*
|
||||||
|
global-exclude .vscode
|
||||||
|
global-exclude .idea
|
||||||
|
|
||||||
|
# Exclude build artifacts
|
||||||
|
prune build
|
||||||
|
prune dist
|
||||||
|
prune *.egg-info
|
||||||
@@ -38,7 +38,7 @@ This will install:
|
|||||||
- OpenCV and Pillow (image processing)
|
- OpenCV and Pillow (image processing)
|
||||||
- And other dependencies
|
- And other dependencies
|
||||||
|
|
||||||
**Note:** The first run will automatically download the YOLOv8s.pt model (~22MB).
|
**Note:** The first run will automatically download the YOLOv8s-seg.pt segmentation model (~23MB).
|
||||||
|
|
||||||
### 4. Verify Installation
|
### 4. Verify Installation
|
||||||
|
|
||||||
@@ -84,11 +84,11 @@ In the Settings dialog:
|
|||||||
### Single Image Detection
|
### Single Image Detection
|
||||||
|
|
||||||
1. Go to the **Detection** tab
|
1. Go to the **Detection** tab
|
||||||
2. Select a model from the dropdown (default: Base Model yolov8s.pt)
|
2. Select a model from the dropdown (default: Base Model yolov8s-seg.pt)
|
||||||
3. Adjust confidence threshold with the slider
|
3. Adjust confidence threshold with the slider
|
||||||
4. Click "Detect Single Image"
|
4. Click "Detect Single Image"
|
||||||
5. Select an image file
|
5. Select an image file
|
||||||
6. View results in the results panel
|
6. View results with segmentation masks overlaid on the image
|
||||||
|
|
||||||
### Batch Detection
|
### Batch Detection
|
||||||
|
|
||||||
@@ -108,9 +108,18 @@ Detection results include:
|
|||||||
- **Class names**: Types of objects detected (e.g., organelle, membrane_branch)
|
- **Class names**: Types of objects detected (e.g., organelle, membrane_branch)
|
||||||
- **Confidence scores**: Detection confidence (0-1)
|
- **Confidence scores**: Detection confidence (0-1)
|
||||||
- **Bounding boxes**: Object locations (stored in database)
|
- **Bounding boxes**: Object locations (stored in database)
|
||||||
|
- **Segmentation masks**: Pixel-accurate polygon coordinates for each detected object
|
||||||
|
|
||||||
All results are stored in the SQLite database at [`data/detections.db`](data/detections.db).
|
All results are stored in the SQLite database at [`data/detections.db`](data/detections.db).
|
||||||
|
|
||||||
|
### Segmentation Visualization
|
||||||
|
|
||||||
|
The application automatically displays segmentation masks when available:
|
||||||
|
- Semi-transparent colored overlay (30% opacity) showing the exact shape of detected objects
|
||||||
|
- Polygon contours outlining each segmentation
|
||||||
|
- Color-coded by object class
|
||||||
|
- Toggle-able in future versions
|
||||||
|
|
||||||
## Database
|
## Database
|
||||||
|
|
||||||
The application uses SQLite to store:
|
The application uses SQLite to store:
|
||||||
@@ -176,7 +185,7 @@ sudo apt-get install libxcb-xinerama0
|
|||||||
### Detection Not Working
|
### Detection Not Working
|
||||||
|
|
||||||
**No models available**
|
**No models available**
|
||||||
- The base YOLOv8s model will be downloaded automatically on first use
|
- The base YOLOv8s-seg segmentation model will be downloaded automatically on first use
|
||||||
- Make sure you have internet connection for the first run
|
- Make sure you have internet connection for the first run
|
||||||
|
|
||||||
**Images not found**
|
**Images not found**
|
||||||
|
|||||||
62
README.md
62
README.md
@@ -1,6 +1,6 @@
|
|||||||
# Microscopy Object Detection Application
|
# Microscopy Object Detection Application
|
||||||
|
|
||||||
A desktop application for detecting organelles and membrane branching structures in microscopy images using YOLOv8, featuring comprehensive training, validation, and visualization capabilities.
|
A desktop application for detecting and segmenting organelles and membrane branching structures in microscopy images using YOLOv8-seg, featuring comprehensive training, validation, and visualization capabilities with pixel-accurate segmentation masks.
|
||||||
|
|
||||||

|

|
||||||

|

|
||||||
@@ -8,8 +8,8 @@ A desktop application for detecting organelles and membrane branching structures
|
|||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- **🎯 Object Detection**: Real-time and batch detection of microscopy objects
|
- **🎯 Object Detection & Segmentation**: Real-time and batch detection with pixel-accurate segmentation masks
|
||||||
- **🎓 Model Training**: Fine-tune YOLOv8s on custom microscopy datasets
|
- **🎓 Model Training**: Fine-tune YOLOv8s-seg on custom microscopy datasets
|
||||||
- **📊 Validation & Metrics**: Comprehensive model validation with visualization
|
- **📊 Validation & Metrics**: Comprehensive model validation with visualization
|
||||||
- **💾 Database Storage**: SQLite database for detection results and metadata
|
- **💾 Database Storage**: SQLite database for detection results and metadata
|
||||||
- **📈 Visualization**: Interactive plots and charts using pyqtgraph
|
- **📈 Visualization**: Interactive plots and charts using pyqtgraph
|
||||||
@@ -34,14 +34,24 @@ A desktop application for detecting organelles and membrane branching structures
|
|||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
### 1. Clone the Repository
|
### Option 1: Install from PyPI (Recommended)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install microscopy-object-detection
|
||||||
|
```
|
||||||
|
|
||||||
|
This will install the package and all its dependencies.
|
||||||
|
|
||||||
|
### Option 2: Install from Source
|
||||||
|
|
||||||
|
#### 1. Clone the Repository
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone <repository-url>
|
git clone <repository-url>
|
||||||
cd object_detection
|
cd object_detection
|
||||||
```
|
```
|
||||||
|
|
||||||
### 2. Create Virtual Environment
|
#### 2. Create Virtual Environment
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Linux/Mac
|
# Linux/Mac
|
||||||
@@ -53,25 +63,44 @@ python -m venv venv
|
|||||||
venv\Scripts\activate
|
venv\Scripts\activate
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3. Install Dependencies
|
#### 3. Install in Development Mode
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install -r requirements.txt
|
# Install in editable mode with dev dependencies
|
||||||
|
pip install -e ".[dev]"
|
||||||
|
|
||||||
|
# Or install just the package
|
||||||
|
pip install .
|
||||||
```
|
```
|
||||||
|
|
||||||
### 4. Download Base Model
|
### 4. Download Base Model
|
||||||
|
|
||||||
The application will automatically download the YOLOv8s.pt model on first use, or you can download it manually:
|
The application will automatically download the YOLOv8s-seg.pt segmentation model on first use, or you can download it manually:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# The model will be downloaded automatically by ultralytics
|
# The model will be downloaded automatically by ultralytics
|
||||||
# Or download manually from: https://github.com/ultralytics/assets/releases
|
# Or download manually from: https://github.com/ultralytics/assets/releases
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Note:** YOLOv8s-seg is a segmentation model that provides pixel-accurate masks for detected objects, enabling more precise analysis than standard bounding box detection.
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
### 1. Launch the Application
|
### 1. Launch the Application
|
||||||
|
|
||||||
|
After installation, you can launch the application in two ways:
|
||||||
|
|
||||||
|
**Using the GUI launcher:**
|
||||||
|
```bash
|
||||||
|
microscopy-detect-gui
|
||||||
|
```
|
||||||
|
|
||||||
|
**Or using Python directly:**
|
||||||
|
```bash
|
||||||
|
python -m microscopy_object_detection
|
||||||
|
```
|
||||||
|
|
||||||
|
**If installed from source:**
|
||||||
```bash
|
```bash
|
||||||
python main.py
|
python main.py
|
||||||
```
|
```
|
||||||
@@ -85,11 +114,12 @@ python main.py
|
|||||||
### 3. Perform Detection
|
### 3. Perform Detection
|
||||||
|
|
||||||
1. Navigate to the **Detection** tab
|
1. Navigate to the **Detection** tab
|
||||||
2. Select a model (default: yolov8s.pt)
|
2. Select a model (default: yolov8s-seg.pt)
|
||||||
3. Choose an image or folder
|
3. Choose an image or folder
|
||||||
4. Set confidence threshold
|
4. Set confidence threshold
|
||||||
5. Click **Detect**
|
5. Click **Detect**
|
||||||
6. View results and save to database
|
6. View results with segmentation masks overlaid
|
||||||
|
7. Save results to database
|
||||||
|
|
||||||
### 4. Train Custom Model
|
### 4. Train Custom Model
|
||||||
|
|
||||||
@@ -212,8 +242,8 @@ The application uses SQLite with the following main tables:
|
|||||||
|
|
||||||
- **models**: Stores trained model information and metrics
|
- **models**: Stores trained model information and metrics
|
||||||
- **images**: Stores image metadata and paths
|
- **images**: Stores image metadata and paths
|
||||||
- **detections**: Stores detection results with bounding boxes
|
- **detections**: Stores detection results with bounding boxes and segmentation masks (polygon coordinates)
|
||||||
- **annotations**: Stores manual annotations (future feature)
|
- **annotations**: Stores manual annotations with optional segmentation masks (future feature)
|
||||||
|
|
||||||
See [`ARCHITECTURE.md`](ARCHITECTURE.md) for detailed schema information.
|
See [`ARCHITECTURE.md`](ARCHITECTURE.md) for detailed schema information.
|
||||||
|
|
||||||
@@ -230,7 +260,7 @@ image_repository:
|
|||||||
allowed_extensions: [".jpg", ".jpeg", ".png", ".tif", ".tiff"]
|
allowed_extensions: [".jpg", ".jpeg", ".png", ".tif", ".tiff"]
|
||||||
|
|
||||||
models:
|
models:
|
||||||
default_base_model: "yolov8s.pt"
|
default_base_model: "yolov8s-seg.pt"
|
||||||
models_directory: "data/models"
|
models_directory: "data/models"
|
||||||
|
|
||||||
training:
|
training:
|
||||||
@@ -258,7 +288,7 @@ visualization:
|
|||||||
from src.model.yolo_wrapper import YOLOWrapper
|
from src.model.yolo_wrapper import YOLOWrapper
|
||||||
|
|
||||||
# Initialize wrapper
|
# Initialize wrapper
|
||||||
yolo = YOLOWrapper("yolov8s.pt")
|
yolo = YOLOWrapper("yolov8s-seg.pt")
|
||||||
|
|
||||||
# Train model
|
# Train model
|
||||||
results = yolo.train(
|
results = yolo.train(
|
||||||
@@ -393,10 +423,10 @@ make html
|
|||||||
|
|
||||||
**Issue**: Model not found error
|
**Issue**: Model not found error
|
||||||
|
|
||||||
**Solution**: Ensure YOLOv8s.pt is downloaded. Run:
|
**Solution**: Ensure YOLOv8s-seg.pt is downloaded. Run:
|
||||||
```python
|
```python
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
model = YOLO('yolov8s.pt') # Will auto-download
|
model = YOLO('yolov8s-seg.pt') # Will auto-download
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,48 +0,0 @@
|
|||||||
database:
|
|
||||||
path: "data/detections.db"
|
|
||||||
|
|
||||||
image_repository:
|
|
||||||
base_path: "" # Set by user through GUI
|
|
||||||
allowed_extensions:
|
|
||||||
- ".jpg"
|
|
||||||
- ".jpeg"
|
|
||||||
- ".png"
|
|
||||||
- ".tif"
|
|
||||||
- ".tiff"
|
|
||||||
- ".bmp"
|
|
||||||
|
|
||||||
models:
|
|
||||||
default_base_model: "yolov8s.pt"
|
|
||||||
models_directory: "data/models"
|
|
||||||
|
|
||||||
training:
|
|
||||||
default_epochs: 100
|
|
||||||
default_batch_size: 16
|
|
||||||
default_imgsz: 640
|
|
||||||
default_patience: 50
|
|
||||||
default_lr0: 0.01
|
|
||||||
|
|
||||||
detection:
|
|
||||||
default_confidence: 0.25
|
|
||||||
default_iou: 0.45
|
|
||||||
max_batch_size: 100
|
|
||||||
|
|
||||||
visualization:
|
|
||||||
bbox_colors:
|
|
||||||
organelle: "#FF6B6B"
|
|
||||||
membrane_branch: "#4ECDC4"
|
|
||||||
default: "#00FF00"
|
|
||||||
bbox_thickness: 2
|
|
||||||
font_size: 12
|
|
||||||
|
|
||||||
export:
|
|
||||||
formats:
|
|
||||||
- csv
|
|
||||||
- json
|
|
||||||
- excel
|
|
||||||
default_format: "csv"
|
|
||||||
|
|
||||||
logging:
|
|
||||||
level: "INFO"
|
|
||||||
file: "logs/app.log"
|
|
||||||
format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
||||||
220
docs/IMAGE_CLASS_USAGE.md
Normal file
220
docs/IMAGE_CLASS_USAGE.md
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
# Image Class Usage Guide
|
||||||
|
|
||||||
|
The `Image` class provides a convenient way to load and work with images in the microscopy object detection application.
|
||||||
|
|
||||||
|
## Supported Formats
|
||||||
|
|
||||||
|
The Image class supports the following image formats:
|
||||||
|
- `.jpg`, `.jpeg` - JPEG images
|
||||||
|
- `.png` - PNG images
|
||||||
|
- `.tif`, `.tiff` - TIFF images (commonly used in microscopy)
|
||||||
|
- `.bmp` - Bitmap images
|
||||||
|
|
||||||
|
## Basic Usage
|
||||||
|
|
||||||
|
### Loading an Image
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.utils import Image, ImageLoadError
|
||||||
|
|
||||||
|
# Load an image from a file path
|
||||||
|
try:
|
||||||
|
img = Image("path/to/image.jpg")
|
||||||
|
print(f"Loaded image: {img.width}x{img.height} pixels")
|
||||||
|
except ImageLoadError as e:
|
||||||
|
print(f"Failed to load image: {e}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Accessing Image Properties
|
||||||
|
|
||||||
|
```python
|
||||||
|
img = Image("microscopy_image.tif")
|
||||||
|
|
||||||
|
# Basic properties
|
||||||
|
print(f"Width: {img.width} pixels")
|
||||||
|
print(f"Height: {img.height} pixels")
|
||||||
|
print(f"Channels: {img.channels}")
|
||||||
|
print(f"Format: {img.format}")
|
||||||
|
print(f"Shape: {img.shape}") # (height, width, channels)
|
||||||
|
|
||||||
|
# File information
|
||||||
|
print(f"File size: {img.size_mb:.2f} MB")
|
||||||
|
print(f"File size: {img.size_bytes} bytes")
|
||||||
|
|
||||||
|
# Image type checks
|
||||||
|
print(f"Is color: {img.is_color()}")
|
||||||
|
print(f"Is grayscale: {img.is_grayscale()}")
|
||||||
|
|
||||||
|
# String representation
|
||||||
|
print(img) # Shows summary of image properties
|
||||||
|
```
|
||||||
|
|
||||||
|
### Working with Image Data
|
||||||
|
|
||||||
|
```python
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
img = Image("sample.png")
|
||||||
|
|
||||||
|
# Get image data as numpy array (OpenCV format, BGR)
|
||||||
|
bgr_data = img.data
|
||||||
|
print(f"Data shape: {bgr_data.shape}")
|
||||||
|
print(f"Data type: {bgr_data.dtype}")
|
||||||
|
|
||||||
|
# Get image as RGB (for display or processing)
|
||||||
|
rgb_data = img.get_rgb()
|
||||||
|
|
||||||
|
# Get grayscale version
|
||||||
|
gray_data = img.get_grayscale()
|
||||||
|
|
||||||
|
# Create a copy (for modifications)
|
||||||
|
img_copy = img.copy()
|
||||||
|
img_copy[0, 0] = [255, 255, 255] # Modify copy, original unchanged
|
||||||
|
|
||||||
|
# Resize image (returns new array, doesn't modify original)
|
||||||
|
resized = img.resize(640, 640)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Using PIL Image
|
||||||
|
|
||||||
|
```python
|
||||||
|
img = Image("photo.jpg")
|
||||||
|
|
||||||
|
# Access as PIL Image (RGB format)
|
||||||
|
pil_img = img.pil_image
|
||||||
|
|
||||||
|
# Use PIL methods
|
||||||
|
pil_img.show() # Display image
|
||||||
|
pil_img.save("output.png") # Save with PIL
|
||||||
|
```
|
||||||
|
|
||||||
|
## Integration with YOLO
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.utils import Image
|
||||||
|
from ultralytics import YOLO
|
||||||
|
|
||||||
|
# Load model and image
|
||||||
|
model = YOLO("yolov8n.pt")
|
||||||
|
img = Image("microscopy/cell_01.tif")
|
||||||
|
|
||||||
|
# Run inference (YOLO accepts file paths or numpy arrays)
|
||||||
|
results = model(img.data)
|
||||||
|
|
||||||
|
# Or use the file path directly
|
||||||
|
results = model(str(img.path))
|
||||||
|
```
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.utils import Image, ImageLoadError
|
||||||
|
|
||||||
|
def process_image(image_path):
|
||||||
|
try:
|
||||||
|
img = Image(image_path)
|
||||||
|
# Process the image...
|
||||||
|
return img
|
||||||
|
except ImageLoadError as e:
|
||||||
|
print(f"Cannot load image: {e}")
|
||||||
|
return None
|
||||||
|
```
|
||||||
|
|
||||||
|
## Advanced Usage
|
||||||
|
|
||||||
|
### Batch Processing
|
||||||
|
|
||||||
|
```python
|
||||||
|
from pathlib import Path
|
||||||
|
from src.utils import Image, ImageLoadError
|
||||||
|
|
||||||
|
def process_image_directory(directory):
|
||||||
|
"""Process all images in a directory."""
|
||||||
|
image_paths = Path(directory).glob("*.tif")
|
||||||
|
|
||||||
|
for path in image_paths:
|
||||||
|
try:
|
||||||
|
img = Image(path)
|
||||||
|
print(f"Processing {img.path.name}: {img.width}x{img.height}")
|
||||||
|
# Process the image...
|
||||||
|
except ImageLoadError as e:
|
||||||
|
print(f"Skipping {path}: {e}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Using with OpenCV Operations
|
||||||
|
|
||||||
|
```python
|
||||||
|
import cv2
|
||||||
|
from src.utils import Image
|
||||||
|
|
||||||
|
img = Image("input.jpg")
|
||||||
|
|
||||||
|
# Apply OpenCV operations on the data
|
||||||
|
blurred = cv2.GaussianBlur(img.data, (5, 5), 0)
|
||||||
|
edges = cv2.Canny(img.data, 100, 200)
|
||||||
|
|
||||||
|
# Note: These operations don't modify the original img.data
|
||||||
|
```
|
||||||
|
|
||||||
|
### Memory Efficient Processing
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.utils import Image
|
||||||
|
|
||||||
|
# The Image class loads data into memory
|
||||||
|
img = Image("large_image.tif")
|
||||||
|
print(f"Image size in memory: {img.data.nbytes / (1024**2):.2f} MB")
|
||||||
|
|
||||||
|
# When processing many images, consider loading one at a time
|
||||||
|
# and releasing memory by deleting the object
|
||||||
|
del img
|
||||||
|
```
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
1. **Always use try-except** when loading images to handle errors gracefully
|
||||||
|
2. **Check image properties** before processing to ensure compatibility
|
||||||
|
3. **Use copy()** when you need to modify image data without affecting the original
|
||||||
|
4. **Path objects work too** - The class accepts both strings and Path objects
|
||||||
|
5. **Consider memory usage** when working with large images or batches
|
||||||
|
|
||||||
|
## Example: Complete Workflow
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.utils import Image, ImageLoadError
|
||||||
|
from src.utils.file_utils import get_image_files
|
||||||
|
|
||||||
|
def analyze_microscopy_images(directory):
|
||||||
|
"""Analyze all microscopy images in a directory."""
|
||||||
|
|
||||||
|
# Get all image files
|
||||||
|
image_files = get_image_files(directory, recursive=True)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for image_path in image_files:
|
||||||
|
try:
|
||||||
|
# Load image
|
||||||
|
img = Image(image_path)
|
||||||
|
|
||||||
|
# Analyze
|
||||||
|
result = {
|
||||||
|
'filename': img.path.name,
|
||||||
|
'width': img.width,
|
||||||
|
'height': img.height,
|
||||||
|
'channels': img.channels,
|
||||||
|
'format': img.format,
|
||||||
|
'size_mb': img.size_mb,
|
||||||
|
'is_color': img.is_color()
|
||||||
|
}
|
||||||
|
|
||||||
|
results.append(result)
|
||||||
|
print(f"✓ Analyzed {img.path.name}")
|
||||||
|
|
||||||
|
except ImageLoadError as e:
|
||||||
|
print(f"✗ Failed to load {image_path}: {e}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
# Run analysis
|
||||||
|
results = analyze_microscopy_images("data/datasets/cells")
|
||||||
|
print(f"\nProcessed {len(results)} images")
|
||||||
151
examples/image_demo.py
Normal file
151
examples/image_demo.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
"""
|
||||||
|
Example script demonstrating the Image class functionality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
from src.utils import Image, ImageLoadError
|
||||||
|
|
||||||
|
|
||||||
|
def demonstrate_image_loading():
|
||||||
|
"""Demonstrate basic image loading functionality."""
|
||||||
|
print("=" * 60)
|
||||||
|
print("Image Class Demonstration")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Example 1: Try to load an image (replace with your own path)
|
||||||
|
example_paths = [
|
||||||
|
"data/datasets/example.jpg",
|
||||||
|
"data/datasets/sample.png",
|
||||||
|
"tests/test_image.jpg",
|
||||||
|
]
|
||||||
|
|
||||||
|
loaded_img = None
|
||||||
|
for image_path in example_paths:
|
||||||
|
if Path(image_path).exists():
|
||||||
|
try:
|
||||||
|
print(f"\n1. Loading image: {image_path}")
|
||||||
|
img = Image(image_path)
|
||||||
|
loaded_img = img
|
||||||
|
print(f" ✓ Successfully loaded!")
|
||||||
|
print(f" {img}")
|
||||||
|
break
|
||||||
|
except ImageLoadError as e:
|
||||||
|
print(f" ✗ Failed: {e}")
|
||||||
|
else:
|
||||||
|
print(f"\n1. Image not found: {image_path}")
|
||||||
|
|
||||||
|
if loaded_img is None:
|
||||||
|
print("\nNo example images found. Creating a test image...")
|
||||||
|
create_test_image()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Example 2: Access image properties
|
||||||
|
print(f"\n2. Image Properties:")
|
||||||
|
print(f" Width: {loaded_img.width} pixels")
|
||||||
|
print(f" Height: {loaded_img.height} pixels")
|
||||||
|
print(f" Channels: {loaded_img.channels}")
|
||||||
|
print(f" Format: {loaded_img.format.upper()}")
|
||||||
|
print(f" Shape: {loaded_img.shape}")
|
||||||
|
print(f" File size: {loaded_img.size_mb:.2f} MB")
|
||||||
|
print(f" Is color: {loaded_img.is_color()}")
|
||||||
|
print(f" Is grayscale: {loaded_img.is_grayscale()}")
|
||||||
|
|
||||||
|
# Example 3: Get different formats
|
||||||
|
print(f"\n3. Accessing Image Data:")
|
||||||
|
print(f" BGR data shape: {loaded_img.data.shape}")
|
||||||
|
print(f" RGB data shape: {loaded_img.get_rgb().shape}")
|
||||||
|
print(f" Grayscale shape: {loaded_img.get_grayscale().shape}")
|
||||||
|
print(f" PIL image mode: {loaded_img.pil_image.mode}")
|
||||||
|
|
||||||
|
# Example 4: Resizing
|
||||||
|
print(f"\n4. Resizing Image:")
|
||||||
|
resized = loaded_img.resize(320, 320)
|
||||||
|
print(f" Original size: {loaded_img.width}x{loaded_img.height}")
|
||||||
|
print(f" Resized to: {resized.shape[1]}x{resized.shape[0]}")
|
||||||
|
|
||||||
|
# Example 5: Working with copies
|
||||||
|
print(f"\n5. Creating Copies:")
|
||||||
|
copy = loaded_img.copy()
|
||||||
|
print(f" Created copy with shape: {copy.shape}")
|
||||||
|
print(f" Original data unchanged: {(loaded_img.data == copy).all()}")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Demonstration Complete!")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_image():
|
||||||
|
"""Create a test image for demonstration purposes."""
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
print("\nCreating a test image...")
|
||||||
|
|
||||||
|
# Create a colorful test image
|
||||||
|
width, height = 400, 300
|
||||||
|
test_img = np.zeros((height, width, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
# Add some colors
|
||||||
|
test_img[:100, :] = [255, 0, 0] # Blue section
|
||||||
|
test_img[100:200, :] = [0, 255, 0] # Green section
|
||||||
|
test_img[200:, :] = [0, 0, 255] # Red section
|
||||||
|
|
||||||
|
# Save the test image
|
||||||
|
test_path = Path("test_demo_image.png")
|
||||||
|
cv2.imwrite(str(test_path), test_img)
|
||||||
|
print(f"Test image created: {test_path}")
|
||||||
|
|
||||||
|
# Now load and demonstrate with it
|
||||||
|
try:
|
||||||
|
img = Image(test_path)
|
||||||
|
print(f"\nLoaded test image: {img}")
|
||||||
|
print(f"Dimensions: {img.width}x{img.height}")
|
||||||
|
print(f"Channels: {img.channels}")
|
||||||
|
print(f"Format: {img.format}")
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
test_path.unlink()
|
||||||
|
print(f"\nTest image cleaned up.")
|
||||||
|
|
||||||
|
except ImageLoadError as e:
|
||||||
|
print(f"Error loading test image: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def demonstrate_error_handling():
|
||||||
|
"""Demonstrate error handling."""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Error Handling Demonstration")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Try to load non-existent file
|
||||||
|
print("\n1. Loading non-existent file:")
|
||||||
|
try:
|
||||||
|
img = Image("nonexistent.jpg")
|
||||||
|
except ImageLoadError as e:
|
||||||
|
print(f" ✓ Caught error: {e}")
|
||||||
|
|
||||||
|
# Try unsupported format
|
||||||
|
print("\n2. Loading unsupported format:")
|
||||||
|
try:
|
||||||
|
# Create a text file
|
||||||
|
test_file = Path("test.txt")
|
||||||
|
test_file.write_text("not an image")
|
||||||
|
img = Image(test_file)
|
||||||
|
except ImageLoadError as e:
|
||||||
|
print(f" ✓ Caught error: {e}")
|
||||||
|
test_file.unlink() # Clean up
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("\n")
|
||||||
|
demonstrate_image_loading()
|
||||||
|
print("\n")
|
||||||
|
demonstrate_error_handling()
|
||||||
|
print("\n")
|
||||||
5
main.py
5
main.py
@@ -6,12 +6,13 @@ Main entry point for the application.
|
|||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
# Add src directory to path
|
# Add src directory to path for development mode
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
from PySide6.QtWidgets import QApplication
|
from PySide6.QtWidgets import QApplication
|
||||||
from PySide6.QtCore import Qt
|
from PySide6.QtCore import Qt
|
||||||
|
|
||||||
|
from src import __version__
|
||||||
from src.gui.main_window import MainWindow
|
from src.gui.main_window import MainWindow
|
||||||
from src.utils.logger import setup_logging
|
from src.utils.logger import setup_logging
|
||||||
from src.utils.config_manager import ConfigManager
|
from src.utils.config_manager import ConfigManager
|
||||||
@@ -37,7 +38,7 @@ def main():
|
|||||||
app = QApplication(sys.argv)
|
app = QApplication(sys.argv)
|
||||||
app.setApplicationName("Microscopy Object Detection")
|
app.setApplicationName("Microscopy Object Detection")
|
||||||
app.setOrganizationName("MicroscopyLab")
|
app.setOrganizationName("MicroscopyLab")
|
||||||
app.setApplicationVersion("1.0.0")
|
app.setApplicationVersion(__version__)
|
||||||
|
|
||||||
# Set application style
|
# Set application style
|
||||||
app.setStyle("Fusion")
|
app.setStyle("Fusion")
|
||||||
|
|||||||
102
pyproject.toml
Normal file
102
pyproject.toml
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=45", "wheel"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "microscopy-object-detection"
|
||||||
|
version = "1.0.0"
|
||||||
|
description = "Desktop application for detecting and segmenting organelles in microscopy images using YOLOv8-seg"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.8"
|
||||||
|
license = { text = "MIT" }
|
||||||
|
authors = [{ name = "Your Name", email = "your.email@example.com" }]
|
||||||
|
keywords = [
|
||||||
|
"microscopy",
|
||||||
|
"yolov8",
|
||||||
|
"object-detection",
|
||||||
|
"segmentation",
|
||||||
|
"computer-vision",
|
||||||
|
"deep-learning",
|
||||||
|
]
|
||||||
|
classifiers = [
|
||||||
|
"Development Status :: 4 - Beta",
|
||||||
|
"Intended Audience :: Science/Research",
|
||||||
|
"Topic :: Scientific/Engineering :: Image Recognition",
|
||||||
|
"Topic :: Scientific/Engineering :: Bio-Informatics",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.8",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
|
"Operating System :: OS Independent",
|
||||||
|
]
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
"ultralytics>=8.0.0",
|
||||||
|
"PySide6>=6.5.0",
|
||||||
|
"pyqtgraph>=0.13.0",
|
||||||
|
"numpy>=1.24.0",
|
||||||
|
"opencv-python>=4.8.0",
|
||||||
|
"Pillow>=10.0.0",
|
||||||
|
"PyYAML>=6.0",
|
||||||
|
"pandas>=2.0.0",
|
||||||
|
"openpyxl>=3.1.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest>=7.0.0",
|
||||||
|
"pytest-cov>=4.0.0",
|
||||||
|
"black>=23.0.0",
|
||||||
|
"pylint>=2.17.0",
|
||||||
|
"mypy>=1.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
Homepage = "https://github.com/yourusername/object_detection"
|
||||||
|
Documentation = "https://github.com/yourusername/object_detection/blob/main/README.md"
|
||||||
|
Repository = "https://github.com/yourusername/object_detection"
|
||||||
|
"Bug Tracker" = "https://github.com/yourusername/object_detection/issues"
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
microscopy-detect = "src.cli:main"
|
||||||
|
|
||||||
|
[project.gui-scripts]
|
||||||
|
microscopy-detect-gui = "src.gui_launcher:main"
|
||||||
|
|
||||||
|
[tool.setuptools]
|
||||||
|
packages = [
|
||||||
|
"src",
|
||||||
|
"src.database",
|
||||||
|
"src.model",
|
||||||
|
"src.gui",
|
||||||
|
"src.gui.tabs",
|
||||||
|
"src.gui.dialogs",
|
||||||
|
"src.gui.widgets",
|
||||||
|
"src.utils",
|
||||||
|
]
|
||||||
|
include-package-data = true
|
||||||
|
|
||||||
|
[tool.setuptools.package-data]
|
||||||
|
"src.database" = ["*.sql"]
|
||||||
|
|
||||||
|
[tool.black]
|
||||||
|
line-length = 120
|
||||||
|
target-version = ['py38', 'py39', 'py310', 'py311']
|
||||||
|
include = '\.pyi?$'
|
||||||
|
|
||||||
|
[tool.pylint.messages_control]
|
||||||
|
max-line-length = 120
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
python_version = "3.8"
|
||||||
|
warn_return_any = true
|
||||||
|
warn_unused_configs = true
|
||||||
|
disallow_untyped_defs = false
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["tests"]
|
||||||
|
python_files = ["test_*.py"]
|
||||||
|
python_functions = ["test_*"]
|
||||||
|
addopts = "-v --cov=src --cov-report=term-missing"
|
||||||
56
setup.py
Normal file
56
setup.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""Setup script for Microscopy Object Detection Application."""
|
||||||
|
|
||||||
|
from setuptools import setup, find_packages
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Read the contents of README file
|
||||||
|
this_directory = Path(__file__).parent
|
||||||
|
long_description = (this_directory / "README.md").read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
# Read requirements
|
||||||
|
requirements = (this_directory / "requirements.txt").read_text().splitlines()
|
||||||
|
requirements = [
|
||||||
|
req.strip() for req in requirements if req.strip() and not req.startswith("#")
|
||||||
|
]
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="microscopy-object-detection",
|
||||||
|
version="1.0.0",
|
||||||
|
author="Your Name",
|
||||||
|
author_email="your.email@example.com",
|
||||||
|
description="Desktop application for detecting and segmenting organelles in microscopy images using YOLOv8-seg",
|
||||||
|
long_description=long_description,
|
||||||
|
long_description_content_type="text/markdown",
|
||||||
|
url="https://github.com/yourusername/object_detection",
|
||||||
|
packages=find_packages(exclude=["tests", "tests.*", "docs"]),
|
||||||
|
include_package_data=True,
|
||||||
|
install_requires=requirements,
|
||||||
|
python_requires=">=3.8",
|
||||||
|
classifiers=[
|
||||||
|
"Development Status :: 4 - Beta",
|
||||||
|
"Intended Audience :: Science/Research",
|
||||||
|
"Topic :: Scientific/Engineering :: Image Recognition",
|
||||||
|
"Topic :: Scientific/Engineering :: Bio-Informatics",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.8",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
|
"Operating System :: OS Independent",
|
||||||
|
],
|
||||||
|
entry_points={
|
||||||
|
"console_scripts": [
|
||||||
|
"microscopy-detect=src.cli:main",
|
||||||
|
],
|
||||||
|
"gui_scripts": [
|
||||||
|
"microscopy-detect-gui=src.gui_launcher:main",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
keywords="microscopy yolov8 object-detection segmentation computer-vision deep-learning",
|
||||||
|
project_urls={
|
||||||
|
"Bug Reports": "https://github.com/yourusername/object_detection/issues",
|
||||||
|
"Source": "https://github.com/yourusername/object_detection",
|
||||||
|
"Documentation": "https://github.com/yourusername/object_detection/blob/main/README.md",
|
||||||
|
},
|
||||||
|
)
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
"""
|
||||||
|
Microscopy Object Detection Application
|
||||||
|
|
||||||
|
A desktop application for detecting and segmenting organelles and membrane
|
||||||
|
branching structures in microscopy images using YOLOv8-seg.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__version__ = "1.0.0"
|
||||||
|
__author__ = "Your Name"
|
||||||
|
__email__ = "your.email@example.com"
|
||||||
|
__license__ = "MIT"
|
||||||
|
|
||||||
|
# Package metadata
|
||||||
|
__all__ = [
|
||||||
|
"__version__",
|
||||||
|
"__author__",
|
||||||
|
"__email__",
|
||||||
|
"__license__",
|
||||||
|
]
|
||||||
|
|||||||
61
src/cli.py
Normal file
61
src/cli.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
"""
|
||||||
|
Command-line interface for microscopy object detection application.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from src import __version__
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main CLI entry point."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Microscopy Object Detection Application - CLI Interface",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
epilog="""
|
||||||
|
Examples:
|
||||||
|
# Launch GUI
|
||||||
|
microscopy-detect-gui
|
||||||
|
|
||||||
|
# Show version
|
||||||
|
microscopy-detect --version
|
||||||
|
|
||||||
|
# Get help
|
||||||
|
microscopy-detect --help
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--version",
|
||||||
|
action="version",
|
||||||
|
version=f"microscopy-object-detection {__version__}",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--gui",
|
||||||
|
action="store_true",
|
||||||
|
help="Launch the GUI application (same as microscopy-detect-gui)",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.gui:
|
||||||
|
# Launch GUI
|
||||||
|
try:
|
||||||
|
from src.gui_launcher import main as gui_main
|
||||||
|
|
||||||
|
gui_main()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error launching GUI: {e}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
# Show help if no arguments provided
|
||||||
|
parser.print_help()
|
||||||
|
print("\nTo launch the GUI, use: microscopy-detect-gui")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
@@ -6,10 +6,18 @@ Handles all database operations including CRUD operations, queries, and exports.
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Dict, Optional, Tuple, Any
|
from typing import List, Dict, Optional, Tuple, Any, Union
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import csv
|
import csv
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from src.utils.logger import get_logger
|
||||||
|
from src.utils.image import Image
|
||||||
|
|
||||||
|
IMAGE_EXTENSIONS = tuple(Image.SUPPORTED_EXTENSIONS)
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DatabaseManager:
|
class DatabaseManager:
|
||||||
@@ -30,18 +38,46 @@ class DatabaseManager:
|
|||||||
# Create directory if it doesn't exist
|
# Create directory if it doesn't exist
|
||||||
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Read schema file and execute
|
|
||||||
schema_path = Path(__file__).parent / "schema.sql"
|
|
||||||
with open(schema_path, "r") as f:
|
|
||||||
schema_sql = f.read()
|
|
||||||
|
|
||||||
conn = self.get_connection()
|
conn = self.get_connection()
|
||||||
try:
|
try:
|
||||||
|
# Check if annotations table needs migration
|
||||||
|
self._migrate_annotations_table(conn)
|
||||||
|
|
||||||
|
# Read schema file and execute
|
||||||
|
schema_path = Path(__file__).parent / "schema.sql"
|
||||||
|
with open(schema_path, "r") as f:
|
||||||
|
schema_sql = f.read()
|
||||||
|
|
||||||
conn.executescript(schema_sql)
|
conn.executescript(schema_sql)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
def _migrate_annotations_table(self, conn: sqlite3.Connection) -> None:
|
||||||
|
"""
|
||||||
|
Migrate annotations table from old schema (class_name) to new schema (class_id).
|
||||||
|
"""
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# Check if annotations table exists
|
||||||
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='annotations'")
|
||||||
|
if not cursor.fetchone():
|
||||||
|
# Table doesn't exist yet, no migration needed
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if table has old schema (class_name column)
|
||||||
|
cursor.execute("PRAGMA table_info(annotations)")
|
||||||
|
columns = {row[1]: row for row in cursor.fetchall()}
|
||||||
|
|
||||||
|
if "class_name" in columns and "class_id" not in columns:
|
||||||
|
# Old schema detected, need to migrate
|
||||||
|
print("Migrating annotations table to new schema with class_id...")
|
||||||
|
|
||||||
|
# Drop old annotations table (assuming no critical data since this is a new feature)
|
||||||
|
cursor.execute("DROP TABLE IF EXISTS annotations")
|
||||||
|
conn.commit()
|
||||||
|
print("Old annotations table dropped, will be recreated with new schema")
|
||||||
|
|
||||||
def get_connection(self) -> sqlite3.Connection:
|
def get_connection(self) -> sqlite3.Connection:
|
||||||
"""Get database connection with proper settings."""
|
"""Get database connection with proper settings."""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
@@ -56,7 +92,7 @@ class DatabaseManager:
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
model_version: str,
|
model_version: str,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
base_model: str = "yolov8s.pt",
|
base_model: str = "yolov8s-seg.pt",
|
||||||
training_params: Optional[Dict] = None,
|
training_params: Optional[Dict] = None,
|
||||||
metrics: Optional[Dict] = None,
|
metrics: Optional[Dict] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
@@ -165,6 +201,28 @@ class DatabaseManager:
|
|||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
def delete_model(self, model_id: int) -> bool:
|
||||||
|
"""Delete a model from the database.
|
||||||
|
|
||||||
|
Note: detections referencing this model are deleted automatically via
|
||||||
|
the `detections.model_id` foreign key (ON DELETE CASCADE).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: ID of the model to delete.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if a model row was deleted, False otherwise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
conn = self.get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("DELETE FROM models WHERE id = ?", (model_id,))
|
||||||
|
conn.commit()
|
||||||
|
return cursor.rowcount > 0
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
# ==================== Image Operations ====================
|
# ==================== Image Operations ====================
|
||||||
|
|
||||||
def add_image(
|
def add_image(
|
||||||
@@ -204,9 +262,7 @@ class DatabaseManager:
|
|||||||
return cursor.lastrowid
|
return cursor.lastrowid
|
||||||
except sqlite3.IntegrityError:
|
except sqlite3.IntegrityError:
|
||||||
# Image already exists, return its ID
|
# Image already exists, return its ID
|
||||||
cursor.execute(
|
cursor.execute("SELECT id FROM images WHERE relative_path = ?", (relative_path,))
|
||||||
"SELECT id FROM images WHERE relative_path = ?", (relative_path,)
|
|
||||||
)
|
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
return row["id"] if row else None
|
return row["id"] if row else None
|
||||||
finally:
|
finally:
|
||||||
@@ -217,17 +273,13 @@ class DatabaseManager:
|
|||||||
conn = self.get_connection()
|
conn = self.get_connection()
|
||||||
try:
|
try:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute(
|
cursor.execute("SELECT * FROM images WHERE relative_path = ?", (relative_path,))
|
||||||
"SELECT * FROM images WHERE relative_path = ?", (relative_path,)
|
|
||||||
)
|
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
return dict(row) if row else None
|
return dict(row) if row else None
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def get_or_create_image(
|
def get_or_create_image(self, relative_path: str, filename: str, width: int, height: int) -> int:
|
||||||
self, relative_path: str, filename: str, width: int, height: int
|
|
||||||
) -> int:
|
|
||||||
"""Get existing image or create new one."""
|
"""Get existing image or create new one."""
|
||||||
existing = self.get_image_by_path(relative_path)
|
existing = self.get_image_by_path(relative_path)
|
||||||
if existing:
|
if existing:
|
||||||
@@ -243,6 +295,7 @@ class DatabaseManager:
|
|||||||
class_name: str,
|
class_name: str,
|
||||||
bbox: Tuple[float, float, float, float], # (x_min, y_min, x_max, y_max)
|
bbox: Tuple[float, float, float, float], # (x_min, y_min, x_max, y_max)
|
||||||
confidence: float,
|
confidence: float,
|
||||||
|
segmentation_mask: Optional[List[List[float]]] = None,
|
||||||
metadata: Optional[Dict] = None,
|
metadata: Optional[Dict] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
@@ -254,6 +307,7 @@ class DatabaseManager:
|
|||||||
class_name: Detected object class
|
class_name: Detected object class
|
||||||
bbox: Bounding box coordinates (normalized 0-1)
|
bbox: Bounding box coordinates (normalized 0-1)
|
||||||
confidence: Detection confidence score
|
confidence: Detection confidence score
|
||||||
|
segmentation_mask: Polygon coordinates for segmentation [[x1,y1], [x2,y2], ...]
|
||||||
metadata: Additional metadata
|
metadata: Additional metadata
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -265,8 +319,8 @@ class DatabaseManager:
|
|||||||
x_min, y_min, x_max, y_max = bbox
|
x_min, y_min, x_max, y_max = bbox
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, metadata)
|
INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, segmentation_mask, metadata)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
image_id,
|
image_id,
|
||||||
@@ -277,6 +331,7 @@ class DatabaseManager:
|
|||||||
x_max,
|
x_max,
|
||||||
y_max,
|
y_max,
|
||||||
confidence,
|
confidence,
|
||||||
|
json.dumps(segmentation_mask) if segmentation_mask else None,
|
||||||
json.dumps(metadata) if metadata else None,
|
json.dumps(metadata) if metadata else None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -302,8 +357,8 @@ class DatabaseManager:
|
|||||||
bbox = det["bbox"]
|
bbox = det["bbox"]
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, metadata)
|
INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, segmentation_mask, metadata)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
det["image_id"],
|
det["image_id"],
|
||||||
@@ -314,11 +369,8 @@ class DatabaseManager:
|
|||||||
bbox[2],
|
bbox[2],
|
||||||
bbox[3],
|
bbox[3],
|
||||||
det["confidence"],
|
det["confidence"],
|
||||||
(
|
(json.dumps(det.get("segmentation_mask")) if det.get("segmentation_mask") else None),
|
||||||
json.dumps(det.get("metadata"))
|
(json.dumps(det.get("metadata")) if det.get("metadata") else None),
|
||||||
if det.get("metadata")
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
@@ -363,15 +415,16 @@ class DatabaseManager:
|
|||||||
if filters:
|
if filters:
|
||||||
conditions = []
|
conditions = []
|
||||||
for key, value in filters.items():
|
for key, value in filters.items():
|
||||||
if (
|
if key.startswith("d.") or key.startswith("i.") or key.startswith("m."):
|
||||||
key.startswith("d.")
|
if "like" in value.lower():
|
||||||
or key.startswith("i.")
|
conditions.append(f"{key} LIKE ?")
|
||||||
or key.startswith("m.")
|
params.append(value.split(" ")[1])
|
||||||
):
|
else:
|
||||||
conditions.append(f"{key} = ?")
|
conditions.append(f"{key} = ?")
|
||||||
|
params.append(value)
|
||||||
else:
|
else:
|
||||||
conditions.append(f"d.{key} = ?")
|
conditions.append(f"d.{key} = ?")
|
||||||
params.append(value)
|
params.append(value)
|
||||||
query += " WHERE " + " AND ".join(conditions)
|
query += " WHERE " + " AND ".join(conditions)
|
||||||
|
|
||||||
query += " ORDER BY d.detected_at DESC"
|
query += " ORDER BY d.detected_at DESC"
|
||||||
@@ -385,24 +438,41 @@ class DatabaseManager:
|
|||||||
detections = []
|
detections = []
|
||||||
for row in cursor.fetchall():
|
for row in cursor.fetchall():
|
||||||
det = dict(row)
|
det = dict(row)
|
||||||
# Parse JSON metadata
|
# Parse JSON fields
|
||||||
if det.get("metadata"):
|
if det.get("metadata"):
|
||||||
det["metadata"] = json.loads(det["metadata"])
|
det["metadata"] = json.loads(det["metadata"])
|
||||||
|
if det.get("segmentation_mask"):
|
||||||
|
det["segmentation_mask"] = json.loads(det["segmentation_mask"])
|
||||||
detections.append(det)
|
detections.append(det)
|
||||||
|
|
||||||
return detections
|
return detections
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def get_detections_for_image(
|
def get_detections_for_image(self, image_id: int, model_id: Optional[int] = None) -> List[Dict]:
|
||||||
self, image_id: int, model_id: Optional[int] = None
|
|
||||||
) -> List[Dict]:
|
|
||||||
"""Get all detections for a specific image."""
|
"""Get all detections for a specific image."""
|
||||||
filters = {"image_id": image_id}
|
filters = {"image_id": image_id}
|
||||||
if model_id:
|
if model_id:
|
||||||
filters["model_id"] = model_id
|
filters["model_id"] = model_id
|
||||||
return self.get_detections(filters)
|
return self.get_detections(filters)
|
||||||
|
|
||||||
|
def delete_detections_for_image(self, image_id: int, model_id: Optional[int] = None) -> int:
|
||||||
|
"""Delete detections tied to a specific image and optional model."""
|
||||||
|
conn = self.get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
if model_id is not None:
|
||||||
|
cursor.execute(
|
||||||
|
"DELETE FROM detections WHERE image_id = ? AND model_id = ?",
|
||||||
|
(image_id, model_id),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cursor.execute("DELETE FROM detections WHERE image_id = ?", (image_id,))
|
||||||
|
conn.commit()
|
||||||
|
return cursor.rowcount
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
def delete_detections_for_model(self, model_id: int) -> int:
|
def delete_detections_for_model(self, model_id: int) -> int:
|
||||||
"""Delete all detections for a specific model."""
|
"""Delete all detections for a specific model."""
|
||||||
conn = self.get_connection()
|
conn = self.get_connection()
|
||||||
@@ -414,6 +484,22 @@ class DatabaseManager:
|
|||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
def delete_all_detections(self) -> int:
|
||||||
|
"""Delete all detections from the database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of rows deleted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
conn = self.get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("DELETE FROM detections")
|
||||||
|
conn.commit()
|
||||||
|
return cursor.rowcount
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
# ==================== Statistics Operations ====================
|
# ==================== Statistics Operations ====================
|
||||||
|
|
||||||
def get_detection_statistics(
|
def get_detection_statistics(
|
||||||
@@ -457,9 +543,7 @@ class DatabaseManager:
|
|||||||
""",
|
""",
|
||||||
params,
|
params,
|
||||||
)
|
)
|
||||||
class_counts = {
|
class_counts = {row["class_name"]: row["count"] for row in cursor.fetchall()}
|
||||||
row["class_name"]: row["count"] for row in cursor.fetchall()
|
|
||||||
}
|
|
||||||
|
|
||||||
# Average confidence
|
# Average confidence
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
@@ -516,9 +600,7 @@ class DatabaseManager:
|
|||||||
|
|
||||||
# ==================== Export Operations ====================
|
# ==================== Export Operations ====================
|
||||||
|
|
||||||
def export_detections_to_csv(
|
def export_detections_to_csv(self, output_path: str, filters: Optional[Dict] = None) -> bool:
|
||||||
self, output_path: str, filters: Optional[Dict] = None
|
|
||||||
) -> bool:
|
|
||||||
"""Export detections to CSV file."""
|
"""Export detections to CSV file."""
|
||||||
try:
|
try:
|
||||||
detections = self.get_detections(filters)
|
detections = self.get_detections(filters)
|
||||||
@@ -538,6 +620,7 @@ class DatabaseManager:
|
|||||||
"x_max",
|
"x_max",
|
||||||
"y_max",
|
"y_max",
|
||||||
"confidence",
|
"confidence",
|
||||||
|
"segmentation_mask",
|
||||||
"detected_at",
|
"detected_at",
|
||||||
]
|
]
|
||||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||||
@@ -545,6 +628,9 @@ class DatabaseManager:
|
|||||||
|
|
||||||
for det in detections:
|
for det in detections:
|
||||||
row = {k: det[k] for k in fieldnames if k in det}
|
row = {k: det[k] for k in fieldnames if k in det}
|
||||||
|
# Convert segmentation mask list to JSON string for CSV
|
||||||
|
if row.get("segmentation_mask") and isinstance(row["segmentation_mask"], list):
|
||||||
|
row["segmentation_mask"] = json.dumps(row["segmentation_mask"])
|
||||||
writer.writerow(row)
|
writer.writerow(row)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
@@ -552,9 +638,7 @@ class DatabaseManager:
|
|||||||
print(f"Error exporting to CSV: {e}")
|
print(f"Error exporting to CSV: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def export_detections_to_json(
|
def export_detections_to_json(self, output_path: str, filters: Optional[Dict] = None) -> bool:
|
||||||
self, output_path: str, filters: Optional[Dict] = None
|
|
||||||
) -> bool:
|
|
||||||
"""Export detections to JSON file."""
|
"""Export detections to JSON file."""
|
||||||
try:
|
try:
|
||||||
detections = self.get_detections(filters)
|
detections = self.get_detections(filters)
|
||||||
@@ -574,25 +658,118 @@ class DatabaseManager:
|
|||||||
|
|
||||||
# ==================== Annotation Operations ====================
|
# ==================== Annotation Operations ====================
|
||||||
|
|
||||||
|
def get_annotated_images_summary(
|
||||||
|
self,
|
||||||
|
name_filter: Optional[str] = None,
|
||||||
|
order_by: str = "filename",
|
||||||
|
order_dir: str = "ASC",
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""Return images that have at least one manual annotation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name_filter: Optional substring filter applied to filename/relative_path.
|
||||||
|
order_by: One of: 'filename', 'relative_path', 'annotation_count', 'added_at'.
|
||||||
|
order_dir: 'ASC' or 'DESC'.
|
||||||
|
limit: Optional max number of rows.
|
||||||
|
offset: Pagination offset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts: {id, relative_path, filename, added_at, annotation_count}
|
||||||
|
"""
|
||||||
|
|
||||||
|
allowed_order_by = {
|
||||||
|
"filename": "i.filename",
|
||||||
|
"relative_path": "i.relative_path",
|
||||||
|
"annotation_count": "annotation_count",
|
||||||
|
"added_at": "i.added_at",
|
||||||
|
}
|
||||||
|
order_expr = allowed_order_by.get(order_by, "i.filename")
|
||||||
|
dir_norm = str(order_dir).upper().strip()
|
||||||
|
if dir_norm not in {"ASC", "DESC"}:
|
||||||
|
dir_norm = "ASC"
|
||||||
|
|
||||||
|
conn = self.get_connection()
|
||||||
|
try:
|
||||||
|
params: List[Any] = []
|
||||||
|
where_sql = ""
|
||||||
|
if name_filter:
|
||||||
|
# Case-insensitive substring search.
|
||||||
|
token = f"%{name_filter}%"
|
||||||
|
where_sql = "WHERE (i.filename LIKE ? OR i.relative_path LIKE ?)"
|
||||||
|
params.extend([token, token])
|
||||||
|
|
||||||
|
limit_sql = ""
|
||||||
|
if limit is not None:
|
||||||
|
limit_sql = " LIMIT ? OFFSET ?"
|
||||||
|
params.extend([int(limit), int(offset)])
|
||||||
|
|
||||||
|
query = f"""
|
||||||
|
SELECT
|
||||||
|
i.id,
|
||||||
|
i.relative_path,
|
||||||
|
i.filename,
|
||||||
|
i.added_at,
|
||||||
|
COUNT(a.id) AS annotation_count
|
||||||
|
FROM images i
|
||||||
|
JOIN annotations a ON a.image_id = i.id
|
||||||
|
{where_sql}
|
||||||
|
GROUP BY i.id
|
||||||
|
HAVING annotation_count > 0
|
||||||
|
ORDER BY {order_expr} {dir_norm}
|
||||||
|
{limit_sql}
|
||||||
|
"""
|
||||||
|
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(query, params)
|
||||||
|
return [dict(row) for row in cursor.fetchall()]
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
def add_annotation(
|
def add_annotation(
|
||||||
self,
|
self,
|
||||||
image_id: int,
|
image_id: int,
|
||||||
class_name: str,
|
class_id: int,
|
||||||
bbox: Tuple[float, float, float, float],
|
bbox: Tuple[float, float, float, float],
|
||||||
annotator: str,
|
annotator: str,
|
||||||
|
segmentation_mask: Optional[List[List[float]]] = None,
|
||||||
verified: bool = False,
|
verified: bool = False,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Add manual annotation."""
|
"""
|
||||||
|
Add manual annotation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_id: ID of the image
|
||||||
|
class_id: ID of the object class (foreign key to object_classes)
|
||||||
|
bbox: Bounding box coordinates (normalized 0-1)
|
||||||
|
annotator: Name of person/tool creating annotation
|
||||||
|
segmentation_mask: Polygon coordinates for segmentation
|
||||||
|
verified: Whether annotation has been verified
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ID of the inserted annotation
|
||||||
|
"""
|
||||||
conn = self.get_connection()
|
conn = self.get_connection()
|
||||||
try:
|
try:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
x_min, y_min, x_max, y_max = bbox
|
x_min, y_min, x_max, y_max = bbox
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO annotations (image_id, class_name, x_min, y_min, x_max, y_max, annotator, verified)
|
INSERT INTO annotations (image_id, class_id, x_min, y_min, x_max, y_max, segmentation_mask, annotator, verified)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
""",
|
""",
|
||||||
(image_id, class_name, x_min, y_min, x_max, y_max, annotator, verified),
|
(
|
||||||
|
image_id,
|
||||||
|
class_id,
|
||||||
|
x_min,
|
||||||
|
y_min,
|
||||||
|
x_max,
|
||||||
|
y_max,
|
||||||
|
json.dumps(segmentation_mask) if segmentation_mask else None,
|
||||||
|
annotator,
|
||||||
|
verified,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
return cursor.lastrowid
|
return cursor.lastrowid
|
||||||
@@ -600,15 +777,363 @@ class DatabaseManager:
|
|||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def get_annotations_for_image(self, image_id: int) -> List[Dict]:
|
def get_annotations_for_image(self, image_id: int) -> List[Dict]:
|
||||||
"""Get all annotations for an image."""
|
"""
|
||||||
|
Get all annotations for an image with class information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_id: ID of the image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of annotation dictionaries with joined class information
|
||||||
|
"""
|
||||||
conn = self.get_connection()
|
conn = self.get_connection()
|
||||||
try:
|
try:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("SELECT * FROM annotations WHERE image_id = ?", (image_id,))
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
SELECT
|
||||||
|
a.*,
|
||||||
|
c.class_name,
|
||||||
|
c.color as class_color,
|
||||||
|
c.description as class_description
|
||||||
|
FROM annotations a
|
||||||
|
JOIN object_classes c ON a.class_id = c.id
|
||||||
|
WHERE a.image_id = ?
|
||||||
|
ORDER BY a.created_at DESC
|
||||||
|
""",
|
||||||
|
(image_id,),
|
||||||
|
)
|
||||||
|
annotations = []
|
||||||
|
for row in cursor.fetchall():
|
||||||
|
ann = dict(row)
|
||||||
|
if ann.get("segmentation_mask"):
|
||||||
|
ann["segmentation_mask"] = json.loads(ann["segmentation_mask"])
|
||||||
|
annotations.append(ann)
|
||||||
|
return annotations
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def delete_annotation(self, annotation_id: int) -> bool:
|
||||||
|
"""
|
||||||
|
Delete a manual annotation by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
annotation_id: ID of the annotation to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if an annotation was deleted, False otherwise.
|
||||||
|
"""
|
||||||
|
conn = self.get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("DELETE FROM annotations WHERE id = ?", (annotation_id,))
|
||||||
|
conn.commit()
|
||||||
|
return cursor.rowcount > 0
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
# ==================== Object Class Operations ====================
|
||||||
|
|
||||||
|
def get_object_classes(self) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Get all object classes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of object class dictionaries
|
||||||
|
"""
|
||||||
|
conn = self.get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT * FROM object_classes ORDER BY class_name")
|
||||||
return [dict(row) for row in cursor.fetchall()]
|
return [dict(row) for row in cursor.fetchall()]
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
def get_object_class_by_id(self, class_id: int) -> Optional[Dict]:
|
||||||
|
"""Get object class by ID."""
|
||||||
|
conn = self.get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT * FROM object_classes WHERE id = ?", (class_id,))
|
||||||
|
row = cursor.fetchone()
|
||||||
|
return dict(row) if row else None
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def get_object_class_by_name(self, class_name: str) -> Optional[Dict]:
|
||||||
|
"""Get object class by name."""
|
||||||
|
conn = self.get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT * FROM object_classes WHERE class_name = ?", (class_name,))
|
||||||
|
row = cursor.fetchone()
|
||||||
|
return dict(row) if row else None
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def add_object_class(self, class_name: str, color: str, description: Optional[str] = None) -> int:
|
||||||
|
"""
|
||||||
|
Add a new object class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_name: Name of the object class
|
||||||
|
color: Hex color code (e.g., '#FF0000')
|
||||||
|
description: Optional description
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ID of the inserted object class
|
||||||
|
"""
|
||||||
|
conn = self.get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO object_classes (class_name, color, description)
|
||||||
|
VALUES (?, ?, ?)
|
||||||
|
""",
|
||||||
|
(class_name, color, description),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
return cursor.lastrowid
|
||||||
|
except sqlite3.IntegrityError:
|
||||||
|
# Class already exists
|
||||||
|
existing = self.get_object_class_by_name(class_name)
|
||||||
|
return existing["id"] if existing else None
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def update_object_class(
|
||||||
|
self,
|
||||||
|
class_id: int,
|
||||||
|
class_name: Optional[str] = None,
|
||||||
|
color: Optional[str] = None,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Update an object class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_id: ID of the class to update
|
||||||
|
class_name: New class name (optional)
|
||||||
|
color: New color (optional)
|
||||||
|
description: New description (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if updated, False otherwise
|
||||||
|
"""
|
||||||
|
conn = self.get_connection()
|
||||||
|
try:
|
||||||
|
updates = {}
|
||||||
|
if class_name is not None:
|
||||||
|
updates["class_name"] = class_name
|
||||||
|
if color is not None:
|
||||||
|
updates["color"] = color
|
||||||
|
if description is not None:
|
||||||
|
updates["description"] = description
|
||||||
|
|
||||||
|
if not updates:
|
||||||
|
return False
|
||||||
|
|
||||||
|
set_clauses = [f"{key} = ?" for key in updates.keys()]
|
||||||
|
params = list(updates.values()) + [class_id]
|
||||||
|
|
||||||
|
query = f"UPDATE object_classes SET {', '.join(set_clauses)} WHERE id = ?"
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(query, params)
|
||||||
|
conn.commit()
|
||||||
|
return cursor.rowcount > 0
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def delete_object_class(self, class_id: int) -> bool:
|
||||||
|
"""
|
||||||
|
Delete an object class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_id: ID of the class to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deleted, False otherwise
|
||||||
|
"""
|
||||||
|
conn = self.get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("DELETE FROM object_classes WHERE id = ?", (class_id,))
|
||||||
|
conn.commit()
|
||||||
|
return cursor.rowcount > 0
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
# ==================== Dataset Utilities ====================
|
||||||
|
|
||||||
|
def compose_data_yaml(
|
||||||
|
self,
|
||||||
|
dataset_root: str,
|
||||||
|
output_path: Optional[str] = None,
|
||||||
|
splits: Optional[Dict[str, str]] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Compose a YOLO data.yaml file based on dataset folders and database metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_root: Base directory containing the dataset structure.
|
||||||
|
output_path: Optional output path; defaults to <dataset_root>/data.yaml.
|
||||||
|
splits: Optional mapping overriding train/val/test image directories (relative
|
||||||
|
to dataset_root or absolute paths).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the generated YAML file.
|
||||||
|
"""
|
||||||
|
dataset_root_path = Path(dataset_root).expanduser()
|
||||||
|
if not dataset_root_path.exists():
|
||||||
|
raise ValueError(f"Dataset root does not exist: {dataset_root_path}")
|
||||||
|
dataset_root_path = dataset_root_path.resolve()
|
||||||
|
|
||||||
|
split_map: Dict[str, str] = {key: "" for key in ("train", "val", "test")}
|
||||||
|
if splits:
|
||||||
|
for key, value in splits.items():
|
||||||
|
if key in split_map and value:
|
||||||
|
split_map[key] = value
|
||||||
|
|
||||||
|
inferred = self._infer_split_dirs(dataset_root_path)
|
||||||
|
for key in split_map:
|
||||||
|
if not split_map[key]:
|
||||||
|
split_map[key] = inferred.get(key, "")
|
||||||
|
|
||||||
|
for required in ("train", "val"):
|
||||||
|
if not split_map[required]:
|
||||||
|
raise ValueError(
|
||||||
|
"Unable to determine %s image directory under %s. Provide it "
|
||||||
|
"explicitly via the 'splits' argument." % (required, dataset_root_path)
|
||||||
|
)
|
||||||
|
|
||||||
|
yaml_splits: Dict[str, str] = {}
|
||||||
|
for key, value in split_map.items():
|
||||||
|
if not value:
|
||||||
|
continue
|
||||||
|
yaml_splits[key] = self._normalize_split_value(value, dataset_root_path)
|
||||||
|
|
||||||
|
class_names = self._fetch_annotation_class_names()
|
||||||
|
if not class_names:
|
||||||
|
class_names = [cls["class_name"] for cls in self.get_object_classes()]
|
||||||
|
if not class_names:
|
||||||
|
raise ValueError("No object classes available to populate data.yaml")
|
||||||
|
|
||||||
|
names_map = {idx: name for idx, name in enumerate(class_names)}
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"path": dataset_root_path.as_posix(),
|
||||||
|
"train": yaml_splits["train"],
|
||||||
|
"val": yaml_splits["val"],
|
||||||
|
"names": names_map,
|
||||||
|
"nc": len(class_names),
|
||||||
|
}
|
||||||
|
if yaml_splits.get("test"):
|
||||||
|
payload["test"] = yaml_splits["test"]
|
||||||
|
|
||||||
|
output_path_obj = Path(output_path).expanduser() if output_path else dataset_root_path / "data.yaml"
|
||||||
|
output_path_obj.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
with open(output_path_obj, "w", encoding="utf-8") as handle:
|
||||||
|
yaml.safe_dump(payload, handle, sort_keys=False)
|
||||||
|
|
||||||
|
logger.info(f"Generated data.yaml at {output_path_obj}")
|
||||||
|
return output_path_obj.as_posix()
|
||||||
|
|
||||||
|
def _fetch_annotation_class_names(self) -> List[str]:
|
||||||
|
"""Return class names referenced by annotations (ordered by class ID)."""
|
||||||
|
conn = self.get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
SELECT DISTINCT c.id, c.class_name
|
||||||
|
FROM annotations a
|
||||||
|
JOIN object_classes c ON a.class_id = c.id
|
||||||
|
ORDER BY c.id
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
return [row["class_name"] for row in rows]
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def _infer_split_dirs(self, dataset_root: Path) -> Dict[str, str]:
|
||||||
|
"""Infer train/val/test image directories relative to dataset_root."""
|
||||||
|
patterns = {
|
||||||
|
"train": [
|
||||||
|
"train/images",
|
||||||
|
"training/images",
|
||||||
|
"images/train",
|
||||||
|
"images/training",
|
||||||
|
"train",
|
||||||
|
"training",
|
||||||
|
],
|
||||||
|
"val": [
|
||||||
|
"val/images",
|
||||||
|
"validation/images",
|
||||||
|
"images/val",
|
||||||
|
"images/validation",
|
||||||
|
"val",
|
||||||
|
"validation",
|
||||||
|
],
|
||||||
|
"test": [
|
||||||
|
"test/images",
|
||||||
|
"testing/images",
|
||||||
|
"images/test",
|
||||||
|
"images/testing",
|
||||||
|
"test",
|
||||||
|
"testing",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
inferred: Dict[str, str] = {key: "" for key in patterns}
|
||||||
|
for split_name, options in patterns.items():
|
||||||
|
for relative in options:
|
||||||
|
candidate = (dataset_root / relative).resolve()
|
||||||
|
if candidate.exists() and candidate.is_dir() and self._directory_has_images(candidate):
|
||||||
|
try:
|
||||||
|
inferred[split_name] = candidate.relative_to(dataset_root).as_posix()
|
||||||
|
except ValueError:
|
||||||
|
inferred[split_name] = candidate.as_posix()
|
||||||
|
break
|
||||||
|
return inferred
|
||||||
|
|
||||||
|
def _normalize_split_value(self, split_value: str, dataset_root: Path) -> str:
|
||||||
|
"""Validate and normalize a split directory to a YAML-friendly string."""
|
||||||
|
split_path = Path(split_value).expanduser()
|
||||||
|
if not split_path.is_absolute():
|
||||||
|
split_path = (dataset_root / split_path).resolve()
|
||||||
|
else:
|
||||||
|
split_path = split_path.resolve()
|
||||||
|
|
||||||
|
if not split_path.exists() or not split_path.is_dir():
|
||||||
|
raise ValueError(f"Split directory not found: {split_path}")
|
||||||
|
|
||||||
|
if not self._directory_has_images(split_path):
|
||||||
|
raise ValueError(f"No images found under {split_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
return split_path.relative_to(dataset_root).as_posix()
|
||||||
|
except ValueError:
|
||||||
|
return split_path.as_posix()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _directory_has_images(directory: Path, max_checks: int = 2000) -> bool:
|
||||||
|
"""Return True if directory tree contains at least one image file."""
|
||||||
|
checked = 0
|
||||||
|
try:
|
||||||
|
for file_path in directory.rglob("*"):
|
||||||
|
if not file_path.is_file():
|
||||||
|
continue
|
||||||
|
if file_path.suffix.lower() in IMAGE_EXTENSIONS:
|
||||||
|
return True
|
||||||
|
checked += 1
|
||||||
|
if checked >= max_checks:
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def calculate_checksum(file_path: str) -> str:
|
def calculate_checksum(file_path: str) -> str:
|
||||||
"""Calculate MD5 checksum of a file."""
|
"""Calculate MD5 checksum of a file."""
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ These dataclasses represent the database entities.
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, Dict, Tuple
|
from typing import Optional, Dict, Tuple, List
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -46,6 +46,9 @@ class Detection:
|
|||||||
class_name: str
|
class_name: str
|
||||||
bbox: Tuple[float, float, float, float] # (x_min, y_min, x_max, y_max)
|
bbox: Tuple[float, float, float, float] # (x_min, y_min, x_max, y_max)
|
||||||
confidence: float
|
confidence: float
|
||||||
|
segmentation_mask: Optional[
|
||||||
|
List[List[float]]
|
||||||
|
] # List of polygon coordinates [[x1,y1], [x2,y2], ...]
|
||||||
detected_at: datetime
|
detected_at: datetime
|
||||||
metadata: Optional[Dict]
|
metadata: Optional[Dict]
|
||||||
|
|
||||||
@@ -58,6 +61,9 @@ class Annotation:
|
|||||||
image_id: int
|
image_id: int
|
||||||
class_name: str
|
class_name: str
|
||||||
bbox: Tuple[float, float, float, float] # (x_min, y_min, x_max, y_max)
|
bbox: Tuple[float, float, float, float] # (x_min, y_min, x_max, y_max)
|
||||||
|
segmentation_mask: Optional[
|
||||||
|
List[List[float]]
|
||||||
|
] # List of polygon coordinates [[x1,y1], [x2,y2], ...]
|
||||||
annotator: str
|
annotator: str
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
verified: bool
|
verified: bool
|
||||||
|
|||||||
@@ -37,25 +37,41 @@ CREATE TABLE IF NOT EXISTS detections (
|
|||||||
x_max REAL NOT NULL CHECK(x_max >= 0 AND x_max <= 1),
|
x_max REAL NOT NULL CHECK(x_max >= 0 AND x_max <= 1),
|
||||||
y_max REAL NOT NULL CHECK(y_max >= 0 AND y_max <= 1),
|
y_max REAL NOT NULL CHECK(y_max >= 0 AND y_max <= 1),
|
||||||
confidence REAL NOT NULL CHECK(confidence >= 0 AND confidence <= 1),
|
confidence REAL NOT NULL CHECK(confidence >= 0 AND confidence <= 1),
|
||||||
|
segmentation_mask TEXT, -- JSON string of polygon coordinates [[x1,y1], [x2,y2], ...]
|
||||||
detected_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
detected_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
metadata TEXT, -- JSON string for additional metadata
|
metadata TEXT, -- JSON string for additional metadata
|
||||||
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE,
|
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE,
|
||||||
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE
|
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE
|
||||||
);
|
);
|
||||||
|
|
||||||
-- Annotations table: stores manual annotations (future feature)
|
-- Object classes table: stores annotation class definitions with colors
|
||||||
|
CREATE TABLE IF NOT EXISTS object_classes (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
class_name TEXT NOT NULL UNIQUE,
|
||||||
|
color TEXT NOT NULL, -- Hex color code (e.g., '#FF0000')
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
description TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Insert default object classes
|
||||||
|
INSERT OR IGNORE INTO object_classes (class_name, color, description) VALUES
|
||||||
|
('terminal', '#FFFF00', 'Axion terminal');
|
||||||
|
|
||||||
|
-- Annotations table: stores manual annotations
|
||||||
CREATE TABLE IF NOT EXISTS annotations (
|
CREATE TABLE IF NOT EXISTS annotations (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
image_id INTEGER NOT NULL,
|
image_id INTEGER NOT NULL,
|
||||||
class_name TEXT NOT NULL,
|
class_id INTEGER NOT NULL,
|
||||||
x_min REAL NOT NULL CHECK(x_min >= 0 AND x_min <= 1),
|
x_min REAL NOT NULL CHECK(x_min >= 0 AND x_min <= 1),
|
||||||
y_min REAL NOT NULL CHECK(y_min >= 0 AND y_min <= 1),
|
y_min REAL NOT NULL CHECK(y_min >= 0 AND y_min <= 1),
|
||||||
x_max REAL NOT NULL CHECK(x_max >= 0 AND x_max <= 1),
|
x_max REAL NOT NULL CHECK(x_max >= 0 AND x_max <= 1),
|
||||||
y_max REAL NOT NULL CHECK(y_max >= 0 AND y_max <= 1),
|
y_max REAL NOT NULL CHECK(y_max >= 0 AND y_max <= 1),
|
||||||
|
segmentation_mask TEXT, -- JSON string of polygon coordinates [[x1,y1], [x2,y2], ...]
|
||||||
annotator TEXT,
|
annotator TEXT,
|
||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
verified BOOLEAN DEFAULT 0,
|
verified BOOLEAN DEFAULT 0,
|
||||||
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE
|
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE,
|
||||||
|
FOREIGN KEY (class_id) REFERENCES object_classes (id) ON DELETE CASCADE
|
||||||
);
|
);
|
||||||
|
|
||||||
-- Create indexes for performance optimization
|
-- Create indexes for performance optimization
|
||||||
@@ -67,4 +83,6 @@ CREATE INDEX IF NOT EXISTS idx_detections_confidence ON detections(confidence);
|
|||||||
CREATE INDEX IF NOT EXISTS idx_images_relative_path ON images(relative_path);
|
CREATE INDEX IF NOT EXISTS idx_images_relative_path ON images(relative_path);
|
||||||
CREATE INDEX IF NOT EXISTS idx_images_added_at ON images(added_at);
|
CREATE INDEX IF NOT EXISTS idx_images_added_at ON images(added_at);
|
||||||
CREATE INDEX IF NOT EXISTS idx_annotations_image_id ON annotations(image_id);
|
CREATE INDEX IF NOT EXISTS idx_annotations_image_id ON annotations(image_id);
|
||||||
CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at);
|
CREATE INDEX IF NOT EXISTS idx_annotations_class_id ON annotations(class_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_object_classes_class_name ON object_classes(class_name);
|
||||||
@@ -121,7 +121,7 @@ class ConfigDialog(QDialog):
|
|||||||
models_layout.addRow("Models Directory:", self.models_dir_edit)
|
models_layout.addRow("Models Directory:", self.models_dir_edit)
|
||||||
|
|
||||||
self.base_model_edit = QLineEdit()
|
self.base_model_edit = QLineEdit()
|
||||||
self.base_model_edit.setPlaceholderText("yolov8s.pt")
|
self.base_model_edit.setPlaceholderText("yolov8s-seg.pt")
|
||||||
models_layout.addRow("Default Base Model:", self.base_model_edit)
|
models_layout.addRow("Default Base Model:", self.base_model_edit)
|
||||||
|
|
||||||
models_group.setLayout(models_layout)
|
models_group.setLayout(models_layout)
|
||||||
@@ -232,7 +232,7 @@ class ConfigDialog(QDialog):
|
|||||||
self.config_manager.get("models.models_directory", "data/models")
|
self.config_manager.get("models.models_directory", "data/models")
|
||||||
)
|
)
|
||||||
self.base_model_edit.setText(
|
self.base_model_edit.setText(
|
||||||
self.config_manager.get("models.default_base_model", "yolov8s.pt")
|
self.config_manager.get("models.default_base_model", "yolov8s-seg.pt")
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training settings
|
# Training settings
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""Main window for the microscopy object detection application."""
|
||||||
Main window for the microscopy object detection application.
|
|
||||||
"""
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from PySide6.QtWidgets import (
|
from PySide6.QtWidgets import (
|
||||||
QMainWindow,
|
QMainWindow,
|
||||||
@@ -13,13 +14,14 @@ from PySide6.QtWidgets import (
|
|||||||
QVBoxLayout,
|
QVBoxLayout,
|
||||||
QLabel,
|
QLabel,
|
||||||
)
|
)
|
||||||
from PySide6.QtCore import Qt, QTimer
|
from PySide6.QtCore import Qt, QTimer, QSettings
|
||||||
from PySide6.QtGui import QAction, QKeySequence
|
from PySide6.QtGui import QAction, QKeySequence
|
||||||
|
|
||||||
from src.database.db_manager import DatabaseManager
|
from src.database.db_manager import DatabaseManager
|
||||||
from src.utils.config_manager import ConfigManager
|
from src.utils.config_manager import ConfigManager
|
||||||
from src.utils.logger import get_logger
|
from src.utils.logger import get_logger
|
||||||
from src.gui.dialogs.config_dialog import ConfigDialog
|
from src.gui.dialogs.config_dialog import ConfigDialog
|
||||||
|
from src.gui.dialogs.delete_model_dialog import DeleteModelDialog
|
||||||
from src.gui.tabs.detection_tab import DetectionTab
|
from src.gui.tabs.detection_tab import DetectionTab
|
||||||
from src.gui.tabs.training_tab import TrainingTab
|
from src.gui.tabs.training_tab import TrainingTab
|
||||||
from src.gui.tabs.validation_tab import ValidationTab
|
from src.gui.tabs.validation_tab import ValidationTab
|
||||||
@@ -52,8 +54,8 @@ class MainWindow(QMainWindow):
|
|||||||
self._create_tab_widget()
|
self._create_tab_widget()
|
||||||
self._create_status_bar()
|
self._create_status_bar()
|
||||||
|
|
||||||
# Center window on screen
|
# Restore window geometry or center window on screen
|
||||||
self._center_window()
|
self._restore_window_state()
|
||||||
|
|
||||||
logger.info("Main window initialized")
|
logger.info("Main window initialized")
|
||||||
|
|
||||||
@@ -91,6 +93,12 @@ class MainWindow(QMainWindow):
|
|||||||
db_stats_action.triggered.connect(self._show_database_stats)
|
db_stats_action.triggered.connect(self._show_database_stats)
|
||||||
tools_menu.addAction(db_stats_action)
|
tools_menu.addAction(db_stats_action)
|
||||||
|
|
||||||
|
tools_menu.addSeparator()
|
||||||
|
|
||||||
|
delete_model_action = QAction("Delete &Model…", self)
|
||||||
|
delete_model_action.triggered.connect(self._show_delete_model_dialog)
|
||||||
|
tools_menu.addAction(delete_model_action)
|
||||||
|
|
||||||
# Help menu
|
# Help menu
|
||||||
help_menu = menubar.addMenu("&Help")
|
help_menu = menubar.addMenu("&Help")
|
||||||
|
|
||||||
@@ -117,10 +125,10 @@ class MainWindow(QMainWindow):
|
|||||||
|
|
||||||
# Add tabs to widget
|
# Add tabs to widget
|
||||||
self.tab_widget.addTab(self.detection_tab, "Detection")
|
self.tab_widget.addTab(self.detection_tab, "Detection")
|
||||||
|
self.tab_widget.addTab(self.results_tab, "Results")
|
||||||
|
self.tab_widget.addTab(self.annotation_tab, "Annotation")
|
||||||
self.tab_widget.addTab(self.training_tab, "Training")
|
self.tab_widget.addTab(self.training_tab, "Training")
|
||||||
self.tab_widget.addTab(self.validation_tab, "Validation")
|
self.tab_widget.addTab(self.validation_tab, "Validation")
|
||||||
self.tab_widget.addTab(self.results_tab, "Results")
|
|
||||||
self.tab_widget.addTab(self.annotation_tab, "Annotation (Future)")
|
|
||||||
|
|
||||||
# Connect tab change signal
|
# Connect tab change signal
|
||||||
self.tab_widget.currentChanged.connect(self._on_tab_changed)
|
self.tab_widget.currentChanged.connect(self._on_tab_changed)
|
||||||
@@ -152,9 +160,25 @@ class MainWindow(QMainWindow):
|
|||||||
"""Center window on screen."""
|
"""Center window on screen."""
|
||||||
screen = self.screen().geometry()
|
screen = self.screen().geometry()
|
||||||
size = self.geometry()
|
size = self.geometry()
|
||||||
self.move(
|
self.move((screen.width() - size.width()) // 2, (screen.height() - size.height()) // 2)
|
||||||
(screen.width() - size.width()) // 2, (screen.height() - size.height()) // 2
|
|
||||||
)
|
def _restore_window_state(self):
|
||||||
|
"""Restore window geometry from settings or center window."""
|
||||||
|
settings = QSettings("microscopy_app", "object_detection")
|
||||||
|
geometry = settings.value("main_window/geometry")
|
||||||
|
|
||||||
|
if geometry:
|
||||||
|
self.restoreGeometry(geometry)
|
||||||
|
logger.debug("Restored window geometry from settings")
|
||||||
|
else:
|
||||||
|
self._center_window()
|
||||||
|
logger.debug("Centered window on screen")
|
||||||
|
|
||||||
|
def _save_window_state(self):
|
||||||
|
"""Save window geometry to settings."""
|
||||||
|
settings = QSettings("microscopy_app", "object_detection")
|
||||||
|
settings.setValue("main_window/geometry", self.saveGeometry())
|
||||||
|
logger.debug("Saved window geometry to settings")
|
||||||
|
|
||||||
def _show_settings(self):
|
def _show_settings(self):
|
||||||
"""Show settings dialog."""
|
"""Show settings dialog."""
|
||||||
@@ -175,6 +199,10 @@ class MainWindow(QMainWindow):
|
|||||||
self.training_tab.refresh()
|
self.training_tab.refresh()
|
||||||
if hasattr(self, "results_tab"):
|
if hasattr(self, "results_tab"):
|
||||||
self.results_tab.refresh()
|
self.results_tab.refresh()
|
||||||
|
if hasattr(self, "annotation_tab"):
|
||||||
|
self.annotation_tab.refresh()
|
||||||
|
if hasattr(self, "validation_tab"):
|
||||||
|
self.validation_tab.refresh()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error applying settings: {e}")
|
logger.error(f"Error applying settings: {e}")
|
||||||
|
|
||||||
@@ -191,6 +219,14 @@ class MainWindow(QMainWindow):
|
|||||||
logger.debug(f"Switched to tab: {tab_name}")
|
logger.debug(f"Switched to tab: {tab_name}")
|
||||||
self._update_status(f"Viewing: {tab_name}")
|
self._update_status(f"Viewing: {tab_name}")
|
||||||
|
|
||||||
|
# Ensure the Annotation tab always shows up-to-date DB-backed lists.
|
||||||
|
try:
|
||||||
|
current_widget = self.tab_widget.widget(index)
|
||||||
|
if hasattr(self, "annotation_tab") and current_widget is self.annotation_tab:
|
||||||
|
self.annotation_tab.refresh()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug(f"Failed to refresh annotation tab on selection: {exc}")
|
||||||
|
|
||||||
def _show_database_stats(self):
|
def _show_database_stats(self):
|
||||||
"""Show database statistics dialog."""
|
"""Show database statistics dialog."""
|
||||||
try:
|
try:
|
||||||
@@ -213,9 +249,229 @@ class MainWindow(QMainWindow):
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting database stats: {e}")
|
logger.error(f"Error getting database stats: {e}")
|
||||||
QMessageBox.warning(
|
QMessageBox.warning(self, "Error", f"Failed to get database statistics:\n{str(e)}")
|
||||||
self, "Error", f"Failed to get database statistics:\n{str(e)}"
|
|
||||||
)
|
def _show_delete_model_dialog(self) -> None:
|
||||||
|
"""Open the model deletion dialog."""
|
||||||
|
dialog = DeleteModelDialog(self.db_manager, self)
|
||||||
|
if not dialog.exec():
|
||||||
|
return
|
||||||
|
|
||||||
|
model_ids = dialog.selected_model_ids
|
||||||
|
if not model_ids:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._delete_models(model_ids)
|
||||||
|
|
||||||
|
def _delete_models(self, model_ids: list[int]) -> None:
|
||||||
|
"""Delete one or more models from the database and remove artifacts from disk."""
|
||||||
|
|
||||||
|
deleted_count = 0
|
||||||
|
removed_paths: list[str] = []
|
||||||
|
remove_errors: list[str] = []
|
||||||
|
|
||||||
|
for model_id in model_ids:
|
||||||
|
model = None
|
||||||
|
try:
|
||||||
|
model = self.db_manager.get_model_by_id(int(model_id))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Failed to load model {model_id} before deletion: {exc}")
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
remove_errors.append(f"Model id {model_id} not found in database.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
deleted = self.db_manager.delete_model(int(model_id))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Failed to delete model {model_id}: {exc}")
|
||||||
|
remove_errors.append(f"Failed to delete model id {model_id} from DB: {exc}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not deleted:
|
||||||
|
remove_errors.append(f"Model id {model_id} was not deleted (already removed?).")
|
||||||
|
continue
|
||||||
|
|
||||||
|
deleted_count += 1
|
||||||
|
removed, errors = self._delete_model_artifacts_from_disk(model)
|
||||||
|
removed_paths.extend(removed)
|
||||||
|
remove_errors.extend(errors)
|
||||||
|
|
||||||
|
# Refresh tabs to reflect the deletion(s).
|
||||||
|
try:
|
||||||
|
if hasattr(self, "detection_tab"):
|
||||||
|
self.detection_tab.refresh()
|
||||||
|
if hasattr(self, "results_tab"):
|
||||||
|
self.results_tab.refresh()
|
||||||
|
if hasattr(self, "validation_tab"):
|
||||||
|
self.validation_tab.refresh()
|
||||||
|
if hasattr(self, "training_tab"):
|
||||||
|
self.training_tab.refresh()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"Failed to refresh tabs after model deletion: {exc}")
|
||||||
|
|
||||||
|
details: list[str] = []
|
||||||
|
if removed_paths:
|
||||||
|
details.append("Removed from disk:\n" + "\n".join(removed_paths))
|
||||||
|
if remove_errors:
|
||||||
|
details.append("\nDisk cleanup warnings:\n" + "\n".join(remove_errors))
|
||||||
|
|
||||||
|
QMessageBox.information(
|
||||||
|
self,
|
||||||
|
"Delete Model",
|
||||||
|
f"Deleted {deleted_count} model(s) from database." + ("\n\n" + "\n".join(details) if details else ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _delete_model(self, model_id: int) -> None:
|
||||||
|
"""Delete a model from the database and remove its artifacts from disk."""
|
||||||
|
|
||||||
|
model = None
|
||||||
|
try:
|
||||||
|
model = self.db_manager.get_model_by_id(model_id)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Failed to load model {model_id} before deletion: {exc}")
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
QMessageBox.warning(self, "Delete Model", "Selected model was not found in the database.")
|
||||||
|
return
|
||||||
|
|
||||||
|
model_path = str(model.get("model_path") or "")
|
||||||
|
|
||||||
|
try:
|
||||||
|
deleted = self.db_manager.delete_model(model_id)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Failed to delete model {model_id}: {exc}")
|
||||||
|
QMessageBox.critical(self, "Delete Model", f"Failed to delete model from database:\n{exc}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not deleted:
|
||||||
|
QMessageBox.warning(self, "Delete Model", "No model was deleted (it may have already been removed).")
|
||||||
|
return
|
||||||
|
|
||||||
|
removed_paths, remove_errors = self._delete_model_artifacts_from_disk(model)
|
||||||
|
|
||||||
|
# Refresh tabs to reflect the deletion.
|
||||||
|
try:
|
||||||
|
if hasattr(self, "detection_tab"):
|
||||||
|
self.detection_tab.refresh()
|
||||||
|
if hasattr(self, "results_tab"):
|
||||||
|
self.results_tab.refresh()
|
||||||
|
if hasattr(self, "validation_tab"):
|
||||||
|
self.validation_tab.refresh()
|
||||||
|
if hasattr(self, "training_tab"):
|
||||||
|
self.training_tab.refresh()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"Failed to refresh tabs after model deletion: {exc}")
|
||||||
|
|
||||||
|
details = []
|
||||||
|
if model_path:
|
||||||
|
details.append(f"Deleted model record for: {model_path}")
|
||||||
|
if removed_paths:
|
||||||
|
details.append("\nRemoved from disk:\n" + "\n".join(removed_paths))
|
||||||
|
if remove_errors:
|
||||||
|
details.append("\nDisk cleanup warnings:\n" + "\n".join(remove_errors))
|
||||||
|
|
||||||
|
QMessageBox.information(
|
||||||
|
self,
|
||||||
|
"Delete Model",
|
||||||
|
"Model deleted from database." + ("\n\n" + "\n".join(details) if details else ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _delete_model_artifacts_from_disk(self, model: dict) -> tuple[list[str], list[str]]:
|
||||||
|
"""Best-effort removal of model artifacts on disk.
|
||||||
|
|
||||||
|
Strategy:
|
||||||
|
- Remove run directories inferred from:
|
||||||
|
- model.model_path (…/<run>/weights/*.pt => <run>)
|
||||||
|
- training_params.stage_results[].results.save_dir
|
||||||
|
but only if they are under the configured models directory.
|
||||||
|
- If the weights file itself exists and is outside the models directory, delete only the file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(removed_paths, errors)
|
||||||
|
"""
|
||||||
|
|
||||||
|
removed: list[str] = []
|
||||||
|
errors: list[str] = []
|
||||||
|
|
||||||
|
models_root = Path(self.config_manager.get_models_directory() or "data/models").expanduser()
|
||||||
|
try:
|
||||||
|
models_root_resolved = models_root.resolve()
|
||||||
|
except Exception:
|
||||||
|
models_root_resolved = models_root
|
||||||
|
|
||||||
|
inferred_dirs: list[Path] = []
|
||||||
|
|
||||||
|
# 1) From model_path
|
||||||
|
model_path_value = model.get("model_path")
|
||||||
|
if model_path_value:
|
||||||
|
try:
|
||||||
|
p = Path(str(model_path_value)).expanduser()
|
||||||
|
p_resolved = p.resolve() if p.exists() else p
|
||||||
|
if p_resolved.is_file():
|
||||||
|
if p_resolved.parent.name == "weights" and p_resolved.parent.parent.exists():
|
||||||
|
inferred_dirs.append(p_resolved.parent.parent)
|
||||||
|
elif p_resolved.parent.exists():
|
||||||
|
inferred_dirs.append(p_resolved.parent)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 2) From training_params.stage_results[].results.save_dir
|
||||||
|
training_params = model.get("training_params") or {}
|
||||||
|
if isinstance(training_params, dict):
|
||||||
|
stage_results = training_params.get("stage_results")
|
||||||
|
if isinstance(stage_results, list):
|
||||||
|
for stage in stage_results:
|
||||||
|
results = (stage or {}).get("results")
|
||||||
|
save_dir = (results or {}).get("save_dir") if isinstance(results, dict) else None
|
||||||
|
if not save_dir:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
d = Path(str(save_dir)).expanduser()
|
||||||
|
if d.exists() and d.is_dir():
|
||||||
|
inferred_dirs.append(d)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Deduplicate inferred_dirs
|
||||||
|
unique_dirs: list[Path] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for d in inferred_dirs:
|
||||||
|
try:
|
||||||
|
key = str(d.resolve())
|
||||||
|
except Exception:
|
||||||
|
key = str(d)
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
unique_dirs.append(d)
|
||||||
|
|
||||||
|
# Delete directories under models_root
|
||||||
|
for d in unique_dirs:
|
||||||
|
try:
|
||||||
|
d_resolved = d.resolve()
|
||||||
|
except Exception:
|
||||||
|
d_resolved = d
|
||||||
|
try:
|
||||||
|
if d_resolved.exists() and d_resolved.is_dir() and d_resolved.is_relative_to(models_root_resolved):
|
||||||
|
shutil.rmtree(d_resolved)
|
||||||
|
removed.append(str(d_resolved))
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"Failed to remove directory {d_resolved}: {exc}")
|
||||||
|
|
||||||
|
# If nothing matched (e.g., model_path outside models_root), delete just the file.
|
||||||
|
if model_path_value:
|
||||||
|
try:
|
||||||
|
p = Path(str(model_path_value)).expanduser()
|
||||||
|
if p.exists() and p.is_file():
|
||||||
|
p_resolved = p.resolve()
|
||||||
|
if not p_resolved.is_relative_to(models_root_resolved):
|
||||||
|
p_resolved.unlink()
|
||||||
|
removed.append(str(p_resolved))
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"Failed to remove model file {model_path_value}: {exc}")
|
||||||
|
|
||||||
|
return removed, errors
|
||||||
|
|
||||||
def _show_about(self):
|
def _show_about(self):
|
||||||
"""Show about dialog."""
|
"""Show about dialog."""
|
||||||
@@ -276,6 +532,20 @@ class MainWindow(QMainWindow):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if reply == QMessageBox.Yes:
|
if reply == QMessageBox.Yes:
|
||||||
|
# Save window state before closing
|
||||||
|
self._save_window_state()
|
||||||
|
|
||||||
|
# Persist tab state and stop background work before exit
|
||||||
|
if hasattr(self, "training_tab"):
|
||||||
|
self.training_tab.shutdown()
|
||||||
|
if hasattr(self, "annotation_tab"):
|
||||||
|
# Best-effort refresh so DB-backed UI state is consistent at shutdown.
|
||||||
|
try:
|
||||||
|
self.annotation_tab.refresh()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self.annotation_tab.save_state()
|
||||||
|
|
||||||
logger.info("Application closing")
|
logger.info("Application closing")
|
||||||
event.accept()
|
event.accept()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,23 +1,49 @@
|
|||||||
"""
|
"""
|
||||||
Annotation tab for the microscopy object detection application.
|
Annotation tab for the microscopy object detection application.
|
||||||
Future feature for manual annotation.
|
Manual annotation with pen tool and object class management.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
|
from PySide6.QtWidgets import (
|
||||||
|
QWidget,
|
||||||
|
QVBoxLayout,
|
||||||
|
QHBoxLayout,
|
||||||
|
QLabel,
|
||||||
|
QGroupBox,
|
||||||
|
QPushButton,
|
||||||
|
QFileDialog,
|
||||||
|
QMessageBox,
|
||||||
|
QSplitter,
|
||||||
|
QLineEdit,
|
||||||
|
QTableWidget,
|
||||||
|
QTableWidgetItem,
|
||||||
|
QHeaderView,
|
||||||
|
QAbstractItemView,
|
||||||
|
)
|
||||||
|
from PySide6.QtCore import Qt, QSettings
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from src.database.db_manager import DatabaseManager
|
from src.database.db_manager import DatabaseManager
|
||||||
from src.utils.config_manager import ConfigManager
|
from src.utils.config_manager import ConfigManager
|
||||||
|
from src.utils.image import Image, ImageLoadError
|
||||||
|
from src.utils.logger import get_logger
|
||||||
|
from src.gui.widgets import AnnotationCanvasWidget, AnnotationToolsWidget
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AnnotationTab(QWidget):
|
class AnnotationTab(QWidget):
|
||||||
"""Annotation tab placeholder (future feature)."""
|
"""Annotation tab for manual image annotation."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
|
||||||
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
|
||||||
):
|
|
||||||
super().__init__(parent)
|
super().__init__(parent)
|
||||||
self.db_manager = db_manager
|
self.db_manager = db_manager
|
||||||
self.config_manager = config_manager
|
self.config_manager = config_manager
|
||||||
|
self.current_image = None
|
||||||
|
self.current_image_path = None
|
||||||
|
self.current_image_id = None
|
||||||
|
self.current_annotations = []
|
||||||
|
# IDs of annotations currently selected on the canvas (multi-select)
|
||||||
|
self.selected_annotation_ids = []
|
||||||
|
|
||||||
self._setup_ui()
|
self._setup_ui()
|
||||||
|
|
||||||
@@ -25,24 +51,692 @@ class AnnotationTab(QWidget):
|
|||||||
"""Setup user interface."""
|
"""Setup user interface."""
|
||||||
layout = QVBoxLayout()
|
layout = QVBoxLayout()
|
||||||
|
|
||||||
group = QGroupBox("Annotation Tool (Future Feature)")
|
# Main horizontal splitter to divide left (image) and right (controls)
|
||||||
group_layout = QVBoxLayout()
|
self.main_splitter = QSplitter(Qt.Horizontal)
|
||||||
label = QLabel(
|
self.main_splitter.setHandleWidth(10)
|
||||||
"Annotation functionality will be implemented in future version.\n\n"
|
|
||||||
"Planned Features:\n"
|
|
||||||
"- Image browser\n"
|
|
||||||
"- Drawing tools for bounding boxes\n"
|
|
||||||
"- Class label assignment\n"
|
|
||||||
"- Export annotations to YOLO format\n"
|
|
||||||
"- Annotation verification"
|
|
||||||
)
|
|
||||||
group_layout.addWidget(label)
|
|
||||||
group.setLayout(group_layout)
|
|
||||||
|
|
||||||
layout.addWidget(group)
|
# { Left-most pane: annotated images list
|
||||||
layout.addStretch()
|
annotated_group = QGroupBox("Annotated Images")
|
||||||
|
annotated_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
filter_row = QHBoxLayout()
|
||||||
|
filter_row.addWidget(QLabel("Filter:"))
|
||||||
|
self.annotated_filter_edit = QLineEdit()
|
||||||
|
self.annotated_filter_edit.setPlaceholderText("Type to filter by image name…")
|
||||||
|
self.annotated_filter_edit.textChanged.connect(self._refresh_annotated_images_list)
|
||||||
|
filter_row.addWidget(self.annotated_filter_edit, 1)
|
||||||
|
annotated_layout.addLayout(filter_row)
|
||||||
|
|
||||||
|
self.annotated_images_table = QTableWidget(0, 2)
|
||||||
|
self.annotated_images_table.setHorizontalHeaderLabels(["Image", "Annotations"])
|
||||||
|
self.annotated_images_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Stretch)
|
||||||
|
self.annotated_images_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeToContents)
|
||||||
|
self.annotated_images_table.setSelectionBehavior(QAbstractItemView.SelectRows)
|
||||||
|
self.annotated_images_table.setSelectionMode(QAbstractItemView.SingleSelection)
|
||||||
|
self.annotated_images_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
|
||||||
|
self.annotated_images_table.setSortingEnabled(True)
|
||||||
|
self.annotated_images_table.itemSelectionChanged.connect(self._on_annotated_image_selected)
|
||||||
|
annotated_layout.addWidget(self.annotated_images_table, 1)
|
||||||
|
|
||||||
|
annotated_group.setLayout(annotated_layout)
|
||||||
|
# }
|
||||||
|
|
||||||
|
# { Left splitter for image display and zoom info
|
||||||
|
self.left_splitter = QSplitter(Qt.Vertical)
|
||||||
|
self.left_splitter.setHandleWidth(10)
|
||||||
|
|
||||||
|
# Annotation canvas section
|
||||||
|
canvas_group = QGroupBox("Annotation Canvas")
|
||||||
|
canvas_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
# Use the AnnotationCanvasWidget
|
||||||
|
self.annotation_canvas = AnnotationCanvasWidget()
|
||||||
|
# Auto-zoom so newly loaded images fill the available canvas viewport.
|
||||||
|
# (Matches the behavior used in ResultsTab.)
|
||||||
|
self.annotation_canvas.set_auto_fit_to_view(True)
|
||||||
|
self.annotation_canvas.zoom_changed.connect(self._on_zoom_changed)
|
||||||
|
self.annotation_canvas.annotation_drawn.connect(self._on_annotation_drawn)
|
||||||
|
# Selection of existing polylines (when tool is not in drawing mode)
|
||||||
|
self.annotation_canvas.annotation_selected.connect(self._on_annotation_selected)
|
||||||
|
canvas_layout.addWidget(self.annotation_canvas)
|
||||||
|
|
||||||
|
canvas_group.setLayout(canvas_layout)
|
||||||
|
self.left_splitter.addWidget(canvas_group)
|
||||||
|
|
||||||
|
# Controls info
|
||||||
|
controls_info = QLabel("Zoom: Mouse wheel or +/- keys | Drawing: Enable pen and drag mouse")
|
||||||
|
controls_info.setStyleSheet("QLabel { color: #888; font-style: italic; }")
|
||||||
|
self.left_splitter.addWidget(controls_info)
|
||||||
|
# }
|
||||||
|
|
||||||
|
# { Right splitter for annotation tools and controls
|
||||||
|
self.right_splitter = QSplitter(Qt.Vertical)
|
||||||
|
self.right_splitter.setHandleWidth(10)
|
||||||
|
|
||||||
|
# Annotation tools section
|
||||||
|
self.annotation_tools = AnnotationToolsWidget(self.db_manager)
|
||||||
|
self.annotation_tools.polyline_enabled_changed.connect(self.annotation_canvas.set_polyline_enabled)
|
||||||
|
self.annotation_tools.polyline_pen_color_changed.connect(self.annotation_canvas.set_polyline_pen_color)
|
||||||
|
self.annotation_tools.polyline_pen_width_changed.connect(self.annotation_canvas.set_polyline_pen_width)
|
||||||
|
# Show / hide bounding boxes
|
||||||
|
self.annotation_tools.show_bboxes_changed.connect(self.annotation_canvas.set_show_bboxes)
|
||||||
|
# RDP simplification controls
|
||||||
|
self.annotation_tools.simplify_on_finish_changed.connect(self._on_simplify_on_finish_changed)
|
||||||
|
self.annotation_tools.simplify_epsilon_changed.connect(self._on_simplify_epsilon_changed)
|
||||||
|
# Class selection and class-color changes
|
||||||
|
self.annotation_tools.class_selected.connect(self._on_class_selected)
|
||||||
|
self.annotation_tools.class_color_changed.connect(self._on_class_color_changed)
|
||||||
|
self.annotation_tools.clear_annotations_requested.connect(self._on_clear_annotations)
|
||||||
|
# Delete selected annotation on canvas
|
||||||
|
self.annotation_tools.delete_selected_annotation_requested.connect(self._on_delete_selected_annotation)
|
||||||
|
self.right_splitter.addWidget(self.annotation_tools)
|
||||||
|
|
||||||
|
# Image loading section
|
||||||
|
load_group = QGroupBox("Image Loading")
|
||||||
|
load_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
# Load image button
|
||||||
|
button_layout = QHBoxLayout()
|
||||||
|
self.load_image_btn = QPushButton("Load Image")
|
||||||
|
self.load_image_btn.clicked.connect(self._load_image)
|
||||||
|
button_layout.addWidget(self.load_image_btn)
|
||||||
|
button_layout.addStretch()
|
||||||
|
load_layout.addLayout(button_layout)
|
||||||
|
|
||||||
|
# Image info label
|
||||||
|
self.image_info_label = QLabel("No image loaded")
|
||||||
|
load_layout.addWidget(self.image_info_label)
|
||||||
|
|
||||||
|
load_group.setLayout(load_layout)
|
||||||
|
self.right_splitter.addWidget(load_group)
|
||||||
|
# }
|
||||||
|
|
||||||
|
# Add list + both splitters to the main horizontal splitter
|
||||||
|
self.main_splitter.addWidget(annotated_group)
|
||||||
|
self.main_splitter.addWidget(self.left_splitter)
|
||||||
|
self.main_splitter.addWidget(self.right_splitter)
|
||||||
|
|
||||||
|
# Set initial sizes: list (left), canvas (middle), controls (right)
|
||||||
|
self.main_splitter.setSizes([320, 650, 280])
|
||||||
|
|
||||||
|
layout.addWidget(self.main_splitter)
|
||||||
self.setLayout(layout)
|
self.setLayout(layout)
|
||||||
|
|
||||||
|
# Restore splitter positions from settings
|
||||||
|
self._restore_state()
|
||||||
|
|
||||||
|
# Populate list on startup.
|
||||||
|
self._refresh_annotated_images_list()
|
||||||
|
|
||||||
|
def _load_image(self):
|
||||||
|
"""Load and display an image file."""
|
||||||
|
# Get last opened directory from QSettings
|
||||||
|
settings = QSettings("microscopy_app", "object_detection")
|
||||||
|
last_dir = settings.value("annotation_tab/last_directory", None)
|
||||||
|
|
||||||
|
# Fallback to image repository path or home directory
|
||||||
|
if last_dir and Path(last_dir).exists():
|
||||||
|
start_dir = last_dir
|
||||||
|
else:
|
||||||
|
repo_path = self.config_manager.get_image_repository_path()
|
||||||
|
start_dir = repo_path if repo_path else str(Path.home())
|
||||||
|
|
||||||
|
# Open file dialog
|
||||||
|
file_path, _ = QFileDialog.getOpenFileName(
|
||||||
|
self,
|
||||||
|
"Select Image",
|
||||||
|
start_dir,
|
||||||
|
"Images (*" + " *".join(Image.SUPPORTED_EXTENSIONS) + ")",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not file_path:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load image using Image class
|
||||||
|
self.current_image = Image(file_path)
|
||||||
|
self.current_image_path = file_path
|
||||||
|
|
||||||
|
# Store the directory for next time
|
||||||
|
settings.setValue("annotation_tab/last_directory", str(Path(file_path).parent))
|
||||||
|
|
||||||
|
# Get or create image in database
|
||||||
|
repo_root = self.config_manager.get_image_repository_path()
|
||||||
|
relative_path: str
|
||||||
|
try:
|
||||||
|
if repo_root:
|
||||||
|
repo_root_path = Path(repo_root).expanduser().resolve()
|
||||||
|
file_resolved = Path(file_path).expanduser().resolve()
|
||||||
|
if file_resolved.is_relative_to(repo_root_path):
|
||||||
|
relative_path = file_resolved.relative_to(repo_root_path).as_posix()
|
||||||
|
else:
|
||||||
|
# Fallback: store filename only to avoid leaking absolute paths.
|
||||||
|
relative_path = file_resolved.name
|
||||||
|
else:
|
||||||
|
relative_path = str(Path(file_path).name)
|
||||||
|
except Exception:
|
||||||
|
relative_path = str(Path(file_path).name)
|
||||||
|
self.current_image_id = self.db_manager.get_or_create_image(
|
||||||
|
relative_path,
|
||||||
|
Path(file_path).name,
|
||||||
|
self.current_image.width,
|
||||||
|
self.current_image.height,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Display image using the AnnotationCanvasWidget
|
||||||
|
self.annotation_canvas.load_image(self.current_image)
|
||||||
|
|
||||||
|
# Load and display any existing annotations for this image
|
||||||
|
self._load_annotations_for_current_image()
|
||||||
|
|
||||||
|
# Update annotated images list (newly annotated image added/selected).
|
||||||
|
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
|
||||||
|
|
||||||
|
# Update info label
|
||||||
|
self._update_image_info()
|
||||||
|
|
||||||
|
logger.info(f"Loaded image: {file_path} (DB ID: {self.current_image_id})")
|
||||||
|
|
||||||
|
except ImageLoadError as e:
|
||||||
|
logger.error(f"Failed to load image: {e}")
|
||||||
|
QMessageBox.critical(self, "Error Loading Image", f"Failed to load image:\n{str(e)}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error loading image: {e}")
|
||||||
|
QMessageBox.critical(self, "Error", f"Unexpected error:\n{str(e)}")
|
||||||
|
|
||||||
|
def _update_image_info(self):
|
||||||
|
"""Update the image info label with current image details."""
|
||||||
|
if self.current_image is None:
|
||||||
|
self.image_info_label.setText("No image loaded")
|
||||||
|
return
|
||||||
|
|
||||||
|
zoom_percentage = self.annotation_canvas.get_zoom_percentage()
|
||||||
|
info_text = (
|
||||||
|
f"File: {Path(self.current_image_path).name}\n"
|
||||||
|
f"Size: {self.current_image.width}x{self.current_image.height} pixels\n"
|
||||||
|
f"Channels: {self.current_image.channels}\n"
|
||||||
|
f"Data type: {self.current_image.dtype}\n"
|
||||||
|
f"Format: {self.current_image.format.upper()}\n"
|
||||||
|
f"File size: {self.current_image.size_mb:.2f} MB\n"
|
||||||
|
f"Zoom: {zoom_percentage}%"
|
||||||
|
)
|
||||||
|
self.image_info_label.setText(info_text)
|
||||||
|
|
||||||
|
def _on_zoom_changed(self, zoom_scale: float):
|
||||||
|
"""Handle zoom level changes from the annotation canvas."""
|
||||||
|
self._update_image_info()
|
||||||
|
|
||||||
|
def _on_annotation_drawn(self, points: list):
|
||||||
|
"""
|
||||||
|
Handle when an annotation stroke is drawn.
|
||||||
|
|
||||||
|
Saves the new annotation directly to the database and refreshes the
|
||||||
|
on-canvas display of annotations for the current image.
|
||||||
|
"""
|
||||||
|
# Ensure we have an image loaded and in the DB
|
||||||
|
if not self.current_image or not self.current_image_id:
|
||||||
|
logger.warning("Annotation drawn but no image loaded")
|
||||||
|
QMessageBox.warning(
|
||||||
|
self,
|
||||||
|
"No Image",
|
||||||
|
"Please load an image before drawing annotations.",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
current_class = self.annotation_tools.get_current_class()
|
||||||
|
|
||||||
|
if not current_class:
|
||||||
|
logger.warning("Annotation drawn but no object class selected")
|
||||||
|
QMessageBox.warning(
|
||||||
|
self,
|
||||||
|
"No Class Selected",
|
||||||
|
"Please select an object class before drawing annotations.",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not points:
|
||||||
|
logger.warning("Annotation drawn with no points, ignoring")
|
||||||
|
return
|
||||||
|
|
||||||
|
# points are [(x_norm, y_norm), ...]
|
||||||
|
xs = [p[0] for p in points]
|
||||||
|
ys = [p[1] for p in points]
|
||||||
|
x_min, x_max = min(xs), max(xs)
|
||||||
|
y_min, y_max = min(ys), max(ys)
|
||||||
|
|
||||||
|
# Store segmentation mask in [y_norm, x_norm] format to match DB
|
||||||
|
db_polyline = [[float(y), float(x)] for (x, y) in points]
|
||||||
|
|
||||||
|
try:
|
||||||
|
annotation_id = self.db_manager.add_annotation(
|
||||||
|
image_id=self.current_image_id,
|
||||||
|
class_id=current_class["id"],
|
||||||
|
bbox=(x_min, y_min, x_max, y_max),
|
||||||
|
annotator="manual",
|
||||||
|
segmentation_mask=db_polyline,
|
||||||
|
verified=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Saved annotation (ID: {annotation_id}) for class "
|
||||||
|
f"'{current_class['class_name']}' "
|
||||||
|
f"Bounding box: ({x_min:.3f}, {y_min:.3f}) to ({x_max:.3f}, {y_max:.3f})\n"
|
||||||
|
f"with {len(points)} polyline points"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reload annotations from DB and redraw (respecting current class filter)
|
||||||
|
self._load_annotations_for_current_image()
|
||||||
|
|
||||||
|
# Update list counts.
|
||||||
|
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save annotation: {e}")
|
||||||
|
QMessageBox.critical(self, "Error", f"Failed to save annotation:\n{str(e)}")
|
||||||
|
|
||||||
|
def _on_annotation_selected(self, annotation_ids):
|
||||||
|
"""
|
||||||
|
Handle selection of existing annotations on the canvas.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
annotation_ids: List of selected annotation IDs, or None/empty if cleared.
|
||||||
|
"""
|
||||||
|
if not annotation_ids:
|
||||||
|
self.selected_annotation_ids = []
|
||||||
|
self.annotation_tools.set_has_selected_annotation(False)
|
||||||
|
logger.debug("Annotation selection cleared on canvas")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Normalize to a unique, sorted list of integer IDs
|
||||||
|
ids = sorted({int(aid) for aid in annotation_ids if isinstance(aid, int)})
|
||||||
|
self.selected_annotation_ids = ids
|
||||||
|
self.annotation_tools.set_has_selected_annotation(bool(ids))
|
||||||
|
logger.debug(f"Annotations selected on canvas: IDs={ids}")
|
||||||
|
|
||||||
|
def _on_simplify_on_finish_changed(self, enabled: bool):
|
||||||
|
"""Update canvas simplify-on-finish flag from tools widget."""
|
||||||
|
self.annotation_canvas.simplify_on_finish = enabled
|
||||||
|
logger.debug(f"Annotation simplification on finish set to {enabled}")
|
||||||
|
|
||||||
|
def _on_simplify_epsilon_changed(self, epsilon: float):
|
||||||
|
"""Update canvas RDP epsilon from tools widget."""
|
||||||
|
self.annotation_canvas.simplify_epsilon = float(epsilon)
|
||||||
|
logger.debug(f"Annotation simplification epsilon set to {epsilon}")
|
||||||
|
|
||||||
|
def _on_class_color_changed(self):
|
||||||
|
"""
|
||||||
|
Handle changes to the selected object's class color.
|
||||||
|
|
||||||
|
When the user updates a class color in the tools widget, reload the
|
||||||
|
annotations for the current image so that all polylines are redrawn
|
||||||
|
using the updated per-class colors.
|
||||||
|
"""
|
||||||
|
if not self.current_image_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug(f"Class color changed; reloading annotations for image ID {self.current_image_id}")
|
||||||
|
self._load_annotations_for_current_image()
|
||||||
|
|
||||||
|
def _on_class_selected(self, class_data):
|
||||||
|
"""
|
||||||
|
Handle when an object class is selected or cleared.
|
||||||
|
|
||||||
|
When a specific class is selected, only annotations of that class are drawn.
|
||||||
|
When the selection is cleared ("-- Select Class --"), all annotations are shown.
|
||||||
|
"""
|
||||||
|
if class_data:
|
||||||
|
logger.debug(f"Object class selected: {class_data['class_name']}")
|
||||||
|
else:
|
||||||
|
logger.debug('No class selected ("-- Select Class --"), showing all annotations')
|
||||||
|
|
||||||
|
# Changing the class filter invalidates any previous selection
|
||||||
|
self.selected_annotation_ids = []
|
||||||
|
self.annotation_tools.set_has_selected_annotation(False)
|
||||||
|
|
||||||
|
# Whenever the selection changes, update which annotations are visible
|
||||||
|
self._redraw_annotations_for_current_filter()
|
||||||
|
|
||||||
|
def _on_clear_annotations(self):
|
||||||
|
"""Handle clearing all annotations."""
|
||||||
|
self.annotation_canvas.clear_annotations()
|
||||||
|
# Clear in-memory state and selection, but keep DB entries unchanged
|
||||||
|
self.current_annotations = []
|
||||||
|
self.selected_annotation_ids = []
|
||||||
|
self.annotation_tools.set_has_selected_annotation(False)
|
||||||
|
logger.info("Cleared all annotations")
|
||||||
|
|
||||||
|
def _on_delete_selected_annotation(self):
|
||||||
|
"""Handle deleting the currently selected annotation(s) (if any)."""
|
||||||
|
if not self.selected_annotation_ids:
|
||||||
|
QMessageBox.information(
|
||||||
|
self,
|
||||||
|
"No Selection",
|
||||||
|
"No annotation is currently selected.",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
count = len(self.selected_annotation_ids)
|
||||||
|
if count == 1:
|
||||||
|
question = "Are you sure you want to delete the selected annotation?"
|
||||||
|
title = "Delete Annotation"
|
||||||
|
else:
|
||||||
|
question = f"Are you sure you want to delete the {count} selected annotations?"
|
||||||
|
title = "Delete Annotations"
|
||||||
|
|
||||||
|
reply = QMessageBox.question(
|
||||||
|
self,
|
||||||
|
title,
|
||||||
|
question,
|
||||||
|
QMessageBox.Yes | QMessageBox.No,
|
||||||
|
QMessageBox.No,
|
||||||
|
)
|
||||||
|
if reply != QMessageBox.Yes:
|
||||||
|
return
|
||||||
|
|
||||||
|
failed_ids = []
|
||||||
|
try:
|
||||||
|
for ann_id in self.selected_annotation_ids:
|
||||||
|
try:
|
||||||
|
deleted = self.db_manager.delete_annotation(ann_id)
|
||||||
|
if not deleted:
|
||||||
|
failed_ids.append(ann_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete annotation ID {ann_id}: {e}")
|
||||||
|
failed_ids.append(ann_id)
|
||||||
|
|
||||||
|
if failed_ids:
|
||||||
|
QMessageBox.warning(
|
||||||
|
self,
|
||||||
|
"Partial Failure",
|
||||||
|
"Some annotations could not be deleted:\n" + ", ".join(str(a) for a in failed_ids),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"Deleted {count} annotation(s): " + ", ".join(str(a) for a in self.selected_annotation_ids)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clear selection and reload annotations for the current image from DB
|
||||||
|
self.selected_annotation_ids = []
|
||||||
|
self.annotation_tools.set_has_selected_annotation(False)
|
||||||
|
self._load_annotations_for_current_image()
|
||||||
|
|
||||||
|
# Update list counts.
|
||||||
|
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete annotations: {e}")
|
||||||
|
QMessageBox.critical(
|
||||||
|
self,
|
||||||
|
"Error",
|
||||||
|
f"Failed to delete annotations:\n{str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_annotations_for_current_image(self):
|
||||||
|
"""
|
||||||
|
Load all annotations for the current image from the database and
|
||||||
|
redraw them on the canvas, honoring the currently selected class
|
||||||
|
filter (if any).
|
||||||
|
"""
|
||||||
|
if not self.current_image_id:
|
||||||
|
self.current_annotations = []
|
||||||
|
self.annotation_canvas.clear_annotations()
|
||||||
|
self.selected_annotation_ids = []
|
||||||
|
self.annotation_tools.set_has_selected_annotation(False)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.current_annotations = self.db_manager.get_annotations_for_image(self.current_image_id)
|
||||||
|
# New annotations loaded; reset any selection
|
||||||
|
self.selected_annotation_ids = []
|
||||||
|
self.annotation_tools.set_has_selected_annotation(False)
|
||||||
|
self._redraw_annotations_for_current_filter()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load annotations for image {self.current_image_id}: {e}")
|
||||||
|
QMessageBox.critical(
|
||||||
|
self,
|
||||||
|
"Error",
|
||||||
|
f"Failed to load annotations for this image:\n{str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _redraw_annotations_for_current_filter(self):
|
||||||
|
"""
|
||||||
|
Redraw annotations for the current image, optionally filtered by the
|
||||||
|
currently selected object class.
|
||||||
|
"""
|
||||||
|
# Clear current on-canvas annotations but keep the image
|
||||||
|
self.annotation_canvas.clear_annotations()
|
||||||
|
|
||||||
|
if not self.current_annotations:
|
||||||
|
return
|
||||||
|
|
||||||
|
current_class = self.annotation_tools.get_current_class()
|
||||||
|
selected_class_id = current_class["id"] if current_class else None
|
||||||
|
|
||||||
|
drawn_count = 0
|
||||||
|
for ann in self.current_annotations:
|
||||||
|
# Filter by class if one is selected
|
||||||
|
if selected_class_id is not None and ann.get("class_id") != selected_class_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ann.get("segmentation_mask"):
|
||||||
|
polyline = ann["segmentation_mask"]
|
||||||
|
color = ann.get("class_color", "#FF0000")
|
||||||
|
|
||||||
|
self.annotation_canvas.draw_saved_polyline(
|
||||||
|
polyline,
|
||||||
|
color,
|
||||||
|
width=3,
|
||||||
|
annotation_id=ann["id"],
|
||||||
|
)
|
||||||
|
self.annotation_canvas.draw_saved_bbox(
|
||||||
|
[ann["x_min"], ann["y_min"], ann["x_max"], ann["y_max"]],
|
||||||
|
color,
|
||||||
|
width=3,
|
||||||
|
)
|
||||||
|
drawn_count += 1
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Displayed {drawn_count} annotation(s) for current image with "
|
||||||
|
f"{'no class filter' if selected_class_id is None else f'class_id={selected_class_id}'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _restore_state(self):
|
||||||
|
"""Restore splitter positions from settings."""
|
||||||
|
settings = QSettings("microscopy_app", "object_detection")
|
||||||
|
|
||||||
|
# Restore main splitter state
|
||||||
|
main_state = settings.value("annotation_tab/main_splitter_state")
|
||||||
|
if main_state:
|
||||||
|
self.main_splitter.restoreState(main_state)
|
||||||
|
logger.debug("Restored main splitter state")
|
||||||
|
|
||||||
|
# Restore left splitter state
|
||||||
|
left_state = settings.value("annotation_tab/left_splitter_state")
|
||||||
|
if left_state:
|
||||||
|
self.left_splitter.restoreState(left_state)
|
||||||
|
logger.debug("Restored left splitter state")
|
||||||
|
|
||||||
|
# Restore right splitter state
|
||||||
|
right_state = settings.value("annotation_tab/right_splitter_state")
|
||||||
|
if right_state:
|
||||||
|
self.right_splitter.restoreState(right_state)
|
||||||
|
logger.debug("Restored right splitter state")
|
||||||
|
|
||||||
|
def save_state(self):
|
||||||
|
"""Save splitter positions to settings."""
|
||||||
|
settings = QSettings("microscopy_app", "object_detection")
|
||||||
|
|
||||||
|
# Save main splitter state
|
||||||
|
settings.setValue("annotation_tab/main_splitter_state", self.main_splitter.saveState())
|
||||||
|
|
||||||
|
# Save left splitter state
|
||||||
|
settings.setValue("annotation_tab/left_splitter_state", self.left_splitter.saveState())
|
||||||
|
|
||||||
|
# Save right splitter state
|
||||||
|
settings.setValue("annotation_tab/right_splitter_state", self.right_splitter.saveState())
|
||||||
|
|
||||||
|
logger.debug("Saved annotation tab splitter states")
|
||||||
|
|
||||||
def refresh(self):
|
def refresh(self):
|
||||||
"""Refresh the tab."""
|
"""Refresh the tab."""
|
||||||
pass
|
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
|
||||||
|
|
||||||
|
# ==================== Annotated images list ====================
|
||||||
|
|
||||||
|
def _refresh_annotated_images_list(self, select_image_id: int | None = None) -> None:
|
||||||
|
"""Reload annotated-images list from the database."""
|
||||||
|
if not hasattr(self, "annotated_images_table"):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Preserve selection if possible
|
||||||
|
desired_id = select_image_id if select_image_id is not None else self.current_image_id
|
||||||
|
|
||||||
|
name_filter = ""
|
||||||
|
if hasattr(self, "annotated_filter_edit"):
|
||||||
|
name_filter = self.annotated_filter_edit.text().strip()
|
||||||
|
|
||||||
|
try:
|
||||||
|
rows = self.db_manager.get_annotated_images_summary(name_filter=name_filter)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Failed to load annotated images summary: {exc}")
|
||||||
|
rows = []
|
||||||
|
|
||||||
|
sorting_enabled = self.annotated_images_table.isSortingEnabled()
|
||||||
|
self.annotated_images_table.setSortingEnabled(False)
|
||||||
|
self.annotated_images_table.blockSignals(True)
|
||||||
|
try:
|
||||||
|
self.annotated_images_table.setRowCount(len(rows))
|
||||||
|
for r, entry in enumerate(rows):
|
||||||
|
image_name = str(entry.get("filename") or "")
|
||||||
|
count = int(entry.get("annotation_count") or 0)
|
||||||
|
rel_path = str(entry.get("relative_path") or "")
|
||||||
|
|
||||||
|
name_item = QTableWidgetItem(image_name)
|
||||||
|
# Tooltip shows full path of the image (best-effort: repository_root + relative_path)
|
||||||
|
full_path = rel_path
|
||||||
|
repo_root = self.config_manager.get_image_repository_path()
|
||||||
|
if repo_root and rel_path and not Path(rel_path).is_absolute():
|
||||||
|
try:
|
||||||
|
full_path = str((Path(repo_root) / rel_path).resolve())
|
||||||
|
except Exception:
|
||||||
|
full_path = str(Path(repo_root) / rel_path)
|
||||||
|
name_item.setToolTip(full_path)
|
||||||
|
name_item.setData(Qt.UserRole, int(entry.get("id")))
|
||||||
|
name_item.setData(Qt.UserRole + 1, rel_path)
|
||||||
|
|
||||||
|
count_item = QTableWidgetItem()
|
||||||
|
# Use EditRole to ensure numeric sorting.
|
||||||
|
count_item.setData(Qt.EditRole, count)
|
||||||
|
count_item.setData(Qt.UserRole, int(entry.get("id")))
|
||||||
|
count_item.setData(Qt.UserRole + 1, rel_path)
|
||||||
|
|
||||||
|
self.annotated_images_table.setItem(r, 0, name_item)
|
||||||
|
self.annotated_images_table.setItem(r, 1, count_item)
|
||||||
|
|
||||||
|
# Re-select desired row
|
||||||
|
if desired_id is not None:
|
||||||
|
for r in range(self.annotated_images_table.rowCount()):
|
||||||
|
item = self.annotated_images_table.item(r, 0)
|
||||||
|
if item and item.data(Qt.UserRole) == desired_id:
|
||||||
|
self.annotated_images_table.selectRow(r)
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
self.annotated_images_table.blockSignals(False)
|
||||||
|
self.annotated_images_table.setSortingEnabled(sorting_enabled)
|
||||||
|
|
||||||
|
def _on_annotated_image_selected(self) -> None:
|
||||||
|
"""When user clicks an item in the list, load that image in the annotation canvas."""
|
||||||
|
selected = self.annotated_images_table.selectedItems()
|
||||||
|
if not selected:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Row selection -> take the first column item
|
||||||
|
row = self.annotated_images_table.currentRow()
|
||||||
|
item = self.annotated_images_table.item(row, 0)
|
||||||
|
if not item:
|
||||||
|
return
|
||||||
|
|
||||||
|
image_id = item.data(Qt.UserRole)
|
||||||
|
rel_path = item.data(Qt.UserRole + 1) or ""
|
||||||
|
if not image_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
image_path = self._resolve_image_path_for_relative_path(rel_path)
|
||||||
|
if not image_path:
|
||||||
|
QMessageBox.warning(
|
||||||
|
self,
|
||||||
|
"Image Not Found",
|
||||||
|
"Unable to locate image on disk for:\n"
|
||||||
|
f"{rel_path}\n\n"
|
||||||
|
"Tip: set Settings → Image repository path to the folder containing your images.",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.current_image = Image(image_path)
|
||||||
|
self.current_image_path = image_path
|
||||||
|
self.current_image_id = int(image_id)
|
||||||
|
self.annotation_canvas.load_image(self.current_image)
|
||||||
|
self._load_annotations_for_current_image()
|
||||||
|
self._update_image_info()
|
||||||
|
except ImageLoadError as exc:
|
||||||
|
logger.error(f"Failed to load image '{image_path}': {exc}")
|
||||||
|
QMessageBox.critical(self, "Error Loading Image", f"Failed to load image:\n{exc}")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Unexpected error loading image '{image_path}': {exc}")
|
||||||
|
QMessageBox.critical(self, "Error", f"Unexpected error:\n{exc}")
|
||||||
|
|
||||||
|
def _resolve_image_path_for_relative_path(self, relative_path: str) -> str | None:
|
||||||
|
"""Best-effort conversion from a DB relative_path to an on-disk file path."""
|
||||||
|
|
||||||
|
rel = (relative_path or "").strip()
|
||||||
|
if not rel:
|
||||||
|
return None
|
||||||
|
|
||||||
|
candidates: list[Path] = []
|
||||||
|
|
||||||
|
# 1) Repository root + relative
|
||||||
|
repo_root = (self.config_manager.get_image_repository_path() or "").strip()
|
||||||
|
if repo_root:
|
||||||
|
candidates.append(Path(repo_root) / rel)
|
||||||
|
|
||||||
|
# 2) If the DB path is absolute, try it directly.
|
||||||
|
candidates.append(Path(rel))
|
||||||
|
|
||||||
|
# 3) Try the directory of the currently loaded image (helps when DB stores only filenames)
|
||||||
|
if self.current_image_path:
|
||||||
|
try:
|
||||||
|
candidates.append(Path(self.current_image_path).expanduser().resolve().parent / Path(rel).name)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 4) Try the last directory used by the annotation file picker
|
||||||
|
try:
|
||||||
|
settings = QSettings("microscopy_app", "object_detection")
|
||||||
|
last_dir = settings.value("annotation_tab/last_directory", None)
|
||||||
|
if last_dir:
|
||||||
|
candidates.append(Path(str(last_dir)) / Path(rel).name)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
for p in candidates:
|
||||||
|
try:
|
||||||
|
expanded = p.expanduser()
|
||||||
|
if expanded.exists() and expanded.is_file():
|
||||||
|
return str(expanded.resolve())
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 5) Fallback: search by filename within repository root.
|
||||||
|
filename = Path(rel).name
|
||||||
|
if repo_root and filename:
|
||||||
|
root = Path(repo_root).expanduser()
|
||||||
|
try:
|
||||||
|
if root.exists():
|
||||||
|
for match in root.rglob(filename):
|
||||||
|
if match.is_file():
|
||||||
|
return str(match.resolve())
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug(f"Search for {filename} under {root} failed: {exc}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|||||||
@@ -20,12 +20,14 @@ from PySide6.QtWidgets import (
|
|||||||
)
|
)
|
||||||
from PySide6.QtCore import Qt, QThread, Signal
|
from PySide6.QtCore import Qt, QThread, Signal
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from src.database.db_manager import DatabaseManager
|
from src.database.db_manager import DatabaseManager
|
||||||
from src.utils.config_manager import ConfigManager
|
from src.utils.config_manager import ConfigManager
|
||||||
from src.utils.logger import get_logger
|
from src.utils.logger import get_logger
|
||||||
from src.utils.file_utils import get_image_files
|
from src.utils.file_utils import get_image_files
|
||||||
from src.model.inference import InferenceEngine
|
from src.model.inference import InferenceEngine
|
||||||
|
from src.utils.image import Image
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -147,30 +149,66 @@ class DetectionTab(QWidget):
|
|||||||
self.model_combo.currentIndexChanged.connect(self._on_model_changed)
|
self.model_combo.currentIndexChanged.connect(self._on_model_changed)
|
||||||
|
|
||||||
def _load_models(self):
|
def _load_models(self):
|
||||||
"""Load available models from database."""
|
"""Load available models from database and local storage."""
|
||||||
try:
|
try:
|
||||||
models = self.db_manager.get_models()
|
|
||||||
self.model_combo.clear()
|
self.model_combo.clear()
|
||||||
|
models = self.db_manager.get_models()
|
||||||
|
has_models = False
|
||||||
|
|
||||||
if not models:
|
known_paths = set()
|
||||||
self.model_combo.addItem("No models available", None)
|
|
||||||
self._set_buttons_enabled(False)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Add base model option
|
# Add base model option first (always available)
|
||||||
base_model = self.config_manager.get(
|
base_model = self.config_manager.get(
|
||||||
"models.default_base_model", "yolov8s.pt"
|
"models.default_base_model", "yolov8s-seg.pt"
|
||||||
)
|
|
||||||
self.model_combo.addItem(
|
|
||||||
f"Base Model ({base_model})", {"id": 0, "path": base_model}
|
|
||||||
)
|
)
|
||||||
|
if base_model:
|
||||||
|
base_data = {
|
||||||
|
"id": 0,
|
||||||
|
"path": base_model,
|
||||||
|
"model_name": Path(base_model).stem or "Base Model",
|
||||||
|
"model_version": "pretrained",
|
||||||
|
"base_model": base_model,
|
||||||
|
"source": "base",
|
||||||
|
}
|
||||||
|
self.model_combo.addItem(f"Base Model ({base_model})", base_data)
|
||||||
|
known_paths.add(self._normalize_model_path(base_model))
|
||||||
|
has_models = True
|
||||||
|
|
||||||
# Add trained models
|
# Add trained models from database
|
||||||
for model in models:
|
for model in models:
|
||||||
display_name = f"{model['model_name']} v{model['model_version']}"
|
display_name = f"{model['model_name']} v{model['model_version']}"
|
||||||
self.model_combo.addItem(display_name, model)
|
model_data = {**model, "path": model.get("model_path")}
|
||||||
|
normalized = self._normalize_model_path(model_data.get("path"))
|
||||||
|
if normalized:
|
||||||
|
known_paths.add(normalized)
|
||||||
|
self.model_combo.addItem(display_name, model_data)
|
||||||
|
has_models = True
|
||||||
|
|
||||||
self._set_buttons_enabled(True)
|
# Discover local model files not yet in the database
|
||||||
|
local_models = self._discover_local_models()
|
||||||
|
for model_path in local_models:
|
||||||
|
normalized = self._normalize_model_path(model_path)
|
||||||
|
if normalized in known_paths:
|
||||||
|
continue
|
||||||
|
|
||||||
|
display_name = f"Local Model ({Path(model_path).stem})"
|
||||||
|
model_data = {
|
||||||
|
"id": None,
|
||||||
|
"path": str(model_path),
|
||||||
|
"model_name": Path(model_path).stem,
|
||||||
|
"model_version": "local",
|
||||||
|
"base_model": Path(model_path).stem,
|
||||||
|
"source": "local",
|
||||||
|
}
|
||||||
|
self.model_combo.addItem(display_name, model_data)
|
||||||
|
known_paths.add(normalized)
|
||||||
|
has_models = True
|
||||||
|
|
||||||
|
if not has_models:
|
||||||
|
self.model_combo.addItem("No models available", None)
|
||||||
|
self._set_buttons_enabled(False)
|
||||||
|
else:
|
||||||
|
self._set_buttons_enabled(True)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error loading models: {e}")
|
logger.error(f"Error loading models: {e}")
|
||||||
@@ -199,7 +237,7 @@ class DetectionTab(QWidget):
|
|||||||
self,
|
self,
|
||||||
"Select Image",
|
"Select Image",
|
||||||
start_dir,
|
start_dir,
|
||||||
"Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)",
|
"Images (*" + " *".join(Image.SUPPORTED_EXTENSIONS) + ")",
|
||||||
)
|
)
|
||||||
|
|
||||||
if not file_path:
|
if not file_path:
|
||||||
@@ -249,25 +287,39 @@ class DetectionTab(QWidget):
|
|||||||
QMessageBox.warning(self, "No Model", "Please select a model first.")
|
QMessageBox.warning(self, "No Model", "Please select a model first.")
|
||||||
return
|
return
|
||||||
|
|
||||||
model_path = model_data["path"]
|
model_path = model_data.get("path")
|
||||||
model_id = model_data["id"]
|
if not model_path:
|
||||||
|
QMessageBox.warning(
|
||||||
|
self, "Invalid Model", "Selected model is missing a file path."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
# Ensure we have a valid model ID (create entry for base model if needed)
|
if not Path(model_path).exists():
|
||||||
if model_id == 0:
|
QMessageBox.critical(
|
||||||
# Create database entry for base model
|
self,
|
||||||
base_model = self.config_manager.get(
|
"Model Not Found",
|
||||||
"models.default_base_model", "yolov8s.pt"
|
f"The selected model file could not be found:\n{model_path}",
|
||||||
)
|
|
||||||
model_id = self.db_manager.add_model(
|
|
||||||
model_name="Base Model",
|
|
||||||
model_version="pretrained",
|
|
||||||
model_path=base_model,
|
|
||||||
base_model=base_model,
|
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
model_id = model_data.get("id")
|
||||||
|
|
||||||
|
# Ensure we have a database entry for the selected model
|
||||||
|
if model_id in (None, 0):
|
||||||
|
model_id = self._ensure_model_record(model_data)
|
||||||
|
if not model_id:
|
||||||
|
QMessageBox.critical(
|
||||||
|
self,
|
||||||
|
"Model Registration Failed",
|
||||||
|
"Unable to register the selected model in the database.",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
normalized_model_path = self._normalize_model_path(model_path) or model_path
|
||||||
|
|
||||||
# Create inference engine
|
# Create inference engine
|
||||||
self.inference_engine = InferenceEngine(
|
self.inference_engine = InferenceEngine(
|
||||||
model_path, self.db_manager, model_id
|
normalized_model_path, self.db_manager, model_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get confidence threshold
|
# Get confidence threshold
|
||||||
@@ -338,6 +390,76 @@ class DetectionTab(QWidget):
|
|||||||
self.batch_btn.setEnabled(enabled)
|
self.batch_btn.setEnabled(enabled)
|
||||||
self.model_combo.setEnabled(enabled)
|
self.model_combo.setEnabled(enabled)
|
||||||
|
|
||||||
|
def _discover_local_models(self) -> list:
|
||||||
|
"""Scan the models directory for standalone .pt files."""
|
||||||
|
models_dir = self.config_manager.get_models_directory()
|
||||||
|
if not models_dir:
|
||||||
|
return []
|
||||||
|
|
||||||
|
models_path = Path(models_dir)
|
||||||
|
if not models_path.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
return sorted(
|
||||||
|
[p for p in models_path.rglob("*.pt") if p.is_file()],
|
||||||
|
key=lambda p: str(p).lower(),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error discovering local models: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _normalize_model_path(self, path_value) -> str:
|
||||||
|
"""Return a normalized absolute path string for comparison."""
|
||||||
|
if not path_value:
|
||||||
|
return ""
|
||||||
|
try:
|
||||||
|
return str(Path(path_value).resolve())
|
||||||
|
except Exception:
|
||||||
|
return str(path_value)
|
||||||
|
|
||||||
|
def _ensure_model_record(self, model_data: dict) -> Optional[int]:
|
||||||
|
"""Ensure a database record exists for the selected model."""
|
||||||
|
model_path = model_data.get("path")
|
||||||
|
if not model_path:
|
||||||
|
return None
|
||||||
|
|
||||||
|
normalized_target = self._normalize_model_path(model_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
existing_models = self.db_manager.get_models()
|
||||||
|
for model in existing_models:
|
||||||
|
existing_path = model.get("model_path")
|
||||||
|
if not existing_path:
|
||||||
|
continue
|
||||||
|
normalized_existing = self._normalize_model_path(existing_path)
|
||||||
|
if (
|
||||||
|
normalized_existing == normalized_target
|
||||||
|
or existing_path == model_path
|
||||||
|
):
|
||||||
|
return model["id"]
|
||||||
|
|
||||||
|
model_name = (
|
||||||
|
model_data.get("model_name") or Path(model_path).stem or "Custom Model"
|
||||||
|
)
|
||||||
|
model_version = (
|
||||||
|
model_data.get("model_version") or model_data.get("source") or "local"
|
||||||
|
)
|
||||||
|
base_model = model_data.get(
|
||||||
|
"base_model",
|
||||||
|
self.config_manager.get("models.default_base_model", "yolov8s-seg.pt"),
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.db_manager.add_model(
|
||||||
|
model_name=model_name,
|
||||||
|
model_version=model_version,
|
||||||
|
model_path=normalized_target,
|
||||||
|
base_model=base_model,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to ensure model record for {model_path}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
def refresh(self):
|
def refresh(self):
|
||||||
"""Refresh the tab."""
|
"""Refresh the tab."""
|
||||||
self._load_models()
|
self._load_models()
|
||||||
|
|||||||
@@ -1,46 +1,699 @@
|
|||||||
"""
|
"""
|
||||||
Results tab for the microscopy object detection application.
|
Results tab for browsing stored detections and visualizing overlays.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from PySide6.QtWidgets import (
|
||||||
|
QWidget,
|
||||||
|
QVBoxLayout,
|
||||||
|
QHBoxLayout,
|
||||||
|
QLabel,
|
||||||
|
QGroupBox,
|
||||||
|
QPushButton,
|
||||||
|
QSplitter,
|
||||||
|
QTableWidget,
|
||||||
|
QTableWidgetItem,
|
||||||
|
QHeaderView,
|
||||||
|
QAbstractItemView,
|
||||||
|
QMessageBox,
|
||||||
|
QCheckBox,
|
||||||
|
)
|
||||||
|
from PySide6.QtCore import Qt
|
||||||
|
|
||||||
from src.database.db_manager import DatabaseManager
|
from src.database.db_manager import DatabaseManager
|
||||||
from src.utils.config_manager import ConfigManager
|
from src.utils.config_manager import ConfigManager
|
||||||
|
from src.utils.logger import get_logger
|
||||||
|
from src.utils.image import Image, ImageLoadError
|
||||||
|
from src.gui.widgets import AnnotationCanvasWidget
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ResultsTab(QWidget):
|
class ResultsTab(QWidget):
|
||||||
"""Results tab placeholder."""
|
"""Results tab showing detection history and preview overlays."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
|
||||||
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
|
||||||
):
|
|
||||||
super().__init__(parent)
|
super().__init__(parent)
|
||||||
self.db_manager = db_manager
|
self.db_manager = db_manager
|
||||||
self.config_manager = config_manager
|
self.config_manager = config_manager
|
||||||
|
|
||||||
|
self.detection_summary: List[Dict] = []
|
||||||
|
self.current_selection: Optional[Dict] = None
|
||||||
|
self.current_image: Optional[Image] = None
|
||||||
|
self.current_detections: List[Dict] = []
|
||||||
|
self._image_path_cache: Dict[str, str] = {}
|
||||||
|
|
||||||
self._setup_ui()
|
self._setup_ui()
|
||||||
|
self.refresh()
|
||||||
|
|
||||||
def _setup_ui(self):
|
def _setup_ui(self):
|
||||||
"""Setup user interface."""
|
"""Setup user interface."""
|
||||||
layout = QVBoxLayout()
|
layout = QVBoxLayout()
|
||||||
|
|
||||||
group = QGroupBox("Results")
|
# Splitter for list + preview
|
||||||
group_layout = QVBoxLayout()
|
splitter = QSplitter(Qt.Horizontal)
|
||||||
label = QLabel(
|
|
||||||
"Results viewer will be implemented here.\n\n"
|
|
||||||
"Features:\n"
|
|
||||||
"- Detection history browser\n"
|
|
||||||
"- Advanced filtering\n"
|
|
||||||
"- Statistics dashboard\n"
|
|
||||||
"- Export functionality"
|
|
||||||
)
|
|
||||||
group_layout.addWidget(label)
|
|
||||||
group.setLayout(group_layout)
|
|
||||||
|
|
||||||
layout.addWidget(group)
|
# Left pane: detection list
|
||||||
layout.addStretch()
|
left_container = QWidget()
|
||||||
|
left_layout = QVBoxLayout()
|
||||||
|
left_layout.setContentsMargins(0, 0, 0, 0)
|
||||||
|
|
||||||
|
controls_layout = QHBoxLayout()
|
||||||
|
self.refresh_btn = QPushButton("Refresh")
|
||||||
|
self.refresh_btn.clicked.connect(self.refresh)
|
||||||
|
controls_layout.addWidget(self.refresh_btn)
|
||||||
|
|
||||||
|
self.delete_all_btn = QPushButton("Delete All Detections")
|
||||||
|
self.delete_all_btn.setToolTip(
|
||||||
|
"Permanently delete ALL detections from the database.\n" "This cannot be undone."
|
||||||
|
)
|
||||||
|
self.delete_all_btn.clicked.connect(self._delete_all_detections)
|
||||||
|
controls_layout.addWidget(self.delete_all_btn)
|
||||||
|
|
||||||
|
self.export_labels_btn = QPushButton("Export Labels")
|
||||||
|
self.export_labels_btn.setToolTip(
|
||||||
|
"Export YOLO .txt labels for the selected image/model run.\n"
|
||||||
|
"Output path is inferred from the image path (images/ -> labels/)."
|
||||||
|
)
|
||||||
|
self.export_labels_btn.clicked.connect(self._export_labels_for_current_selection)
|
||||||
|
controls_layout.addWidget(self.export_labels_btn)
|
||||||
|
|
||||||
|
controls_layout.addStretch()
|
||||||
|
left_layout.addLayout(controls_layout)
|
||||||
|
|
||||||
|
self.results_table = QTableWidget(0, 5)
|
||||||
|
self.results_table.setHorizontalHeaderLabels(["Image", "Model", "Detections", "Classes", "Last Updated"])
|
||||||
|
self.results_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Stretch)
|
||||||
|
self.results_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.Stretch)
|
||||||
|
self.results_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeToContents)
|
||||||
|
self.results_table.horizontalHeader().setSectionResizeMode(3, QHeaderView.Stretch)
|
||||||
|
self.results_table.horizontalHeader().setSectionResizeMode(4, QHeaderView.ResizeToContents)
|
||||||
|
self.results_table.setSelectionBehavior(QAbstractItemView.SelectRows)
|
||||||
|
self.results_table.setSelectionMode(QAbstractItemView.SingleSelection)
|
||||||
|
self.results_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
|
||||||
|
self.results_table.itemSelectionChanged.connect(self._on_result_selected)
|
||||||
|
|
||||||
|
left_layout.addWidget(self.results_table)
|
||||||
|
left_container.setLayout(left_layout)
|
||||||
|
|
||||||
|
# Right pane: preview canvas and controls
|
||||||
|
right_container = QWidget()
|
||||||
|
right_layout = QVBoxLayout()
|
||||||
|
right_layout.setContentsMargins(0, 0, 0, 0)
|
||||||
|
|
||||||
|
preview_group = QGroupBox("Detection Preview")
|
||||||
|
preview_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
self.preview_canvas = AnnotationCanvasWidget()
|
||||||
|
# Auto-zoom so newly loaded images fill the available preview viewport.
|
||||||
|
self.preview_canvas.set_auto_fit_to_view(True)
|
||||||
|
self.preview_canvas.set_polyline_enabled(False)
|
||||||
|
self.preview_canvas.set_show_bboxes(True)
|
||||||
|
preview_layout.addWidget(self.preview_canvas)
|
||||||
|
|
||||||
|
toggles_layout = QHBoxLayout()
|
||||||
|
self.show_masks_checkbox = QCheckBox("Show Masks")
|
||||||
|
self.show_masks_checkbox.setChecked(False)
|
||||||
|
self.show_masks_checkbox.stateChanged.connect(self._apply_detection_overlays)
|
||||||
|
self.show_bboxes_checkbox = QCheckBox("Show Bounding Boxes")
|
||||||
|
self.show_bboxes_checkbox.setChecked(True)
|
||||||
|
self.show_bboxes_checkbox.stateChanged.connect(self._toggle_bboxes)
|
||||||
|
self.show_confidence_checkbox = QCheckBox("Show Confidence")
|
||||||
|
self.show_confidence_checkbox.setChecked(False)
|
||||||
|
self.show_confidence_checkbox.stateChanged.connect(self._apply_detection_overlays)
|
||||||
|
toggles_layout.addWidget(self.show_masks_checkbox)
|
||||||
|
toggles_layout.addWidget(self.show_bboxes_checkbox)
|
||||||
|
toggles_layout.addWidget(self.show_confidence_checkbox)
|
||||||
|
toggles_layout.addStretch()
|
||||||
|
preview_layout.addLayout(toggles_layout)
|
||||||
|
|
||||||
|
self.summary_label = QLabel("Select a detection result to preview.")
|
||||||
|
self.summary_label.setWordWrap(True)
|
||||||
|
preview_layout.addWidget(self.summary_label)
|
||||||
|
|
||||||
|
preview_group.setLayout(preview_layout)
|
||||||
|
right_layout.addWidget(preview_group)
|
||||||
|
right_container.setLayout(right_layout)
|
||||||
|
|
||||||
|
splitter.addWidget(left_container)
|
||||||
|
splitter.addWidget(right_container)
|
||||||
|
splitter.setStretchFactor(0, 1)
|
||||||
|
splitter.setStretchFactor(1, 2)
|
||||||
|
|
||||||
|
layout.addWidget(splitter)
|
||||||
self.setLayout(layout)
|
self.setLayout(layout)
|
||||||
|
|
||||||
|
def _delete_all_detections(self):
|
||||||
|
"""Delete all detections from the database after user confirmation."""
|
||||||
|
confirm = QMessageBox.warning(
|
||||||
|
self,
|
||||||
|
"Delete All Detections",
|
||||||
|
"This will permanently delete ALL detections from the database.\n\n"
|
||||||
|
"This action cannot be undone.\n\n"
|
||||||
|
"Do you want to continue?",
|
||||||
|
QMessageBox.Yes | QMessageBox.No,
|
||||||
|
QMessageBox.No,
|
||||||
|
)
|
||||||
|
|
||||||
|
if confirm != QMessageBox.Yes:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
deleted = self.db_manager.delete_all_detections()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Failed to delete all detections: {exc}")
|
||||||
|
QMessageBox.critical(
|
||||||
|
self,
|
||||||
|
"Error",
|
||||||
|
f"Failed to delete detections:\n{exc}",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
QMessageBox.information(
|
||||||
|
self,
|
||||||
|
"Delete All Detections",
|
||||||
|
f"Deleted {deleted} detection(s) from the database.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reset UI state.
|
||||||
|
self.refresh()
|
||||||
|
|
||||||
def refresh(self):
|
def refresh(self):
|
||||||
"""Refresh the tab."""
|
"""Refresh the detection list and preview."""
|
||||||
pass
|
self._load_detection_summary()
|
||||||
|
self._populate_results_table()
|
||||||
|
self.current_selection = None
|
||||||
|
self.current_image = None
|
||||||
|
self.current_detections = []
|
||||||
|
self.preview_canvas.clear()
|
||||||
|
self.summary_label.setText("Select a detection result to preview.")
|
||||||
|
if hasattr(self, "export_labels_btn"):
|
||||||
|
self.export_labels_btn.setEnabled(False)
|
||||||
|
|
||||||
|
def _load_detection_summary(self):
|
||||||
|
"""Load latest detection summaries grouped by image + model."""
|
||||||
|
try:
|
||||||
|
detections = self.db_manager.get_detections(limit=500)
|
||||||
|
summary_map: Dict[tuple, Dict] = {}
|
||||||
|
|
||||||
|
for det in detections:
|
||||||
|
key = (det["image_id"], det["model_id"])
|
||||||
|
metadata = det.get("metadata") or {}
|
||||||
|
entry = summary_map.setdefault(
|
||||||
|
key,
|
||||||
|
{
|
||||||
|
"image_id": det["image_id"],
|
||||||
|
"model_id": det["model_id"],
|
||||||
|
"image_path": det.get("image_path"),
|
||||||
|
"image_filename": det.get("image_filename") or det.get("image_path"),
|
||||||
|
"model_name": det.get("model_name", ""),
|
||||||
|
"model_version": det.get("model_version", ""),
|
||||||
|
"last_detected": det.get("detected_at"),
|
||||||
|
"count": 0,
|
||||||
|
"classes": set(),
|
||||||
|
"source_path": metadata.get("source_path"),
|
||||||
|
"repository_root": metadata.get("repository_root"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
entry["count"] += 1
|
||||||
|
if det.get("detected_at") and (
|
||||||
|
not entry.get("last_detected") or str(det.get("detected_at")) > str(entry.get("last_detected"))
|
||||||
|
):
|
||||||
|
entry["last_detected"] = det.get("detected_at")
|
||||||
|
if det.get("class_name"):
|
||||||
|
entry["classes"].add(det["class_name"])
|
||||||
|
if metadata.get("source_path") and not entry.get("source_path"):
|
||||||
|
entry["source_path"] = metadata.get("source_path")
|
||||||
|
if metadata.get("repository_root") and not entry.get("repository_root"):
|
||||||
|
entry["repository_root"] = metadata.get("repository_root")
|
||||||
|
|
||||||
|
self.detection_summary = sorted(
|
||||||
|
summary_map.values(),
|
||||||
|
key=lambda x: str(x.get("last_detected") or ""),
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load detection summary: {e}")
|
||||||
|
QMessageBox.critical(
|
||||||
|
self,
|
||||||
|
"Error",
|
||||||
|
f"Failed to load detection results:\n{str(e)}",
|
||||||
|
)
|
||||||
|
self.detection_summary = []
|
||||||
|
|
||||||
|
def _populate_results_table(self):
|
||||||
|
"""Populate the table widget with detection summaries."""
|
||||||
|
self.results_table.setRowCount(len(self.detection_summary))
|
||||||
|
|
||||||
|
for row, entry in enumerate(self.detection_summary):
|
||||||
|
model_label = f"{entry['model_name']} {entry['model_version']}".strip()
|
||||||
|
class_list = ", ".join(sorted(entry["classes"])) if entry["classes"] else "-"
|
||||||
|
|
||||||
|
items = [
|
||||||
|
QTableWidgetItem(entry.get("image_filename", "")),
|
||||||
|
QTableWidgetItem(model_label),
|
||||||
|
QTableWidgetItem(str(entry.get("count", 0))),
|
||||||
|
QTableWidgetItem(class_list),
|
||||||
|
QTableWidgetItem(str(entry.get("last_detected") or "")),
|
||||||
|
]
|
||||||
|
|
||||||
|
for col, item in enumerate(items):
|
||||||
|
item.setData(Qt.UserRole, row)
|
||||||
|
self.results_table.setItem(row, col, item)
|
||||||
|
|
||||||
|
self.results_table.clearSelection()
|
||||||
|
|
||||||
|
def _on_result_selected(self):
|
||||||
|
"""Handle selection changes in the detection table."""
|
||||||
|
selected_items = self.results_table.selectedItems()
|
||||||
|
if not selected_items:
|
||||||
|
return
|
||||||
|
|
||||||
|
row = selected_items[0].data(Qt.UserRole)
|
||||||
|
if row is None or row >= len(self.detection_summary):
|
||||||
|
return
|
||||||
|
|
||||||
|
entry = self.detection_summary[row]
|
||||||
|
if (
|
||||||
|
self.current_selection
|
||||||
|
and self.current_selection.get("image_id") == entry["image_id"]
|
||||||
|
and self.current_selection.get("model_id") == entry["model_id"]
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
self.current_selection = entry
|
||||||
|
|
||||||
|
image_path = self._resolve_image_path(entry)
|
||||||
|
if not image_path:
|
||||||
|
QMessageBox.warning(
|
||||||
|
self,
|
||||||
|
"Image Not Found",
|
||||||
|
"Unable to locate the image file for this detection.",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.current_image = Image(image_path)
|
||||||
|
self.preview_canvas.load_image(self.current_image)
|
||||||
|
except ImageLoadError as e:
|
||||||
|
logger.error(f"Failed to load image '{image_path}': {e}")
|
||||||
|
QMessageBox.critical(
|
||||||
|
self,
|
||||||
|
"Image Error",
|
||||||
|
f"Failed to load image for preview:\n{str(e)}",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
self._load_detections_for_selection(entry)
|
||||||
|
self._apply_detection_overlays()
|
||||||
|
self._update_summary_label(entry)
|
||||||
|
if hasattr(self, "export_labels_btn"):
|
||||||
|
self.export_labels_btn.setEnabled(True)
|
||||||
|
|
||||||
|
def _export_labels_for_current_selection(self):
|
||||||
|
"""Export YOLO label file(s) for the currently selected image/model."""
|
||||||
|
if not self.current_selection:
|
||||||
|
QMessageBox.information(self, "Export Labels", "Select a detection result first.")
|
||||||
|
return
|
||||||
|
|
||||||
|
entry = self.current_selection
|
||||||
|
|
||||||
|
image_path_str = self._resolve_image_path(entry)
|
||||||
|
if not image_path_str:
|
||||||
|
QMessageBox.warning(
|
||||||
|
self,
|
||||||
|
"Export Labels",
|
||||||
|
"Unable to locate the image file for this detection; cannot infer labels path.",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Ensure we have the detections for the selection.
|
||||||
|
if not self.current_detections:
|
||||||
|
self._load_detections_for_selection(entry)
|
||||||
|
|
||||||
|
if not self.current_detections:
|
||||||
|
QMessageBox.information(
|
||||||
|
self,
|
||||||
|
"Export Labels",
|
||||||
|
"No detections found for this image/model selection.",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
image_path = Path(image_path_str)
|
||||||
|
try:
|
||||||
|
label_path = self._infer_yolo_label_path(image_path)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Failed to infer label path for {image_path}: {exc}")
|
||||||
|
QMessageBox.critical(
|
||||||
|
self,
|
||||||
|
"Export Labels",
|
||||||
|
f"Failed to infer export path for labels:\n{exc}",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
class_map = self._build_detection_class_index_map(self.current_detections)
|
||||||
|
if not class_map:
|
||||||
|
QMessageBox.warning(
|
||||||
|
self,
|
||||||
|
"Export Labels",
|
||||||
|
"Unable to build class->index mapping (missing class names).",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
lines_written = 0
|
||||||
|
skipped = 0
|
||||||
|
label_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
try:
|
||||||
|
with open(label_path, "w", encoding="utf-8") as handle:
|
||||||
|
print("writing to", label_path)
|
||||||
|
for det in self.current_detections:
|
||||||
|
yolo_line = self._format_detection_as_yolo_line(det, class_map)
|
||||||
|
if not yolo_line:
|
||||||
|
skipped += 1
|
||||||
|
continue
|
||||||
|
handle.write(yolo_line + "\n")
|
||||||
|
lines_written += 1
|
||||||
|
except OSError as exc:
|
||||||
|
logger.error(f"Failed to write labels file {label_path}: {exc}")
|
||||||
|
QMessageBox.critical(
|
||||||
|
self,
|
||||||
|
"Export Labels",
|
||||||
|
f"Failed to write label file:\n{label_path}\n\n{exc}",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
return
|
||||||
|
# Optional: write a classes.txt next to the labels root to make the mapping discoverable.
|
||||||
|
# This is not required by Ultralytics (data.yaml usually holds class names), but helps reuse.
|
||||||
|
try:
|
||||||
|
classes_txt = label_path.parent.parent / "classes.txt"
|
||||||
|
classes_txt.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
inv = {idx: name for name, idx in class_map.items()}
|
||||||
|
with open(classes_txt, "w", encoding="utf-8") as handle:
|
||||||
|
for idx in range(len(inv)):
|
||||||
|
handle.write(f"{inv[idx]}\n")
|
||||||
|
except Exception:
|
||||||
|
# Non-fatal
|
||||||
|
pass
|
||||||
|
|
||||||
|
QMessageBox.information(
|
||||||
|
self,
|
||||||
|
"Export Labels",
|
||||||
|
f"Exported {lines_written} label line(s) to:\n{label_path}\n\nSkipped {skipped} invalid detection(s).",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _infer_yolo_label_path(self, image_path: Path) -> Path:
|
||||||
|
"""Infer a YOLO label path from an image path.
|
||||||
|
|
||||||
|
If the image lives under an `images/` directory (anywhere in the path), we mirror the
|
||||||
|
subpath under a sibling `labels/` directory at the same level.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
/dataset/train/images/sub/img.jpg -> /dataset/train/labels/sub/img.txt
|
||||||
|
"""
|
||||||
|
|
||||||
|
resolved = image_path.expanduser().resolve()
|
||||||
|
|
||||||
|
# Find the nearest ancestor directory named 'images'
|
||||||
|
images_dir: Optional[Path] = None
|
||||||
|
for parent in [resolved.parent, *resolved.parents]:
|
||||||
|
if parent.name.lower() == "images":
|
||||||
|
images_dir = parent
|
||||||
|
break
|
||||||
|
|
||||||
|
if images_dir is not None:
|
||||||
|
rel = resolved.relative_to(images_dir)
|
||||||
|
labels_dir = images_dir.parent / "labels"
|
||||||
|
return (labels_dir / rel).with_suffix(".txt")
|
||||||
|
|
||||||
|
# Fallback: create a local sibling labels folder next to the image.
|
||||||
|
return (resolved.parent / "labels" / resolved.name).with_suffix(".txt")
|
||||||
|
|
||||||
|
def _build_detection_class_index_map(self, detections: List[Dict]) -> Dict[str, int]:
|
||||||
|
"""Build a stable class_name -> YOLO class index mapping.
|
||||||
|
|
||||||
|
Preference order:
|
||||||
|
1) Database object_classes table (alphabetical class_name order)
|
||||||
|
2) Fallback to class_name values present in the detections (alphabetical)
|
||||||
|
"""
|
||||||
|
|
||||||
|
names: List[str] = []
|
||||||
|
try:
|
||||||
|
db_classes = self.db_manager.get_object_classes() or []
|
||||||
|
names = [str(row.get("class_name")) for row in db_classes if row.get("class_name")]
|
||||||
|
except Exception:
|
||||||
|
names = []
|
||||||
|
|
||||||
|
if not names:
|
||||||
|
observed = sorted({str(det.get("class_name")) for det in detections if det.get("class_name")})
|
||||||
|
names = list(observed)
|
||||||
|
|
||||||
|
return {name: idx for idx, name in enumerate(names)}
|
||||||
|
|
||||||
|
def _format_detection_as_yolo_line(self, det: Dict, class_map: Dict[str, int]) -> Optional[str]:
|
||||||
|
"""Convert a detection row to a YOLO label line.
|
||||||
|
|
||||||
|
- If segmentation_mask is present, exports segmentation polygon format:
|
||||||
|
class x1 y1 x2 y2 ...
|
||||||
|
(normalized coordinates)
|
||||||
|
- Otherwise exports bbox format:
|
||||||
|
class x_center y_center width height
|
||||||
|
(normalized coordinates)
|
||||||
|
"""
|
||||||
|
|
||||||
|
class_name = det.get("class_name")
|
||||||
|
if not class_name or class_name not in class_map:
|
||||||
|
return None
|
||||||
|
class_idx = class_map[class_name]
|
||||||
|
|
||||||
|
mask = det.get("segmentation_mask")
|
||||||
|
polygon = self._convert_segmentation_mask_to_polygon(mask)
|
||||||
|
if polygon:
|
||||||
|
coords = " ".join(f"{value:.6f}" for value in polygon)
|
||||||
|
return f"{class_idx} {coords}".strip()
|
||||||
|
|
||||||
|
bbox = self._convert_bbox_to_yolo_xywh(det)
|
||||||
|
if bbox is None:
|
||||||
|
return None
|
||||||
|
x_center, y_center, width, height = bbox
|
||||||
|
return f"{class_idx} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
|
||||||
|
|
||||||
|
def _convert_bbox_to_yolo_xywh(self, det: Dict) -> Optional[Tuple[float, float, float, float]]:
|
||||||
|
"""Convert stored xyxy (normalized) bbox to YOLO xywh (normalized)."""
|
||||||
|
|
||||||
|
x_min = det.get("x_min")
|
||||||
|
y_min = det.get("y_min")
|
||||||
|
x_max = det.get("x_max")
|
||||||
|
y_max = det.get("y_max")
|
||||||
|
if any(v is None for v in (x_min, y_min, x_max, y_max)):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
x_min_f = self._clamp01(float(x_min))
|
||||||
|
y_min_f = self._clamp01(float(y_min))
|
||||||
|
x_max_f = self._clamp01(float(x_max))
|
||||||
|
y_max_f = self._clamp01(float(y_max))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
width = max(0.0, x_max_f - x_min_f)
|
||||||
|
height = max(0.0, y_max_f - y_min_f)
|
||||||
|
if width <= 0.0 or height <= 0.0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
x_center = x_min_f + width / 2.0
|
||||||
|
y_center = y_min_f + height / 2.0
|
||||||
|
return x_center, y_center, width, height
|
||||||
|
|
||||||
|
def _convert_segmentation_mask_to_polygon(self, mask_data) -> List[float]:
|
||||||
|
"""Convert stored segmentation_mask [[x,y], ...] to YOLO polygon coords [x1,y1,...]."""
|
||||||
|
|
||||||
|
if not isinstance(mask_data, list):
|
||||||
|
return []
|
||||||
|
|
||||||
|
coords: List[float] = []
|
||||||
|
for point in mask_data:
|
||||||
|
if not isinstance(point, (list, tuple)) or len(point) < 2:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
x = self._clamp01(float(point[0]))
|
||||||
|
y = self._clamp01(float(point[1]))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
coords.extend([x, y])
|
||||||
|
|
||||||
|
# Need at least 3 points => 6 values.
|
||||||
|
return coords if len(coords) >= 6 else []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _clamp01(value: float) -> float:
|
||||||
|
if value < 0.0:
|
||||||
|
return 0.0
|
||||||
|
if value > 1.0:
|
||||||
|
return 1.0
|
||||||
|
return value
|
||||||
|
|
||||||
|
def _load_detections_for_selection(self, entry: Dict):
|
||||||
|
"""Load detection records for the selected image/model pair."""
|
||||||
|
self.current_detections = []
|
||||||
|
if not entry:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
filters = {"image_id": entry["image_id"], "model_id": entry["model_id"]}
|
||||||
|
self.current_detections = self.db_manager.get_detections(filters)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load detections for preview: {e}")
|
||||||
|
QMessageBox.critical(
|
||||||
|
self,
|
||||||
|
"Error",
|
||||||
|
f"Failed to load detections for this image:\n{str(e)}",
|
||||||
|
)
|
||||||
|
self.current_detections = []
|
||||||
|
|
||||||
|
def _apply_detection_overlays(self):
|
||||||
|
"""Draw detections onto the preview canvas based on current toggles."""
|
||||||
|
self.preview_canvas.clear_annotations()
|
||||||
|
self.preview_canvas.set_show_bboxes(self.show_bboxes_checkbox.isChecked())
|
||||||
|
|
||||||
|
if not self.current_detections or not self.current_image:
|
||||||
|
return
|
||||||
|
|
||||||
|
for det in self.current_detections:
|
||||||
|
color = self._get_class_color(det.get("class_name"))
|
||||||
|
|
||||||
|
if self.show_masks_checkbox.isChecked() and det.get("segmentation_mask"):
|
||||||
|
mask_points = self._convert_mask(det["segmentation_mask"])
|
||||||
|
if mask_points:
|
||||||
|
self.preview_canvas.draw_saved_polyline(mask_points, color)
|
||||||
|
|
||||||
|
bbox = [
|
||||||
|
det.get("x_min"),
|
||||||
|
det.get("y_min"),
|
||||||
|
det.get("x_max"),
|
||||||
|
det.get("y_max"),
|
||||||
|
]
|
||||||
|
if all(v is not None for v in bbox):
|
||||||
|
label = None
|
||||||
|
if self.show_confidence_checkbox.isChecked():
|
||||||
|
confidence = det.get("confidence")
|
||||||
|
if confidence is not None:
|
||||||
|
label = f"{confidence:.2f}"
|
||||||
|
self.preview_canvas.draw_saved_bbox(bbox, color, label=label)
|
||||||
|
|
||||||
|
def _convert_mask(self, mask_points: List[List[float]]) -> List[List[float]]:
|
||||||
|
"""Convert stored [x, y] masks to [y, x] format for the canvas."""
|
||||||
|
converted = []
|
||||||
|
for point in mask_points:
|
||||||
|
if len(point) >= 2:
|
||||||
|
x, y = point[0], point[1]
|
||||||
|
converted.append([y, x])
|
||||||
|
return converted
|
||||||
|
|
||||||
|
def _toggle_bboxes(self):
|
||||||
|
"""Update bounding box visibility on the canvas."""
|
||||||
|
self.preview_canvas.set_show_bboxes(self.show_bboxes_checkbox.isChecked())
|
||||||
|
# Re-render to respect show/hide when toggled
|
||||||
|
self._apply_detection_overlays()
|
||||||
|
|
||||||
|
def _update_summary_label(self, entry: Dict):
|
||||||
|
"""Display textual summary for the selected detection run."""
|
||||||
|
classes = ", ".join(sorted(entry.get("classes", []))) or "-"
|
||||||
|
summary_text = (
|
||||||
|
f"Image: {entry.get('image_filename', 'unknown')}\n"
|
||||||
|
f"Model: {entry.get('model_name', '')} {entry.get('model_version', '')}\n"
|
||||||
|
f"Detections: {entry.get('count', 0)}\n"
|
||||||
|
f"Classes: {classes}\n"
|
||||||
|
f"Last Updated: {entry.get('last_detected', 'n/a')}"
|
||||||
|
)
|
||||||
|
self.summary_label.setText(summary_text)
|
||||||
|
|
||||||
|
def _resolve_image_path(self, entry: Dict) -> Optional[str]:
|
||||||
|
"""Resolve an image path using metadata, cache, and repository hints."""
|
||||||
|
relative_path = entry.get("image_path") if entry else None
|
||||||
|
cache_key = relative_path or entry.get("source_path")
|
||||||
|
if cache_key and cache_key in self._image_path_cache:
|
||||||
|
cached = Path(self._image_path_cache[cache_key])
|
||||||
|
if cached.exists():
|
||||||
|
return self._image_path_cache[cache_key]
|
||||||
|
del self._image_path_cache[cache_key]
|
||||||
|
|
||||||
|
candidates = []
|
||||||
|
source_path = entry.get("source_path") if entry else None
|
||||||
|
if source_path:
|
||||||
|
candidates.append(Path(source_path))
|
||||||
|
|
||||||
|
repo_roots = []
|
||||||
|
if entry.get("repository_root"):
|
||||||
|
repo_roots.append(entry["repository_root"])
|
||||||
|
config_repo = self.config_manager.get_image_repository_path()
|
||||||
|
if config_repo:
|
||||||
|
repo_roots.append(config_repo)
|
||||||
|
|
||||||
|
for root in repo_roots:
|
||||||
|
if relative_path:
|
||||||
|
candidates.append(Path(root) / relative_path)
|
||||||
|
|
||||||
|
if relative_path:
|
||||||
|
candidates.append(Path(relative_path))
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
try:
|
||||||
|
if candidate and candidate.exists():
|
||||||
|
resolved = str(candidate.resolve())
|
||||||
|
if cache_key:
|
||||||
|
self._image_path_cache[cache_key] = resolved
|
||||||
|
return resolved
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Fallback: search by filename in known roots
|
||||||
|
filename = Path(relative_path).name if relative_path else None
|
||||||
|
if filename:
|
||||||
|
search_roots = [Path(root) for root in repo_roots if root]
|
||||||
|
if not search_roots:
|
||||||
|
search_roots = [Path("data")]
|
||||||
|
match = self._search_in_roots(filename, search_roots)
|
||||||
|
if match:
|
||||||
|
resolved = str(match.resolve())
|
||||||
|
if cache_key:
|
||||||
|
self._image_path_cache[cache_key] = resolved
|
||||||
|
return resolved
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _search_in_roots(self, filename: str, roots: List[Path]) -> Optional[Path]:
|
||||||
|
"""Search for a file name within a list of root directories."""
|
||||||
|
for root in roots:
|
||||||
|
try:
|
||||||
|
if not root.exists():
|
||||||
|
continue
|
||||||
|
for candidate in root.rglob(filename):
|
||||||
|
return candidate
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error searching for {filename} in {root}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_class_color(self, class_name: Optional[str]) -> str:
|
||||||
|
"""Return consistent color hex for a class name."""
|
||||||
|
if not class_name:
|
||||||
|
return "#FF6B6B"
|
||||||
|
|
||||||
|
color_map = self.config_manager.get_bbox_colors()
|
||||||
|
if class_name in color_map:
|
||||||
|
return color_map[class_name]
|
||||||
|
|
||||||
|
# Deterministic fallback color based on hash
|
||||||
|
palette = [
|
||||||
|
"#FF6B6B",
|
||||||
|
"#4ECDC4",
|
||||||
|
"#FFD166",
|
||||||
|
"#1D3557",
|
||||||
|
"#F4A261",
|
||||||
|
"#E76F51",
|
||||||
|
]
|
||||||
|
return palette[hash(class_name) % len(palette)]
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -2,45 +2,554 @@
|
|||||||
Validation tab for the microscopy object detection application.
|
Validation tab for the microscopy object detection application.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
from PySide6.QtCore import Qt, QSize
|
||||||
|
from PySide6.QtGui import QPainter, QPixmap
|
||||||
|
from PySide6.QtWidgets import (
|
||||||
|
QWidget,
|
||||||
|
QVBoxLayout,
|
||||||
|
QLabel,
|
||||||
|
QGroupBox,
|
||||||
|
QHBoxLayout,
|
||||||
|
QPushButton,
|
||||||
|
QComboBox,
|
||||||
|
QFormLayout,
|
||||||
|
QScrollArea,
|
||||||
|
QGridLayout,
|
||||||
|
QFrame,
|
||||||
|
QTableWidget,
|
||||||
|
QTableWidgetItem,
|
||||||
|
QHeaderView,
|
||||||
|
QSplitter,
|
||||||
|
QListWidget,
|
||||||
|
QListWidgetItem,
|
||||||
|
QAbstractItemView,
|
||||||
|
QGraphicsView,
|
||||||
|
QGraphicsScene,
|
||||||
|
QGraphicsPixmapItem,
|
||||||
|
)
|
||||||
|
|
||||||
from src.database.db_manager import DatabaseManager
|
from src.database.db_manager import DatabaseManager
|
||||||
from src.utils.config_manager import ConfigManager
|
from src.utils.config_manager import ConfigManager
|
||||||
|
from src.utils.logger import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class _PlotItem:
|
||||||
|
label: str
|
||||||
|
path: Path
|
||||||
|
|
||||||
|
|
||||||
|
class _ZoomableImageView(QGraphicsView):
|
||||||
|
"""Zoomable image viewer.
|
||||||
|
|
||||||
|
- Mouse wheel: zoom in/out
|
||||||
|
- Left mouse drag: pan (ScrollHandDrag)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, parent: Optional[QWidget] = None):
|
||||||
|
super().__init__(parent)
|
||||||
|
self._scene = QGraphicsScene(self)
|
||||||
|
self.setScene(self._scene)
|
||||||
|
self._pixmap_item = QGraphicsPixmapItem()
|
||||||
|
self._scene.addItem(self._pixmap_item)
|
||||||
|
|
||||||
|
# QGraphicsView render hints are QPainter.RenderHints.
|
||||||
|
self.setRenderHints(self.renderHints() | QPainter.RenderHint.SmoothPixmapTransform)
|
||||||
|
self.setDragMode(QGraphicsView.DragMode.ScrollHandDrag)
|
||||||
|
self.setTransformationAnchor(QGraphicsView.ViewportAnchor.AnchorUnderMouse)
|
||||||
|
self.setResizeAnchor(QGraphicsView.ViewportAnchor.AnchorUnderMouse)
|
||||||
|
|
||||||
|
self._has_pixmap = False
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
self._pixmap_item.setPixmap(QPixmap())
|
||||||
|
self._scene.setSceneRect(0, 0, 1, 1)
|
||||||
|
self.resetTransform()
|
||||||
|
self._has_pixmap = False
|
||||||
|
|
||||||
|
def set_pixmap(self, pixmap: QPixmap, *, fit: bool = True) -> None:
|
||||||
|
self._pixmap_item.setPixmap(pixmap)
|
||||||
|
self._scene.setSceneRect(pixmap.rect())
|
||||||
|
self._has_pixmap = not pixmap.isNull()
|
||||||
|
self.resetTransform()
|
||||||
|
if fit and self._has_pixmap:
|
||||||
|
self.fitInView(self._pixmap_item, Qt.AspectRatioMode.KeepAspectRatio)
|
||||||
|
|
||||||
|
def wheelEvent(self, event) -> None: # type: ignore[override]
|
||||||
|
if not self._has_pixmap:
|
||||||
|
return
|
||||||
|
zoom_in_factor = 1.25
|
||||||
|
zoom_out_factor = 1.0 / zoom_in_factor
|
||||||
|
factor = zoom_in_factor if event.angleDelta().y() > 0 else zoom_out_factor
|
||||||
|
self.scale(factor, factor)
|
||||||
|
|
||||||
|
|
||||||
class ValidationTab(QWidget):
|
class ValidationTab(QWidget):
|
||||||
"""Validation tab placeholder."""
|
"""Validation tab that shows stored validation metrics + plots for a selected model."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
|
||||||
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
|
||||||
):
|
|
||||||
super().__init__(parent)
|
super().__init__(parent)
|
||||||
self.db_manager = db_manager
|
self.db_manager = db_manager
|
||||||
self.config_manager = config_manager
|
self.config_manager = config_manager
|
||||||
|
|
||||||
|
self._models: List[Dict[str, Any]] = []
|
||||||
|
self._selected_model_id: Optional[int] = None
|
||||||
|
self._plot_widgets: List[QWidget] = []
|
||||||
|
self._plot_items: List[_PlotItem] = []
|
||||||
|
|
||||||
self._setup_ui()
|
self._setup_ui()
|
||||||
|
self.refresh()
|
||||||
|
|
||||||
def _setup_ui(self):
|
def _setup_ui(self):
|
||||||
"""Setup user interface."""
|
"""Setup user interface."""
|
||||||
layout = QVBoxLayout()
|
layout = QVBoxLayout(self)
|
||||||
|
|
||||||
group = QGroupBox("Validation")
|
# ===== Header controls =====
|
||||||
group_layout = QVBoxLayout()
|
header = QGroupBox("Validation")
|
||||||
label = QLabel(
|
header_layout = QVBoxLayout()
|
||||||
"Validation functionality will be implemented here.\n\n"
|
header_row = QHBoxLayout()
|
||||||
"Features:\n"
|
|
||||||
"- Model validation\n"
|
|
||||||
"- Metrics visualization\n"
|
|
||||||
"- Confusion matrix\n"
|
|
||||||
"- Precision-Recall curves"
|
|
||||||
)
|
|
||||||
group_layout.addWidget(label)
|
|
||||||
group.setLayout(group_layout)
|
|
||||||
|
|
||||||
layout.addWidget(group)
|
header_row.addWidget(QLabel("Select model:"))
|
||||||
layout.addStretch()
|
|
||||||
self.setLayout(layout)
|
self.model_combo = QComboBox()
|
||||||
|
self.model_combo.setMinimumWidth(420)
|
||||||
|
self.model_combo.currentIndexChanged.connect(self._on_model_selected)
|
||||||
|
header_row.addWidget(self.model_combo, 1)
|
||||||
|
|
||||||
|
self.refresh_btn = QPushButton("Refresh")
|
||||||
|
self.refresh_btn.clicked.connect(self.refresh)
|
||||||
|
header_row.addWidget(self.refresh_btn)
|
||||||
|
header_row.addStretch()
|
||||||
|
|
||||||
|
header_layout.addLayout(header_row)
|
||||||
|
self.header_status = QLabel("No models loaded.")
|
||||||
|
self.header_status.setWordWrap(True)
|
||||||
|
header_layout.addWidget(self.header_status)
|
||||||
|
header.setLayout(header_layout)
|
||||||
|
layout.addWidget(header)
|
||||||
|
|
||||||
|
# ===== Metrics =====
|
||||||
|
metrics_group = QGroupBox("Validation Metrics")
|
||||||
|
metrics_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
self.metrics_form = QFormLayout()
|
||||||
|
self.metric_labels: Dict[str, QLabel] = {}
|
||||||
|
for key in ("mAP50", "mAP50-95", "precision", "recall", "fitness"):
|
||||||
|
value_label = QLabel("–")
|
||||||
|
value_label.setTextInteractionFlags(Qt.TextSelectableByMouse)
|
||||||
|
self.metric_labels[key] = value_label
|
||||||
|
self.metrics_form.addRow(f"{key}:", value_label)
|
||||||
|
metrics_layout.addLayout(self.metrics_form)
|
||||||
|
|
||||||
|
self.per_class_table = QTableWidget(0, 3)
|
||||||
|
self.per_class_table.setHorizontalHeaderLabels(["Class", "AP", "AP50"])
|
||||||
|
self.per_class_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Stretch)
|
||||||
|
self.per_class_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeToContents)
|
||||||
|
self.per_class_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeToContents)
|
||||||
|
self.per_class_table.setEditTriggers(QTableWidget.NoEditTriggers)
|
||||||
|
self.per_class_table.setMinimumHeight(160)
|
||||||
|
metrics_layout.addWidget(QLabel("Per-class metrics (if available):"))
|
||||||
|
metrics_layout.addWidget(self.per_class_table)
|
||||||
|
|
||||||
|
metrics_group.setLayout(metrics_layout)
|
||||||
|
layout.addWidget(metrics_group)
|
||||||
|
|
||||||
|
# ===== Plots =====
|
||||||
|
plots_group = QGroupBox("Validation Plots")
|
||||||
|
plots_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
self.plots_status = QLabel("Select a model to see validation plots.")
|
||||||
|
self.plots_status.setWordWrap(True)
|
||||||
|
plots_layout.addWidget(self.plots_status)
|
||||||
|
|
||||||
|
self.plots_splitter = QSplitter(Qt.Orientation.Horizontal)
|
||||||
|
|
||||||
|
# Left: selected image viewer
|
||||||
|
left_widget = QWidget()
|
||||||
|
left_layout = QVBoxLayout(left_widget)
|
||||||
|
left_layout.setContentsMargins(0, 0, 0, 0)
|
||||||
|
|
||||||
|
self.selected_plot_title = QLabel("No image selected.")
|
||||||
|
self.selected_plot_title.setWordWrap(True)
|
||||||
|
self.selected_plot_title.setTextInteractionFlags(Qt.TextSelectableByMouse)
|
||||||
|
left_layout.addWidget(self.selected_plot_title)
|
||||||
|
|
||||||
|
self.plot_view = _ZoomableImageView()
|
||||||
|
self.plot_view.setMinimumHeight(360)
|
||||||
|
left_layout.addWidget(self.plot_view, 1)
|
||||||
|
|
||||||
|
self.selected_plot_path = QLabel("")
|
||||||
|
self.selected_plot_path.setWordWrap(True)
|
||||||
|
self.selected_plot_path.setStyleSheet("color: #888;")
|
||||||
|
self.selected_plot_path.setTextInteractionFlags(Qt.TextSelectableByMouse)
|
||||||
|
left_layout.addWidget(self.selected_plot_path)
|
||||||
|
|
||||||
|
# Right: scrollable list
|
||||||
|
right_widget = QWidget()
|
||||||
|
right_layout = QVBoxLayout(right_widget)
|
||||||
|
right_layout.setContentsMargins(0, 0, 0, 0)
|
||||||
|
right_layout.addWidget(QLabel("Images:"))
|
||||||
|
|
||||||
|
self.plots_list = QListWidget()
|
||||||
|
self.plots_list.setSelectionMode(QAbstractItemView.SelectionMode.SingleSelection)
|
||||||
|
self.plots_list.setIconSize(QSize(160, 160))
|
||||||
|
self.plots_list.itemSelectionChanged.connect(self._on_plot_item_selected)
|
||||||
|
right_layout.addWidget(self.plots_list, 1)
|
||||||
|
|
||||||
|
self.plots_splitter.addWidget(left_widget)
|
||||||
|
self.plots_splitter.addWidget(right_widget)
|
||||||
|
self.plots_splitter.setStretchFactor(0, 3)
|
||||||
|
self.plots_splitter.setStretchFactor(1, 1)
|
||||||
|
plots_layout.addWidget(self.plots_splitter, 1)
|
||||||
|
|
||||||
|
plots_group.setLayout(plots_layout)
|
||||||
|
layout.addWidget(plots_group, 1)
|
||||||
|
|
||||||
|
layout.addStretch(0)
|
||||||
|
|
||||||
|
self._clear_metrics()
|
||||||
|
self._clear_plots()
|
||||||
|
|
||||||
|
# ==================== Public API ====================
|
||||||
|
|
||||||
def refresh(self):
|
def refresh(self):
|
||||||
"""Refresh the tab."""
|
"""Refresh the tab."""
|
||||||
pass
|
self._load_models()
|
||||||
|
self._populate_model_combo()
|
||||||
|
self._restore_or_select_default_model()
|
||||||
|
|
||||||
|
# ==================== Internal: models ====================
|
||||||
|
|
||||||
|
def _load_models(self) -> None:
|
||||||
|
try:
|
||||||
|
self._models = self.db_manager.get_models() or []
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to load models: %s", exc)
|
||||||
|
self._models = []
|
||||||
|
|
||||||
|
def _populate_model_combo(self) -> None:
|
||||||
|
self.model_combo.blockSignals(True)
|
||||||
|
self.model_combo.clear()
|
||||||
|
self.model_combo.addItem("Select a model…", None)
|
||||||
|
|
||||||
|
for model in self._models:
|
||||||
|
model_id = model.get("id")
|
||||||
|
name = (model.get("model_name") or "").strip()
|
||||||
|
version = (model.get("model_version") or "").strip()
|
||||||
|
created_at = model.get("created_at")
|
||||||
|
label = f"{name} {version}".strip()
|
||||||
|
if created_at:
|
||||||
|
label = f"{label} ({created_at})"
|
||||||
|
self.model_combo.addItem(label, model_id)
|
||||||
|
|
||||||
|
self.model_combo.blockSignals(False)
|
||||||
|
|
||||||
|
if self._models:
|
||||||
|
self.header_status.setText(f"Loaded {len(self._models)} model(s).")
|
||||||
|
else:
|
||||||
|
self.header_status.setText("No models found. Train a model first.")
|
||||||
|
|
||||||
|
def _restore_or_select_default_model(self) -> None:
|
||||||
|
if not self._models:
|
||||||
|
self._selected_model_id = None
|
||||||
|
self._clear_metrics()
|
||||||
|
self._clear_plots()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Keep selection if still present.
|
||||||
|
if self._selected_model_id is not None:
|
||||||
|
for idx in range(1, self.model_combo.count()):
|
||||||
|
if self.model_combo.itemData(idx) == self._selected_model_id:
|
||||||
|
self.model_combo.setCurrentIndex(idx)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Otherwise select the newest model (top of get_models ORDER BY created_at DESC).
|
||||||
|
first_model_id = self.model_combo.itemData(1) if self.model_combo.count() > 1 else None
|
||||||
|
if first_model_id is not None:
|
||||||
|
self.model_combo.setCurrentIndex(1)
|
||||||
|
|
||||||
|
def _on_model_selected(self, index: int) -> None:
|
||||||
|
model_id = self.model_combo.itemData(index)
|
||||||
|
if not model_id:
|
||||||
|
self._selected_model_id = None
|
||||||
|
self._clear_metrics()
|
||||||
|
self._clear_plots()
|
||||||
|
self.plots_status.setText("Select a model to see validation plots.")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._selected_model_id = int(model_id)
|
||||||
|
model = self._get_model_by_id(self._selected_model_id)
|
||||||
|
if not model:
|
||||||
|
self._clear_metrics()
|
||||||
|
self._clear_plots()
|
||||||
|
self.plots_status.setText("Selected model not found.")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._render_metrics(model)
|
||||||
|
self._render_plots(model)
|
||||||
|
|
||||||
|
def _get_model_by_id(self, model_id: int) -> Optional[Dict[str, Any]]:
|
||||||
|
for model in self._models:
|
||||||
|
if model.get("id") == model_id:
|
||||||
|
return model
|
||||||
|
try:
|
||||||
|
return self.db_manager.get_model_by_id(model_id)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ==================== Internal: metrics ====================
|
||||||
|
|
||||||
|
def _clear_metrics(self) -> None:
|
||||||
|
for label in self.metric_labels.values():
|
||||||
|
label.setText("–")
|
||||||
|
self.per_class_table.setRowCount(0)
|
||||||
|
|
||||||
|
def _render_metrics(self, model: Dict[str, Any]) -> None:
|
||||||
|
self._clear_metrics()
|
||||||
|
|
||||||
|
metrics: Dict[str, Any] = model.get("metrics") or {}
|
||||||
|
# Training tab stores metrics under results['metrics'] in training results payload.
|
||||||
|
if isinstance(metrics, dict) and "metrics" in metrics and isinstance(metrics.get("metrics"), dict):
|
||||||
|
metrics = metrics.get("metrics") or {}
|
||||||
|
|
||||||
|
def set_metric(key: str, value: Any) -> None:
|
||||||
|
if key not in self.metric_labels:
|
||||||
|
return
|
||||||
|
if value is None:
|
||||||
|
self.metric_labels[key].setText("–")
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
self.metric_labels[key].setText(f"{float(value):.4f}")
|
||||||
|
except Exception:
|
||||||
|
self.metric_labels[key].setText(str(value))
|
||||||
|
|
||||||
|
set_metric("mAP50", metrics.get("mAP50"))
|
||||||
|
set_metric("mAP50-95", metrics.get("mAP50-95") or metrics.get("mAP50_95") or metrics.get("mAP50-95"))
|
||||||
|
set_metric("precision", metrics.get("precision"))
|
||||||
|
set_metric("recall", metrics.get("recall"))
|
||||||
|
set_metric("fitness", metrics.get("fitness"))
|
||||||
|
|
||||||
|
# Optional per-class metrics
|
||||||
|
class_metrics = metrics.get("class_metrics") if isinstance(metrics, dict) else None
|
||||||
|
if isinstance(class_metrics, dict) and class_metrics:
|
||||||
|
items = sorted(class_metrics.items(), key=lambda kv: str(kv[0]))
|
||||||
|
self.per_class_table.setRowCount(len(items))
|
||||||
|
for row, (cls_name, cls_stats) in enumerate(items):
|
||||||
|
ap = (cls_stats or {}).get("ap")
|
||||||
|
ap50 = (cls_stats or {}).get("ap50")
|
||||||
|
self.per_class_table.setItem(row, 0, QTableWidgetItem(str(cls_name)))
|
||||||
|
self.per_class_table.setItem(row, 1, QTableWidgetItem(self._format_float(ap)))
|
||||||
|
self.per_class_table.setItem(row, 2, QTableWidgetItem(self._format_float(ap50)))
|
||||||
|
else:
|
||||||
|
self.per_class_table.setRowCount(0)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _format_float(value: Any) -> str:
|
||||||
|
if value is None:
|
||||||
|
return "–"
|
||||||
|
try:
|
||||||
|
return f"{float(value):.4f}"
|
||||||
|
except Exception:
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
# ==================== Internal: plots ====================
|
||||||
|
|
||||||
|
def _clear_plots(self) -> None:
|
||||||
|
# Remove legacy grid widgets (from the initial implementation).
|
||||||
|
for widget in self._plot_widgets:
|
||||||
|
widget.setParent(None)
|
||||||
|
widget.deleteLater()
|
||||||
|
self._plot_widgets = []
|
||||||
|
|
||||||
|
self._plot_items = []
|
||||||
|
|
||||||
|
if hasattr(self, "plots_list"):
|
||||||
|
self.plots_list.blockSignals(True)
|
||||||
|
self.plots_list.clear()
|
||||||
|
self.plots_list.blockSignals(False)
|
||||||
|
|
||||||
|
if hasattr(self, "plot_view"):
|
||||||
|
self.plot_view.clear()
|
||||||
|
if hasattr(self, "selected_plot_title"):
|
||||||
|
self.selected_plot_title.setText("No image selected.")
|
||||||
|
if hasattr(self, "selected_plot_path"):
|
||||||
|
self.selected_plot_path.setText("")
|
||||||
|
|
||||||
|
def _render_plots(self, model: Dict[str, Any]) -> None:
|
||||||
|
self._clear_plots()
|
||||||
|
|
||||||
|
plot_dirs = self._infer_run_directories(model)
|
||||||
|
plot_items = self._discover_plot_items(plot_dirs)
|
||||||
|
|
||||||
|
if not plot_items:
|
||||||
|
dirs_text = "\n".join(str(p) for p in plot_dirs if p)
|
||||||
|
self.plots_status.setText(
|
||||||
|
"No validation plot images found for this model.\n\n"
|
||||||
|
"Searched directories:\n" + (dirs_text or "(none)")
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
self._plot_items = list(plot_items)
|
||||||
|
self.plots_status.setText(f"Found {len(plot_items)} plot image(s). Select one to view/zoom.")
|
||||||
|
|
||||||
|
self.plots_list.blockSignals(True)
|
||||||
|
self.plots_list.clear()
|
||||||
|
for idx, item in enumerate(self._plot_items):
|
||||||
|
qitem = QListWidgetItem(item.label)
|
||||||
|
qitem.setData(Qt.ItemDataRole.UserRole, idx)
|
||||||
|
|
||||||
|
pix = QPixmap(str(item.path))
|
||||||
|
if not pix.isNull():
|
||||||
|
thumb = pix.scaled(
|
||||||
|
self.plots_list.iconSize(),
|
||||||
|
Qt.AspectRatioMode.KeepAspectRatio,
|
||||||
|
Qt.TransformationMode.SmoothTransformation,
|
||||||
|
)
|
||||||
|
qitem.setIcon(thumb)
|
||||||
|
self.plots_list.addItem(qitem)
|
||||||
|
self.plots_list.blockSignals(False)
|
||||||
|
|
||||||
|
if self.plots_list.count() > 0:
|
||||||
|
self.plots_list.setCurrentRow(0)
|
||||||
|
|
||||||
|
def _on_plot_item_selected(self) -> None:
|
||||||
|
if not self._plot_items:
|
||||||
|
return
|
||||||
|
|
||||||
|
selected = self.plots_list.selectedItems()
|
||||||
|
if not selected:
|
||||||
|
return
|
||||||
|
|
||||||
|
idx = selected[0].data(Qt.ItemDataRole.UserRole)
|
||||||
|
try:
|
||||||
|
idx_int = int(idx)
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
if idx_int < 0 or idx_int >= len(self._plot_items):
|
||||||
|
return
|
||||||
|
|
||||||
|
plot = self._plot_items[idx_int]
|
||||||
|
self.selected_plot_title.setText(plot.label)
|
||||||
|
self.selected_plot_path.setText(str(plot.path))
|
||||||
|
|
||||||
|
pix = QPixmap(str(plot.path))
|
||||||
|
if pix.isNull():
|
||||||
|
self.plot_view.clear()
|
||||||
|
return
|
||||||
|
self.plot_view.set_pixmap(pix, fit=True)
|
||||||
|
|
||||||
|
def _infer_run_directories(self, model: Dict[str, Any]) -> List[Path]:
|
||||||
|
dirs: List[Path] = []
|
||||||
|
|
||||||
|
# 1) Infer from model_path: .../<run>/weights/best.pt -> <run>
|
||||||
|
model_path = model.get("model_path")
|
||||||
|
if model_path:
|
||||||
|
try:
|
||||||
|
p = Path(str(model_path)).expanduser()
|
||||||
|
if p.name.lower().endswith(".pt"):
|
||||||
|
# If it lives under weights/, use parent.parent.
|
||||||
|
if p.parent.name == "weights" and p.parent.parent.exists():
|
||||||
|
dirs.append(p.parent.parent)
|
||||||
|
elif p.parent.exists():
|
||||||
|
dirs.append(p.parent)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 2) Look at training_params.stage_results[].results.save_dir
|
||||||
|
training_params = model.get("training_params") or {}
|
||||||
|
stage_results = None
|
||||||
|
if isinstance(training_params, dict):
|
||||||
|
stage_results = training_params.get("stage_results")
|
||||||
|
if isinstance(stage_results, list):
|
||||||
|
for stage in stage_results:
|
||||||
|
results = (stage or {}).get("results")
|
||||||
|
save_dir = (results or {}).get("save_dir") if isinstance(results, dict) else None
|
||||||
|
if save_dir:
|
||||||
|
try:
|
||||||
|
save_path = Path(str(save_dir)).expanduser()
|
||||||
|
if save_path.exists():
|
||||||
|
dirs.append(save_path)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Deduplicate while preserving order.
|
||||||
|
unique: List[Path] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for d in dirs:
|
||||||
|
try:
|
||||||
|
resolved = str(d.resolve())
|
||||||
|
except Exception:
|
||||||
|
resolved = str(d)
|
||||||
|
if resolved not in seen and d.exists() and d.is_dir():
|
||||||
|
seen.add(resolved)
|
||||||
|
unique.append(d)
|
||||||
|
return unique
|
||||||
|
|
||||||
|
def _discover_plot_items(self, directories: Sequence[Path]) -> List[_PlotItem]:
|
||||||
|
# Prefer canonical Ultralytics filenames first, then fall back to any png/jpg.
|
||||||
|
preferred_names = [
|
||||||
|
"results.png",
|
||||||
|
"results.jpg",
|
||||||
|
"confusion_matrix.png",
|
||||||
|
"confusion_matrix_normalized.png",
|
||||||
|
"labels.jpg",
|
||||||
|
"labels.png",
|
||||||
|
"BoxPR_curve.png",
|
||||||
|
"BoxP_curve.png",
|
||||||
|
"BoxR_curve.png",
|
||||||
|
"BoxF1_curve.png",
|
||||||
|
"MaskPR_curve.png",
|
||||||
|
"MaskP_curve.png",
|
||||||
|
"MaskR_curve.png",
|
||||||
|
"MaskF1_curve.png",
|
||||||
|
"val_batch0_pred.jpg",
|
||||||
|
"val_batch0_labels.jpg",
|
||||||
|
]
|
||||||
|
|
||||||
|
found: List[_PlotItem] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
|
||||||
|
for d in directories:
|
||||||
|
# 1) Preferred
|
||||||
|
for name in preferred_names:
|
||||||
|
p = d / name
|
||||||
|
if p.exists() and p.is_file():
|
||||||
|
key = str(p)
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
found.append(_PlotItem(label=f"{name} (from {d.name})", path=p))
|
||||||
|
|
||||||
|
# 2) Curated globs
|
||||||
|
for pattern in ("train_batch*.jpg", "val_batch*.jpg", "*curve*.png"):
|
||||||
|
for p in sorted(d.glob(pattern)):
|
||||||
|
if not p.is_file():
|
||||||
|
continue
|
||||||
|
key = str(p)
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
found.append(_PlotItem(label=f"{p.name} (from {d.name})", path=p))
|
||||||
|
|
||||||
|
# 3) Fallback: any top-level png/jpg (excluding weights dir contents)
|
||||||
|
for ext in ("*.png", "*.jpg", "*.jpeg", "*.webp"):
|
||||||
|
for p in sorted(d.glob(ext)):
|
||||||
|
if not p.is_file():
|
||||||
|
continue
|
||||||
|
key = str(p)
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
found.append(_PlotItem(label=f"{p.name} (from {d.name})", path=p))
|
||||||
|
|
||||||
|
# Keep list bounded to avoid UI overload for huge runs.
|
||||||
|
return found[:60]
|
||||||
|
|||||||
@@ -0,0 +1,7 @@
|
|||||||
|
"""GUI widgets for the microscopy object detection application."""
|
||||||
|
|
||||||
|
from src.gui.widgets.image_display_widget import ImageDisplayWidget
|
||||||
|
from src.gui.widgets.annotation_canvas_widget import AnnotationCanvasWidget
|
||||||
|
from src.gui.widgets.annotation_tools_widget import AnnotationToolsWidget
|
||||||
|
|
||||||
|
__all__ = ["ImageDisplayWidget", "AnnotationCanvasWidget", "AnnotationToolsWidget"]
|
||||||
|
|||||||
930
src/gui/widgets/annotation_canvas_widget.py
Normal file
930
src/gui/widgets/annotation_canvas_widget.py
Normal file
@@ -0,0 +1,930 @@
|
|||||||
|
"""
|
||||||
|
Annotation canvas widget for drawing annotations on images.
|
||||||
|
Currently supports polyline drawing tool with color selection for manual annotation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
|
||||||
|
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QScrollArea
|
||||||
|
from PySide6.QtGui import (
|
||||||
|
QPixmap,
|
||||||
|
QImage,
|
||||||
|
QPainter,
|
||||||
|
QPen,
|
||||||
|
QColor,
|
||||||
|
QKeyEvent,
|
||||||
|
QMouseEvent,
|
||||||
|
QPaintEvent,
|
||||||
|
QPolygonF,
|
||||||
|
)
|
||||||
|
from PySide6.QtCore import Qt, QEvent, Signal, QPoint, QPointF, QRect, QTimer
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from src.utils.image import Image, ImageLoadError
|
||||||
|
from src.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def perpendicular_distance(
|
||||||
|
point: Tuple[float, float],
|
||||||
|
start: Tuple[float, float],
|
||||||
|
end: Tuple[float, float],
|
||||||
|
) -> float:
|
||||||
|
"""Perpendicular distance from `point` to the line defined by `start`->`end`."""
|
||||||
|
(x, y), (x1, y1), (x2, y2) = point, start, end
|
||||||
|
dx = x2 - x1
|
||||||
|
dy = y2 - y1
|
||||||
|
if dx == 0.0 and dy == 0.0:
|
||||||
|
return math.hypot(x - x1, y - y1)
|
||||||
|
num = abs(dy * x - dx * y + x2 * y1 - y2 * x1)
|
||||||
|
den = math.hypot(dx, dy)
|
||||||
|
return num / den
|
||||||
|
|
||||||
|
|
||||||
|
def rdp(points: List[Tuple[float, float]], epsilon: float) -> List[Tuple[float, float]]:
|
||||||
|
"""
|
||||||
|
Recursive Ramer-Douglas-Peucker (RDP) polyline simplification.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
points: List of (x, y) points.
|
||||||
|
epsilon: Maximum allowed perpendicular distance in pixels.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Simplified list of (x, y) points including first and last.
|
||||||
|
"""
|
||||||
|
if len(points) <= 2:
|
||||||
|
return list(points)
|
||||||
|
|
||||||
|
start = points[0]
|
||||||
|
end = points[-1]
|
||||||
|
max_dist = -1.0
|
||||||
|
index = -1
|
||||||
|
|
||||||
|
for i in range(1, len(points) - 1):
|
||||||
|
d = perpendicular_distance(points[i], start, end)
|
||||||
|
if d > max_dist:
|
||||||
|
max_dist = d
|
||||||
|
index = i
|
||||||
|
|
||||||
|
if max_dist > epsilon:
|
||||||
|
# Recursive split
|
||||||
|
left = rdp(points[: index + 1], epsilon)
|
||||||
|
right = rdp(points[index:], epsilon)
|
||||||
|
# Concatenate but avoid duplicate at split point
|
||||||
|
return left[:-1] + right
|
||||||
|
|
||||||
|
# Keep only start and end
|
||||||
|
return [start, end]
|
||||||
|
|
||||||
|
|
||||||
|
def simplify_polyline(points: List[Tuple[float, float]], epsilon: float) -> List[Tuple[float, float]]:
|
||||||
|
"""
|
||||||
|
Simplify a polyline with RDP while preserving closure semantics.
|
||||||
|
|
||||||
|
If the polyline is closed (first == last), the duplicate last point is removed
|
||||||
|
before simplification and then re-added after simplification.
|
||||||
|
"""
|
||||||
|
if not points:
|
||||||
|
return []
|
||||||
|
|
||||||
|
pts = [(float(x), float(y)) for x, y in points]
|
||||||
|
closed = False
|
||||||
|
|
||||||
|
if len(pts) >= 2 and pts[0] == pts[-1]:
|
||||||
|
closed = True
|
||||||
|
pts = pts[:-1] # remove duplicate last for simplification
|
||||||
|
|
||||||
|
if len(pts) <= 2:
|
||||||
|
simplified = list(pts)
|
||||||
|
else:
|
||||||
|
simplified = rdp(pts, epsilon)
|
||||||
|
|
||||||
|
if closed and simplified:
|
||||||
|
if simplified[0] != simplified[-1]:
|
||||||
|
simplified.append(simplified[0])
|
||||||
|
|
||||||
|
return simplified
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotationCanvasWidget(QWidget):
|
||||||
|
"""
|
||||||
|
Widget for displaying images and drawing annotations with zoom and drawing tools.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Display images with zoom functionality
|
||||||
|
- Polyline tool for drawing annotations
|
||||||
|
- Configurable pen color and width
|
||||||
|
- Mouse-based drawing interface
|
||||||
|
- Zoom in/out with mouse wheel and keyboard
|
||||||
|
|
||||||
|
Signals:
|
||||||
|
zoom_changed: Emitted when zoom level changes (float zoom_scale)
|
||||||
|
annotation_drawn: Emitted when a new stroke is completed (list of points)
|
||||||
|
"""
|
||||||
|
|
||||||
|
zoom_changed = Signal(float)
|
||||||
|
annotation_drawn = Signal(list) # List of (x, y) points in normalized coordinates
|
||||||
|
# Emitted when the user selects an existing polyline on the canvas.
|
||||||
|
# Carries the associated annotation_id (int) or None if selection is cleared
|
||||||
|
annotation_selected = Signal(object)
|
||||||
|
|
||||||
|
def __init__(self, parent=None):
|
||||||
|
"""Initialize the annotation canvas widget."""
|
||||||
|
super().__init__(parent)
|
||||||
|
|
||||||
|
self.current_image = None
|
||||||
|
self.original_pixmap = None
|
||||||
|
self.annotation_pixmap = None # Overlay for annotations
|
||||||
|
self.zoom_scale = 1.0
|
||||||
|
self.zoom_min = 0.1
|
||||||
|
self.zoom_max = 10.0
|
||||||
|
self.zoom_step = 0.1
|
||||||
|
self.zoom_wheel_step = 0.15
|
||||||
|
|
||||||
|
# Auto-fit behavior (opt-in): when enabled, newly loaded images (and resizes)
|
||||||
|
# will scale to fill the available viewport while preserving aspect ratio.
|
||||||
|
self._auto_fit_to_view: bool = False
|
||||||
|
|
||||||
|
# Drawing / interaction state
|
||||||
|
self.is_drawing = False
|
||||||
|
self.polyline_enabled = False
|
||||||
|
self.polyline_pen_color = QColor(255, 0, 0, 128) # Default red with 50% alpha
|
||||||
|
self.polyline_pen_width = 3
|
||||||
|
self.show_bboxes: bool = True # Control visibility of bounding boxes
|
||||||
|
|
||||||
|
# Current stroke and stored polylines (in image coordinates, pixel units)
|
||||||
|
self.current_stroke: List[Tuple[float, float]] = []
|
||||||
|
self.polylines: List[List[Tuple[float, float]]] = []
|
||||||
|
self.stroke_meta: List[Dict[str, Any]] = [] # per-polyline style (color, width)
|
||||||
|
# Optional DB annotation_id for each stored polyline (None for temporary / unsaved)
|
||||||
|
self.polyline_annotation_ids: List[Optional[int]] = []
|
||||||
|
# Indices in self.polylines of the currently selected polylines (multi-select)
|
||||||
|
self.selected_polyline_indices: List[int] = []
|
||||||
|
|
||||||
|
# Stored bounding boxes in normalized coordinates (x_min, y_min, x_max, y_max)
|
||||||
|
self.bboxes: List[List[float]] = []
|
||||||
|
self.bbox_meta: List[Dict[str, Any]] = [] # per-bbox style (color, width)
|
||||||
|
|
||||||
|
# Legacy collection of strokes in normalized coordinates (kept for API compatibility)
|
||||||
|
self.all_strokes: List[dict] = []
|
||||||
|
|
||||||
|
# RDP simplification parameters (in pixels)
|
||||||
|
self.simplify_on_finish: bool = True
|
||||||
|
self.simplify_epsilon: float = 2.0
|
||||||
|
self.sample_threshold: float = 2.0 # minimum movement to sample a new point
|
||||||
|
|
||||||
|
self._setup_ui()
|
||||||
|
|
||||||
|
def set_auto_fit_to_view(self, enabled: bool):
|
||||||
|
"""Enable/disable automatic zoom-to-fit behavior."""
|
||||||
|
self._auto_fit_to_view = bool(enabled)
|
||||||
|
if self._auto_fit_to_view and self.original_pixmap is not None:
|
||||||
|
QTimer.singleShot(0, self.fit_to_view)
|
||||||
|
|
||||||
|
def fit_to_view(self, padding_px: int = 6):
|
||||||
|
"""Zoom the image so it fits the scroll area's viewport (aspect preserved)."""
|
||||||
|
if self.original_pixmap is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
viewport = self.scroll_area.viewport().size()
|
||||||
|
available_w = max(1, int(viewport.width()) - int(padding_px))
|
||||||
|
available_h = max(1, int(viewport.height()) - int(padding_px))
|
||||||
|
|
||||||
|
img_w = max(1, int(self.original_pixmap.width()))
|
||||||
|
img_h = max(1, int(self.original_pixmap.height()))
|
||||||
|
|
||||||
|
scale_w = available_w / img_w
|
||||||
|
scale_h = available_h / img_h
|
||||||
|
new_scale = min(scale_w, scale_h)
|
||||||
|
new_scale = max(self.zoom_min, min(self.zoom_max, float(new_scale)))
|
||||||
|
|
||||||
|
if abs(new_scale - self.zoom_scale) < 1e-4:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.zoom_scale = new_scale
|
||||||
|
self._apply_zoom()
|
||||||
|
|
||||||
|
def _setup_ui(self):
|
||||||
|
"""Setup user interface."""
|
||||||
|
layout = QVBoxLayout()
|
||||||
|
layout.setContentsMargins(0, 0, 0, 0)
|
||||||
|
|
||||||
|
# Scroll area for canvas
|
||||||
|
self.scroll_area = QScrollArea()
|
||||||
|
self.scroll_area.setWidgetResizable(True)
|
||||||
|
self.scroll_area.setMinimumHeight(400)
|
||||||
|
|
||||||
|
self.canvas_label = QLabel("No image loaded")
|
||||||
|
self.canvas_label.setAlignment(Qt.AlignCenter)
|
||||||
|
self.canvas_label.setStyleSheet("QLabel { background-color: #2b2b2b; color: #888; }")
|
||||||
|
self.canvas_label.setScaledContents(False)
|
||||||
|
self.canvas_label.setMouseTracking(True)
|
||||||
|
|
||||||
|
self.scroll_area.setWidget(self.canvas_label)
|
||||||
|
self.scroll_area.viewport().installEventFilter(self)
|
||||||
|
|
||||||
|
layout.addWidget(self.scroll_area)
|
||||||
|
self.setLayout(layout)
|
||||||
|
|
||||||
|
self.setFocusPolicy(Qt.StrongFocus)
|
||||||
|
|
||||||
|
def load_image(self, image: Image):
|
||||||
|
"""
|
||||||
|
Load and display an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Image object to display
|
||||||
|
"""
|
||||||
|
self.current_image = image
|
||||||
|
self.zoom_scale = 1.0
|
||||||
|
self.clear_annotations()
|
||||||
|
self._display_image()
|
||||||
|
|
||||||
|
# Defer fit-to-view until the widget has a valid viewport size.
|
||||||
|
if self._auto_fit_to_view:
|
||||||
|
QTimer.singleShot(0, self.fit_to_view)
|
||||||
|
|
||||||
|
logger.debug(f"Loaded image into annotation canvas: {image.width}x{image.height}")
|
||||||
|
|
||||||
|
def resizeEvent(self, event):
|
||||||
|
"""Optionally keep the image fitted when the widget is resized."""
|
||||||
|
super().resizeEvent(event)
|
||||||
|
if self._auto_fit_to_view and self.original_pixmap is not None:
|
||||||
|
QTimer.singleShot(0, self.fit_to_view)
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""Clear the displayed image and all annotations."""
|
||||||
|
self.current_image = None
|
||||||
|
self.original_pixmap = None
|
||||||
|
self.annotation_pixmap = None
|
||||||
|
self.zoom_scale = 1.0
|
||||||
|
self.clear_annotations()
|
||||||
|
self.canvas_label.setText("No image loaded")
|
||||||
|
self.canvas_label.setPixmap(QPixmap())
|
||||||
|
|
||||||
|
def clear_annotations(self):
|
||||||
|
"""Clear all drawn annotations."""
|
||||||
|
self.all_strokes = []
|
||||||
|
self.current_stroke = []
|
||||||
|
self.polylines = []
|
||||||
|
self.stroke_meta = []
|
||||||
|
self.polyline_annotation_ids = []
|
||||||
|
self.selected_polyline_indices = []
|
||||||
|
self.bboxes = []
|
||||||
|
self.bbox_meta = []
|
||||||
|
self.is_drawing = False
|
||||||
|
if self.annotation_pixmap:
|
||||||
|
self.annotation_pixmap.fill(Qt.transparent)
|
||||||
|
self._update_display()
|
||||||
|
|
||||||
|
def _display_image(self):
|
||||||
|
"""Display the current image in the canvas."""
|
||||||
|
if self.current_image is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get image data in a format compatible with Qt
|
||||||
|
if self.current_image.channels in (3, 4):
|
||||||
|
image_data = self.current_image.get_rgb()
|
||||||
|
else:
|
||||||
|
image_data = self.current_image.get_qt_rgb()
|
||||||
|
|
||||||
|
height, width = image_data.shape[:2]
|
||||||
|
bytes_per_line = image_data.strides[0]
|
||||||
|
|
||||||
|
qimage = QImage(
|
||||||
|
image_data.data,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
bytes_per_line,
|
||||||
|
QImage.Format_RGBX32FPx4, # self.current_image.qtimage_format,
|
||||||
|
).copy() # Copy so Qt owns the buffer even after numpy array goes out of scope
|
||||||
|
|
||||||
|
self.original_pixmap = QPixmap.fromImage(qimage)
|
||||||
|
|
||||||
|
# Create transparent overlay for annotations
|
||||||
|
self.annotation_pixmap = QPixmap(self.original_pixmap.size())
|
||||||
|
self.annotation_pixmap.fill(Qt.transparent)
|
||||||
|
|
||||||
|
self._apply_zoom()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error displaying image: {e}")
|
||||||
|
raise ImageLoadError(f"Failed to display image: {str(e)}")
|
||||||
|
|
||||||
|
def _apply_zoom(self):
|
||||||
|
"""Apply current zoom level to the displayed image."""
|
||||||
|
if self.original_pixmap is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
scaled_width = int(self.original_pixmap.width() * self.zoom_scale)
|
||||||
|
scaled_height = int(self.original_pixmap.height() * self.zoom_scale)
|
||||||
|
|
||||||
|
# Scale both image and annotations
|
||||||
|
scaled_image = self.original_pixmap.scaled(
|
||||||
|
scaled_width,
|
||||||
|
scaled_height,
|
||||||
|
Qt.KeepAspectRatio,
|
||||||
|
(Qt.SmoothTransformation if self.zoom_scale >= 1.0 else Qt.FastTransformation),
|
||||||
|
)
|
||||||
|
|
||||||
|
scaled_annotations = self.annotation_pixmap.scaled(
|
||||||
|
scaled_width,
|
||||||
|
scaled_height,
|
||||||
|
Qt.KeepAspectRatio,
|
||||||
|
(Qt.SmoothTransformation if self.zoom_scale >= 1.0 else Qt.FastTransformation),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Composite image and annotations
|
||||||
|
combined = QPixmap(scaled_image.size())
|
||||||
|
painter = QPainter(combined)
|
||||||
|
painter.drawPixmap(0, 0, scaled_image)
|
||||||
|
painter.drawPixmap(0, 0, scaled_annotations)
|
||||||
|
painter.end()
|
||||||
|
|
||||||
|
self.canvas_label.setPixmap(combined)
|
||||||
|
self.canvas_label.setScaledContents(False)
|
||||||
|
self.canvas_label.adjustSize()
|
||||||
|
|
||||||
|
self.zoom_changed.emit(self.zoom_scale)
|
||||||
|
|
||||||
|
def _update_display(self):
|
||||||
|
"""Update display after drawing."""
|
||||||
|
self._apply_zoom()
|
||||||
|
|
||||||
|
def set_polyline_enabled(self, enabled: bool):
|
||||||
|
"""Enable or disable polyline tool."""
|
||||||
|
self.polyline_enabled = enabled
|
||||||
|
if enabled:
|
||||||
|
self.canvas_label.setCursor(Qt.CrossCursor)
|
||||||
|
else:
|
||||||
|
self.canvas_label.setCursor(Qt.ArrowCursor)
|
||||||
|
|
||||||
|
def set_polyline_pen_color(self, color: QColor):
|
||||||
|
"""Set polyline pen color."""
|
||||||
|
self.polyline_pen_color = color
|
||||||
|
|
||||||
|
def set_polyline_pen_width(self, width: int):
|
||||||
|
"""Set polyline pen width."""
|
||||||
|
self.polyline_pen_width = max(1, width)
|
||||||
|
|
||||||
|
def get_zoom_percentage(self) -> int:
|
||||||
|
"""Get current zoom level as percentage."""
|
||||||
|
return int(self.zoom_scale * 100)
|
||||||
|
|
||||||
|
def zoom_in(self):
|
||||||
|
"""Zoom in on the image."""
|
||||||
|
if self.original_pixmap is None:
|
||||||
|
return
|
||||||
|
new_scale = self.zoom_scale + self.zoom_step
|
||||||
|
if new_scale <= self.zoom_max:
|
||||||
|
self.zoom_scale = new_scale
|
||||||
|
self._apply_zoom()
|
||||||
|
|
||||||
|
def zoom_out(self):
|
||||||
|
"""Zoom out from the image."""
|
||||||
|
if self.original_pixmap is None:
|
||||||
|
return
|
||||||
|
new_scale = self.zoom_scale - self.zoom_step
|
||||||
|
if new_scale >= self.zoom_min:
|
||||||
|
self.zoom_scale = new_scale
|
||||||
|
self._apply_zoom()
|
||||||
|
|
||||||
|
def reset_zoom(self):
|
||||||
|
"""Reset zoom to 100%."""
|
||||||
|
if self.original_pixmap is None:
|
||||||
|
return
|
||||||
|
self.zoom_scale = 1.0
|
||||||
|
self._apply_zoom()
|
||||||
|
|
||||||
|
def _canvas_to_image_coords(self, pos: QPoint) -> Optional[Tuple[int, int]]:
|
||||||
|
"""Convert canvas coordinates to image coordinates, accounting for zoom and centering."""
|
||||||
|
if self.original_pixmap is None or self.canvas_label.pixmap() is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get the displayed pixmap size (after zoom)
|
||||||
|
displayed_pixmap = self.canvas_label.pixmap()
|
||||||
|
displayed_width = displayed_pixmap.width()
|
||||||
|
displayed_height = displayed_pixmap.height()
|
||||||
|
|
||||||
|
# Calculate offset due to label centering (label might be larger than pixmap)
|
||||||
|
label_width = self.canvas_label.width()
|
||||||
|
label_height = self.canvas_label.height()
|
||||||
|
offset_x = max(0, (label_width - displayed_width) // 2)
|
||||||
|
offset_y = max(0, (label_height - displayed_height) // 2)
|
||||||
|
|
||||||
|
# Adjust position for offset and convert to image coordinates
|
||||||
|
x = (pos.x() - offset_x) / self.zoom_scale
|
||||||
|
y = (pos.y() - offset_y) / self.zoom_scale
|
||||||
|
|
||||||
|
# Check bounds
|
||||||
|
if 0 <= x < self.original_pixmap.width() and 0 <= y < self.original_pixmap.height():
|
||||||
|
return (int(x), int(y))
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _find_polyline_at(self, img_x: float, img_y: float, threshold_px: float = 5.0) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
Find index of polyline whose geometry is within threshold_px of (img_x, img_y).
|
||||||
|
Returns the index in self.polylines, or None if none is close enough.
|
||||||
|
"""
|
||||||
|
best_index: Optional[int] = None
|
||||||
|
best_dist: float = float("inf")
|
||||||
|
|
||||||
|
for idx, polyline in enumerate(self.polylines):
|
||||||
|
if len(polyline) < 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Quick bounding-box check to skip obviously distant polylines
|
||||||
|
xs = [p[0] for p in polyline]
|
||||||
|
ys = [p[1] for p in polyline]
|
||||||
|
if img_x < min(xs) - threshold_px or img_x > max(xs) + threshold_px:
|
||||||
|
continue
|
||||||
|
if img_y < min(ys) - threshold_px or img_y > max(ys) + threshold_px:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Precise distance to all segments
|
||||||
|
for (x1, y1), (x2, y2) in zip(polyline[:-1], polyline[1:]):
|
||||||
|
d = perpendicular_distance((img_x, img_y), (float(x1), float(y1)), (float(x2), float(y2)))
|
||||||
|
if d < best_dist:
|
||||||
|
best_dist = d
|
||||||
|
best_index = idx
|
||||||
|
|
||||||
|
if best_index is not None and best_dist <= threshold_px:
|
||||||
|
return best_index
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _image_to_normalized_coords(self, x: int, y: int) -> Tuple[float, float]:
|
||||||
|
"""Convert image coordinates to normalized coordinates (0-1)."""
|
||||||
|
if self.original_pixmap is None:
|
||||||
|
return (0.0, 0.0)
|
||||||
|
|
||||||
|
norm_x = x / self.original_pixmap.width()
|
||||||
|
norm_y = y / self.original_pixmap.height()
|
||||||
|
return (norm_x, norm_y)
|
||||||
|
|
||||||
|
def _add_polyline(
|
||||||
|
self,
|
||||||
|
img_points: List[Tuple[float, float]],
|
||||||
|
color: QColor,
|
||||||
|
width: int,
|
||||||
|
annotation_id: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""Store a polyline in image coordinates and redraw annotations."""
|
||||||
|
if not img_points or len(img_points) < 2:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Ensure all points are tuples of floats
|
||||||
|
normalized_points = [(float(x), float(y)) for x, y in img_points]
|
||||||
|
self.polylines.append(normalized_points)
|
||||||
|
self.stroke_meta.append({"color": QColor(color), "width": int(width)})
|
||||||
|
self.polyline_annotation_ids.append(annotation_id)
|
||||||
|
|
||||||
|
self._redraw_annotations()
|
||||||
|
|
||||||
|
def _redraw_annotations(self):
|
||||||
|
"""Redraw all stored polylines and (optionally) bounding boxes onto the annotation pixmap."""
|
||||||
|
if self.annotation_pixmap is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Clear existing overlay
|
||||||
|
self.annotation_pixmap.fill(Qt.transparent)
|
||||||
|
|
||||||
|
painter = QPainter(self.annotation_pixmap)
|
||||||
|
|
||||||
|
# Draw polylines
|
||||||
|
for idx, (polyline, meta) in enumerate(zip(self.polylines, self.stroke_meta)):
|
||||||
|
pen_color: QColor = meta.get("color", self.polyline_pen_color)
|
||||||
|
width: int = meta.get("width", self.polyline_pen_width)
|
||||||
|
|
||||||
|
if idx in self.selected_polyline_indices:
|
||||||
|
# Highlight selected polylines in a distinct color / width
|
||||||
|
highlight_color = QColor(255, 255, 0, 200) # yellow, semi-opaque
|
||||||
|
pen = QPen(
|
||||||
|
highlight_color,
|
||||||
|
width + 1,
|
||||||
|
Qt.SolidLine,
|
||||||
|
Qt.RoundCap,
|
||||||
|
Qt.RoundJoin,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pen = QPen(
|
||||||
|
pen_color,
|
||||||
|
width,
|
||||||
|
Qt.SolidLine,
|
||||||
|
Qt.RoundCap,
|
||||||
|
Qt.RoundJoin,
|
||||||
|
)
|
||||||
|
|
||||||
|
painter.setPen(pen)
|
||||||
|
# Use QPolygonF for efficient polygon rendering (single call vs N-1 calls)
|
||||||
|
# drawPolygon() automatically closes the shape, ensuring proper visual closure
|
||||||
|
polygon = QPolygonF([QPointF(x, y) for x, y in polyline])
|
||||||
|
painter.drawPolygon(polygon)
|
||||||
|
|
||||||
|
# Draw bounding boxes (dashed) if enabled
|
||||||
|
if self.show_bboxes and self.original_pixmap is not None and self.bboxes:
|
||||||
|
img_width = float(self.original_pixmap.width())
|
||||||
|
img_height = float(self.original_pixmap.height())
|
||||||
|
|
||||||
|
for bbox, meta in zip(self.bboxes, self.bbox_meta):
|
||||||
|
if len(bbox) != 4:
|
||||||
|
continue
|
||||||
|
|
||||||
|
x_min_norm, y_min_norm, x_max_norm, y_max_norm = bbox
|
||||||
|
x_min = int(x_min_norm * img_width)
|
||||||
|
y_min = int(y_min_norm * img_height)
|
||||||
|
x_max = int(x_max_norm * img_width)
|
||||||
|
y_max = int(y_max_norm * img_height)
|
||||||
|
|
||||||
|
rect_width = x_max - x_min
|
||||||
|
rect_height = y_max - y_min
|
||||||
|
|
||||||
|
pen_color: QColor = meta.get("color", QColor(255, 0, 0, 128))
|
||||||
|
width: int = meta.get("width", self.polyline_pen_width)
|
||||||
|
pen = QPen(
|
||||||
|
pen_color,
|
||||||
|
width,
|
||||||
|
Qt.DashLine,
|
||||||
|
Qt.SquareCap,
|
||||||
|
Qt.MiterJoin,
|
||||||
|
)
|
||||||
|
painter.setPen(pen)
|
||||||
|
painter.drawRect(x_min, y_min, rect_width, rect_height)
|
||||||
|
|
||||||
|
label_text = meta.get("label")
|
||||||
|
if label_text:
|
||||||
|
painter.save()
|
||||||
|
font = painter.font()
|
||||||
|
font.setPointSizeF(max(10.0, width + 4))
|
||||||
|
painter.setFont(font)
|
||||||
|
metrics = painter.fontMetrics()
|
||||||
|
text_width = metrics.horizontalAdvance(label_text)
|
||||||
|
text_height = metrics.height()
|
||||||
|
padding = 4
|
||||||
|
bg_width = text_width + padding * 2
|
||||||
|
bg_height = text_height + padding * 2
|
||||||
|
canvas_width = self.original_pixmap.width()
|
||||||
|
canvas_height = self.original_pixmap.height()
|
||||||
|
bg_x = max(0, min(x_min, canvas_width - bg_width))
|
||||||
|
bg_y = y_min - bg_height
|
||||||
|
if bg_y < 0:
|
||||||
|
bg_y = min(y_min, canvas_height - bg_height)
|
||||||
|
bg_y = max(0, bg_y)
|
||||||
|
background_rect = QRect(bg_x, bg_y, bg_width, bg_height)
|
||||||
|
background_color = QColor(pen_color)
|
||||||
|
background_color.setAlpha(220)
|
||||||
|
painter.fillRect(background_rect, background_color)
|
||||||
|
text_color = QColor(0, 0, 0)
|
||||||
|
if background_color.lightness() < 128:
|
||||||
|
text_color = QColor(255, 255, 255)
|
||||||
|
painter.setPen(text_color)
|
||||||
|
painter.drawText(
|
||||||
|
background_rect.adjusted(padding, padding, -padding, -padding),
|
||||||
|
Qt.AlignLeft | Qt.AlignVCenter,
|
||||||
|
label_text,
|
||||||
|
)
|
||||||
|
painter.restore()
|
||||||
|
|
||||||
|
painter.end()
|
||||||
|
|
||||||
|
self._update_display()
|
||||||
|
|
||||||
|
def mousePressEvent(self, event: QMouseEvent):
|
||||||
|
"""Handle mouse press events for drawing and selecting polylines."""
|
||||||
|
if self.annotation_pixmap is None:
|
||||||
|
super().mousePressEvent(event)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Map click to image coordinates
|
||||||
|
label_pos = self.canvas_label.mapFromGlobal(event.globalPos())
|
||||||
|
img_coords = self._canvas_to_image_coords(label_pos)
|
||||||
|
|
||||||
|
# Left button + drawing tool enabled -> start a new stroke
|
||||||
|
if event.button() == Qt.LeftButton and self.polyline_enabled:
|
||||||
|
if img_coords:
|
||||||
|
self.is_drawing = True
|
||||||
|
self.current_stroke = [(float(img_coords[0]), float(img_coords[1]))]
|
||||||
|
return
|
||||||
|
|
||||||
|
# Left button + drawing tool disabled -> attempt selection of existing polyline
|
||||||
|
if event.button() == Qt.LeftButton and not self.polyline_enabled:
|
||||||
|
if img_coords:
|
||||||
|
idx = self._find_polyline_at(float(img_coords[0]), float(img_coords[1]))
|
||||||
|
if idx is not None:
|
||||||
|
if event.modifiers() & Qt.ShiftModifier:
|
||||||
|
# Multi-select mode: add to current selection (if not already selected)
|
||||||
|
if idx not in self.selected_polyline_indices:
|
||||||
|
self.selected_polyline_indices.append(idx)
|
||||||
|
else:
|
||||||
|
# Single-select mode: replace current selection
|
||||||
|
self.selected_polyline_indices = [idx]
|
||||||
|
|
||||||
|
# Build list of selected annotation IDs (ignore None entries)
|
||||||
|
selected_ids: List[int] = []
|
||||||
|
for sel_idx in self.selected_polyline_indices:
|
||||||
|
if 0 <= sel_idx < len(self.polyline_annotation_ids):
|
||||||
|
ann_id = self.polyline_annotation_ids[sel_idx]
|
||||||
|
if isinstance(ann_id, int):
|
||||||
|
selected_ids.append(ann_id)
|
||||||
|
|
||||||
|
if selected_ids:
|
||||||
|
self.annotation_selected.emit(selected_ids)
|
||||||
|
else:
|
||||||
|
# No valid DB-backed annotations in selection
|
||||||
|
self.annotation_selected.emit(None)
|
||||||
|
else:
|
||||||
|
# Clicked on empty space -> clear selection
|
||||||
|
self.selected_polyline_indices = []
|
||||||
|
self.annotation_selected.emit(None)
|
||||||
|
|
||||||
|
self._redraw_annotations()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Fallback for other buttons / cases
|
||||||
|
super().mousePressEvent(event)
|
||||||
|
|
||||||
|
def mouseMoveEvent(self, event: QMouseEvent):
|
||||||
|
"""Handle mouse move events for drawing."""
|
||||||
|
if not self.is_drawing or not self.polyline_enabled or self.annotation_pixmap is None:
|
||||||
|
super().mouseMoveEvent(event)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get accurate position using global coordinates
|
||||||
|
label_pos = self.canvas_label.mapFromGlobal(event.globalPos())
|
||||||
|
img_coords = self._canvas_to_image_coords(label_pos)
|
||||||
|
|
||||||
|
if img_coords and len(self.current_stroke) > 0:
|
||||||
|
last_point = self.current_stroke[-1]
|
||||||
|
dx = img_coords[0] - last_point[0]
|
||||||
|
dy = img_coords[1] - last_point[1]
|
||||||
|
|
||||||
|
# Only sample a new point if we moved enough pixels
|
||||||
|
if math.hypot(dx, dy) < self.sample_threshold:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Draw line from last point to current point for interactive feedback
|
||||||
|
painter = QPainter(self.annotation_pixmap)
|
||||||
|
pen = QPen(
|
||||||
|
self.polyline_pen_color,
|
||||||
|
self.polyline_pen_width,
|
||||||
|
Qt.SolidLine,
|
||||||
|
Qt.RoundCap,
|
||||||
|
Qt.RoundJoin,
|
||||||
|
)
|
||||||
|
painter.setPen(pen)
|
||||||
|
painter.drawLine(
|
||||||
|
int(last_point[0]),
|
||||||
|
int(last_point[1]),
|
||||||
|
int(img_coords[0]),
|
||||||
|
int(img_coords[1]),
|
||||||
|
)
|
||||||
|
painter.end()
|
||||||
|
|
||||||
|
self.current_stroke.append((float(img_coords[0]), float(img_coords[1])))
|
||||||
|
self._update_display()
|
||||||
|
|
||||||
|
def mouseReleaseEvent(self, event: QMouseEvent):
|
||||||
|
"""Handle mouse release events to complete a stroke."""
|
||||||
|
if not self.is_drawing or event.button() != Qt.LeftButton:
|
||||||
|
super().mouseReleaseEvent(event)
|
||||||
|
return
|
||||||
|
|
||||||
|
self.is_drawing = False
|
||||||
|
|
||||||
|
if len(self.current_stroke) > 1 and self.original_pixmap is not None:
|
||||||
|
# Ensure the stroke is closed by connecting end -> start
|
||||||
|
raw_points = list(self.current_stroke)
|
||||||
|
if raw_points[0] != raw_points[-1]:
|
||||||
|
raw_points.append(raw_points[0])
|
||||||
|
|
||||||
|
# Optional RDP simplification (in image pixel space)
|
||||||
|
if self.simplify_on_finish:
|
||||||
|
simplified = simplify_polyline(raw_points, self.simplify_epsilon)
|
||||||
|
else:
|
||||||
|
simplified = raw_points
|
||||||
|
|
||||||
|
if len(simplified) >= 2:
|
||||||
|
# Store polyline and redraw all annotations
|
||||||
|
self._add_polyline(simplified, self.polyline_pen_color, self.polyline_pen_width)
|
||||||
|
|
||||||
|
# Convert to normalized coordinates for metadata + signal
|
||||||
|
normalized_stroke = [self._image_to_normalized_coords(int(x), int(y)) for (x, y) in simplified]
|
||||||
|
self.all_strokes.append(
|
||||||
|
{
|
||||||
|
"points": normalized_stroke,
|
||||||
|
"color": self.polyline_pen_color.name(),
|
||||||
|
"alpha": self.polyline_pen_color.alpha(),
|
||||||
|
"width": self.polyline_pen_width,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Emit signal with normalized coordinates
|
||||||
|
self.annotation_drawn.emit(normalized_stroke)
|
||||||
|
logger.debug(
|
||||||
|
f"Completed stroke with {len(simplified)} points " f"(normalized len={len(normalized_stroke)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.current_stroke = []
|
||||||
|
|
||||||
|
def get_all_strokes(self) -> List[dict]:
|
||||||
|
"""Get all drawn strokes with metadata."""
|
||||||
|
return self.all_strokes
|
||||||
|
|
||||||
|
def get_annotation_parameters(self) -> Optional[List[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Get all annotation parameters including bounding box and polyline.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dictionaries, each containing:
|
||||||
|
- 'bbox': [x_min, y_min, x_max, y_max] in normalized image coordinates
|
||||||
|
- 'polyline': List of [y_norm, x_norm] points describing the polygon
|
||||||
|
"""
|
||||||
|
if self.original_pixmap is None or not self.polylines:
|
||||||
|
return None
|
||||||
|
|
||||||
|
img_width = float(self.original_pixmap.width())
|
||||||
|
img_height = float(self.original_pixmap.height())
|
||||||
|
|
||||||
|
params: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
for idx, polyline in enumerate(self.polylines):
|
||||||
|
if len(polyline) < 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
xs = [p[0] for p in polyline]
|
||||||
|
ys = [p[1] for p in polyline]
|
||||||
|
|
||||||
|
x_min_norm = min(xs) / img_width
|
||||||
|
x_max_norm = max(xs) / img_width
|
||||||
|
y_min_norm = min(ys) / img_height
|
||||||
|
y_max_norm = max(ys) / img_height
|
||||||
|
|
||||||
|
# Store polyline as [y_norm, x_norm] to match DB convention and
|
||||||
|
# the expectations of draw_saved_polyline().
|
||||||
|
normalized_polyline = [[y / img_height, x / img_width] for (x, y) in polyline]
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Polyline {idx}: {len(polyline)} points, "
|
||||||
|
f"bbox=({x_min_norm:.3f}, {y_min_norm:.3f})-({x_max_norm:.3f}, {y_max_norm:.3f})"
|
||||||
|
)
|
||||||
|
|
||||||
|
params.append(
|
||||||
|
{
|
||||||
|
"bbox": [x_min_norm, y_min_norm, x_max_norm, y_max_norm],
|
||||||
|
"polyline": normalized_polyline,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return params or None
|
||||||
|
|
||||||
|
def draw_saved_polyline(
|
||||||
|
self,
|
||||||
|
polyline: List[List[float]],
|
||||||
|
color: str,
|
||||||
|
width: int = 1,
|
||||||
|
annotation_id: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Draw a polyline from database coordinates onto the annotation canvas.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
polyline: List of [x, y] coordinate pairs in normalized coordinates (0-1)
|
||||||
|
color: Color hex string (e.g., '#FF0000')
|
||||||
|
width: Line width in pixels
|
||||||
|
"""
|
||||||
|
if not self.annotation_pixmap or not self.original_pixmap:
|
||||||
|
logger.warning("Cannot draw polyline: no image loaded")
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(polyline) < 2:
|
||||||
|
logger.warning("Polyline has less than 2 points, cannot draw")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Convert normalized coordinates to image coordinates
|
||||||
|
# Polyline is stored as [[y_norm, x_norm], ...] (row_norm, col_norm format)
|
||||||
|
img_width = self.original_pixmap.width()
|
||||||
|
img_height = self.original_pixmap.height()
|
||||||
|
|
||||||
|
logger.debug(f"Loading polyline with {len(polyline)} points")
|
||||||
|
logger.debug(f" Image size: {img_width}x{img_height}")
|
||||||
|
logger.debug(f" First 3 normalized points from DB: {polyline[:3]}")
|
||||||
|
|
||||||
|
img_coords: List[Tuple[float, float]] = []
|
||||||
|
for y_norm, x_norm in polyline:
|
||||||
|
x = float(x_norm * img_width)
|
||||||
|
y = float(y_norm * img_height)
|
||||||
|
img_coords.append((x, y))
|
||||||
|
|
||||||
|
logger.debug(f" First 3 pixel coords: {img_coords[:3]}")
|
||||||
|
|
||||||
|
# Store and redraw using common pipeline
|
||||||
|
pen_color = QColor(color)
|
||||||
|
pen_color.setAlpha(255) # Add semi-transparency
|
||||||
|
self._add_polyline(img_coords, pen_color, width, annotation_id=annotation_id)
|
||||||
|
|
||||||
|
# Store in all_strokes for consistency (uses normalized coordinates)
|
||||||
|
self.all_strokes.append({"points": polyline, "color": color, "alpha": 255, "width": width})
|
||||||
|
|
||||||
|
logger.debug(f"Drew saved polyline with {len(polyline)} points in color {color}")
|
||||||
|
|
||||||
|
def draw_saved_bbox(
|
||||||
|
self,
|
||||||
|
bbox: List[float],
|
||||||
|
color: str,
|
||||||
|
width: int = 3,
|
||||||
|
label: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Draw a bounding box from database coordinates onto the annotation canvas.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bbox: Bounding box as [x_min_norm, y_min_norm, x_max_norm, y_max_norm]
|
||||||
|
in normalized coordinates (0-1)
|
||||||
|
color: Color hex string (e.g., '#FF0000')
|
||||||
|
width: Line width in pixels
|
||||||
|
label: Optional text label to render near the bounding box
|
||||||
|
"""
|
||||||
|
if not self.annotation_pixmap or not self.original_pixmap:
|
||||||
|
logger.warning("Cannot draw bounding box: no image loaded")
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(bbox) != 4:
|
||||||
|
logger.warning(f"Invalid bounding box format: expected 4 values, got {len(bbox)}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Convert normalized coordinates to image coordinates (for logging/debug)
|
||||||
|
img_width = self.original_pixmap.width()
|
||||||
|
img_height = self.original_pixmap.height()
|
||||||
|
|
||||||
|
x_min_norm, y_min_norm, x_max_norm, y_max_norm = bbox
|
||||||
|
x_min = int(x_min_norm * img_width)
|
||||||
|
y_min = int(y_min_norm * img_height)
|
||||||
|
x_max = int(x_max_norm * img_width)
|
||||||
|
y_max = int(y_max_norm * img_height)
|
||||||
|
|
||||||
|
logger.debug(f"Drawing bounding box: {bbox}")
|
||||||
|
logger.debug(f" Image size: {img_width}x{img_height}")
|
||||||
|
logger.debug(f" Pixel coords: ({x_min}, {y_min}) to ({x_max}, {y_max})")
|
||||||
|
|
||||||
|
# Store bounding box (normalized) and its style; actual drawing happens
|
||||||
|
# in _redraw_annotations() together with all polylines.
|
||||||
|
pen_color = QColor(color)
|
||||||
|
pen_color.setAlpha(128) # Add semi-transparency
|
||||||
|
self.bboxes.append([float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)])
|
||||||
|
self.bbox_meta.append({"color": pen_color, "width": int(width), "label": label})
|
||||||
|
|
||||||
|
# Store in all_strokes for consistency
|
||||||
|
self.all_strokes.append({"bbox": bbox, "color": color, "alpha": 128, "width": width, "label": label})
|
||||||
|
|
||||||
|
# Redraw overlay (polylines + all bounding boxes)
|
||||||
|
self._redraw_annotations()
|
||||||
|
logger.debug(f"Drew saved bounding box in color {color}")
|
||||||
|
|
||||||
|
def set_show_bboxes(self, show: bool):
|
||||||
|
"""
|
||||||
|
Enable or disable drawing of bounding boxes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
show: If True, draw bounding boxes; if False, hide them.
|
||||||
|
"""
|
||||||
|
self.show_bboxes = bool(show)
|
||||||
|
logger.debug(f"Set show_bboxes to {self.show_bboxes}")
|
||||||
|
self._redraw_annotations()
|
||||||
|
|
||||||
|
def keyPressEvent(self, event: QKeyEvent):
|
||||||
|
"""Handle keyboard events for zooming."""
|
||||||
|
if event.key() in (Qt.Key_Plus, Qt.Key_Equal):
|
||||||
|
self.zoom_in()
|
||||||
|
event.accept()
|
||||||
|
elif event.key() == Qt.Key_Minus:
|
||||||
|
self.zoom_out()
|
||||||
|
event.accept()
|
||||||
|
elif event.key() == Qt.Key_0 and event.modifiers() == Qt.ControlModifier:
|
||||||
|
self.reset_zoom()
|
||||||
|
event.accept()
|
||||||
|
else:
|
||||||
|
super().keyPressEvent(event)
|
||||||
|
|
||||||
|
def eventFilter(self, obj, event: QEvent) -> bool:
|
||||||
|
"""Event filter to capture wheel events for zooming."""
|
||||||
|
if event.type() == QEvent.Wheel:
|
||||||
|
wheel_event = event
|
||||||
|
if self.original_pixmap is not None:
|
||||||
|
delta = wheel_event.angleDelta().y()
|
||||||
|
|
||||||
|
if delta > 0:
|
||||||
|
new_scale = self.zoom_scale + self.zoom_wheel_step
|
||||||
|
if new_scale <= self.zoom_max:
|
||||||
|
self.zoom_scale = new_scale
|
||||||
|
self._apply_zoom()
|
||||||
|
else:
|
||||||
|
new_scale = self.zoom_scale - self.zoom_wheel_step
|
||||||
|
if new_scale >= self.zoom_min:
|
||||||
|
self.zoom_scale = new_scale
|
||||||
|
self._apply_zoom()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
return super().eventFilter(obj, event)
|
||||||
478
src/gui/widgets/annotation_tools_widget.py
Normal file
478
src/gui/widgets/annotation_tools_widget.py
Normal file
@@ -0,0 +1,478 @@
|
|||||||
|
"""
|
||||||
|
Annotation tools widget for controlling annotation parameters.
|
||||||
|
Includes polyline tool, color picker, class selection, and annotation management.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from PySide6.QtWidgets import (
|
||||||
|
QWidget,
|
||||||
|
QVBoxLayout,
|
||||||
|
QHBoxLayout,
|
||||||
|
QLabel,
|
||||||
|
QGroupBox,
|
||||||
|
QPushButton,
|
||||||
|
QComboBox,
|
||||||
|
QSpinBox,
|
||||||
|
QDoubleSpinBox,
|
||||||
|
QCheckBox,
|
||||||
|
QColorDialog,
|
||||||
|
QInputDialog,
|
||||||
|
QMessageBox,
|
||||||
|
)
|
||||||
|
from PySide6.QtGui import QColor, QIcon, QPixmap, QPainter
|
||||||
|
from PySide6.QtCore import Qt, Signal
|
||||||
|
from typing import Optional, Dict
|
||||||
|
|
||||||
|
from src.database.db_manager import DatabaseManager
|
||||||
|
from src.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotationToolsWidget(QWidget):
|
||||||
|
"""
|
||||||
|
Widget for annotation tool controls.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Enable/disable polyline tool
|
||||||
|
- Color selection for polyline pen
|
||||||
|
- Object class selection
|
||||||
|
- Add new object classes
|
||||||
|
- Pen width control
|
||||||
|
- Clear annotations
|
||||||
|
|
||||||
|
Signals:
|
||||||
|
polyline_enabled_changed: Emitted when polyline tool is enabled/disabled (bool)
|
||||||
|
polyline_pen_color_changed: Emitted when polyline pen color changes (QColor)
|
||||||
|
polyline_pen_width_changed: Emitted when polyline pen width changes (int)
|
||||||
|
class_selected: Emitted when object class is selected (dict)
|
||||||
|
clear_annotations_requested: Emitted when clear button is pressed
|
||||||
|
"""
|
||||||
|
|
||||||
|
polyline_enabled_changed = Signal(bool)
|
||||||
|
polyline_pen_color_changed = Signal(QColor)
|
||||||
|
polyline_pen_width_changed = Signal(int)
|
||||||
|
simplify_on_finish_changed = Signal(bool)
|
||||||
|
simplify_epsilon_changed = Signal(float)
|
||||||
|
# Toggle visibility of bounding boxes on the canvas
|
||||||
|
show_bboxes_changed = Signal(bool)
|
||||||
|
class_selected = Signal(dict)
|
||||||
|
class_color_changed = Signal()
|
||||||
|
clear_annotations_requested = Signal()
|
||||||
|
# Request deletion of the currently selected annotation on the canvas
|
||||||
|
delete_selected_annotation_requested = Signal()
|
||||||
|
|
||||||
|
def __init__(self, db_manager: DatabaseManager, parent=None):
|
||||||
|
"""
|
||||||
|
Initialize annotation tools widget.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_manager: Database manager instance
|
||||||
|
parent: Parent widget
|
||||||
|
"""
|
||||||
|
super().__init__(parent)
|
||||||
|
self.db_manager = db_manager
|
||||||
|
self.polyline_enabled = False
|
||||||
|
self.current_color = QColor(255, 0, 0, 128) # Red with 50% alpha
|
||||||
|
self.current_class = None
|
||||||
|
|
||||||
|
self._setup_ui()
|
||||||
|
self._load_object_classes()
|
||||||
|
|
||||||
|
def _setup_ui(self):
|
||||||
|
"""Setup user interface."""
|
||||||
|
layout = QVBoxLayout()
|
||||||
|
|
||||||
|
# Polyline Tool Group
|
||||||
|
polyline_group = QGroupBox("Polyline Tool")
|
||||||
|
polyline_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
# Enable/Disable polyline tool
|
||||||
|
button_layout = QHBoxLayout()
|
||||||
|
self.polyline_toggle_btn = QPushButton("Start Drawing Polyline")
|
||||||
|
self.polyline_toggle_btn.setCheckable(True)
|
||||||
|
self.polyline_toggle_btn.clicked.connect(self._on_polyline_toggle)
|
||||||
|
button_layout.addWidget(self.polyline_toggle_btn)
|
||||||
|
polyline_layout.addLayout(button_layout)
|
||||||
|
|
||||||
|
# Polyline pen width control
|
||||||
|
width_layout = QHBoxLayout()
|
||||||
|
width_layout.addWidget(QLabel("Pen Width:"))
|
||||||
|
self.polyline_pen_width_spin = QSpinBox()
|
||||||
|
self.polyline_pen_width_spin.setMinimum(1)
|
||||||
|
self.polyline_pen_width_spin.setMaximum(20)
|
||||||
|
self.polyline_pen_width_spin.setValue(3)
|
||||||
|
self.polyline_pen_width_spin.valueChanged.connect(
|
||||||
|
self._on_polyline_pen_width_changed
|
||||||
|
)
|
||||||
|
width_layout.addWidget(self.polyline_pen_width_spin)
|
||||||
|
width_layout.addStretch()
|
||||||
|
polyline_layout.addLayout(width_layout)
|
||||||
|
|
||||||
|
# Simplification controls (RDP)
|
||||||
|
simplify_layout = QHBoxLayout()
|
||||||
|
self.simplify_checkbox = QCheckBox("Simplify on finish")
|
||||||
|
self.simplify_checkbox.setChecked(True)
|
||||||
|
self.simplify_checkbox.stateChanged.connect(self._on_simplify_toggle)
|
||||||
|
simplify_layout.addWidget(self.simplify_checkbox)
|
||||||
|
|
||||||
|
simplify_layout.addWidget(QLabel("epsilon (px):"))
|
||||||
|
self.eps_spin = QDoubleSpinBox()
|
||||||
|
self.eps_spin.setRange(0.0, 1000.0)
|
||||||
|
self.eps_spin.setSingleStep(0.5)
|
||||||
|
self.eps_spin.setValue(2.0)
|
||||||
|
self.eps_spin.valueChanged.connect(self._on_eps_change)
|
||||||
|
simplify_layout.addWidget(self.eps_spin)
|
||||||
|
simplify_layout.addStretch()
|
||||||
|
polyline_layout.addLayout(simplify_layout)
|
||||||
|
|
||||||
|
polyline_group.setLayout(polyline_layout)
|
||||||
|
layout.addWidget(polyline_group)
|
||||||
|
|
||||||
|
# Object Class Group
|
||||||
|
class_group = QGroupBox("Object Class")
|
||||||
|
class_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
# Class selection dropdown
|
||||||
|
self.class_combo = QComboBox()
|
||||||
|
self.class_combo.currentIndexChanged.connect(self._on_class_selected)
|
||||||
|
class_layout.addWidget(self.class_combo)
|
||||||
|
|
||||||
|
# Add / manage classes
|
||||||
|
class_button_layout = QHBoxLayout()
|
||||||
|
self.add_class_btn = QPushButton("Add New Class")
|
||||||
|
self.add_class_btn.clicked.connect(self._on_add_class)
|
||||||
|
class_button_layout.addWidget(self.add_class_btn)
|
||||||
|
|
||||||
|
self.refresh_classes_btn = QPushButton("Refresh")
|
||||||
|
self.refresh_classes_btn.clicked.connect(self._load_object_classes)
|
||||||
|
class_button_layout.addWidget(self.refresh_classes_btn)
|
||||||
|
class_layout.addLayout(class_button_layout)
|
||||||
|
|
||||||
|
# Class color (associated with selected object class)
|
||||||
|
color_layout = QHBoxLayout()
|
||||||
|
color_layout.addWidget(QLabel("Class Color:"))
|
||||||
|
self.color_btn = QPushButton()
|
||||||
|
self.color_btn.setFixedSize(40, 30)
|
||||||
|
self.color_btn.clicked.connect(self._on_color_picker)
|
||||||
|
self._update_color_button()
|
||||||
|
color_layout.addWidget(self.color_btn)
|
||||||
|
color_layout.addStretch()
|
||||||
|
class_layout.addLayout(color_layout)
|
||||||
|
|
||||||
|
# Selected class info
|
||||||
|
self.class_info_label = QLabel("No class selected")
|
||||||
|
self.class_info_label.setWordWrap(True)
|
||||||
|
self.class_info_label.setStyleSheet(
|
||||||
|
"QLabel { color: #888; font-style: italic; }"
|
||||||
|
)
|
||||||
|
class_layout.addWidget(self.class_info_label)
|
||||||
|
|
||||||
|
class_group.setLayout(class_layout)
|
||||||
|
layout.addWidget(class_group)
|
||||||
|
|
||||||
|
# Actions Group
|
||||||
|
actions_group = QGroupBox("Actions")
|
||||||
|
actions_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
# Show / hide bounding boxes
|
||||||
|
self.show_bboxes_checkbox = QCheckBox("Show bounding boxes")
|
||||||
|
self.show_bboxes_checkbox.setChecked(True)
|
||||||
|
self.show_bboxes_checkbox.stateChanged.connect(self._on_show_bboxes_toggle)
|
||||||
|
actions_layout.addWidget(self.show_bboxes_checkbox)
|
||||||
|
|
||||||
|
self.clear_btn = QPushButton("Clear All Annotations")
|
||||||
|
self.clear_btn.clicked.connect(self._on_clear_annotations)
|
||||||
|
actions_layout.addWidget(self.clear_btn)
|
||||||
|
|
||||||
|
# Delete currently selected annotation (enabled when a selection exists)
|
||||||
|
self.delete_selected_btn = QPushButton("Delete Selected Annotation")
|
||||||
|
self.delete_selected_btn.clicked.connect(self._on_delete_selected_annotation)
|
||||||
|
self.delete_selected_btn.setEnabled(False)
|
||||||
|
actions_layout.addWidget(self.delete_selected_btn)
|
||||||
|
|
||||||
|
actions_group.setLayout(actions_layout)
|
||||||
|
layout.addWidget(actions_group)
|
||||||
|
|
||||||
|
layout.addStretch()
|
||||||
|
self.setLayout(layout)
|
||||||
|
|
||||||
|
def _update_color_button(self):
|
||||||
|
"""Update the color button appearance with current color."""
|
||||||
|
pixmap = QPixmap(40, 30)
|
||||||
|
pixmap.fill(self.current_color)
|
||||||
|
|
||||||
|
# Add border
|
||||||
|
painter = QPainter(pixmap)
|
||||||
|
painter.setPen(Qt.black)
|
||||||
|
painter.drawRect(0, 0, pixmap.width() - 1, pixmap.height() - 1)
|
||||||
|
painter.end()
|
||||||
|
|
||||||
|
self.color_btn.setIcon(QIcon(pixmap))
|
||||||
|
self.color_btn.setStyleSheet(f"background-color: {self.current_color.name()};")
|
||||||
|
|
||||||
|
def _load_object_classes(self):
|
||||||
|
"""Load object classes from database and populate combo box."""
|
||||||
|
try:
|
||||||
|
classes = self.db_manager.get_object_classes()
|
||||||
|
|
||||||
|
# Clear and repopulate combo box
|
||||||
|
self.class_combo.clear()
|
||||||
|
self.class_combo.addItem("-- Select Class / Show All --", None)
|
||||||
|
|
||||||
|
for cls in classes:
|
||||||
|
self.class_combo.addItem(cls["class_name"], cls)
|
||||||
|
|
||||||
|
logger.debug(f"Loaded {len(classes)} object classes")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading object classes: {e}")
|
||||||
|
QMessageBox.warning(
|
||||||
|
self, "Error", f"Failed to load object classes:\n{str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_polyline_toggle(self, checked: bool):
|
||||||
|
"""Handle polyline tool enable/disable."""
|
||||||
|
self.polyline_enabled = checked
|
||||||
|
|
||||||
|
if checked:
|
||||||
|
self.polyline_toggle_btn.setText("Stop Drawing Polyline")
|
||||||
|
self.polyline_toggle_btn.setStyleSheet(
|
||||||
|
"QPushButton { background-color: #4CAF50; }"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.polyline_toggle_btn.setText("Start Drawing Polyline")
|
||||||
|
self.polyline_toggle_btn.setStyleSheet("")
|
||||||
|
|
||||||
|
self.polyline_enabled_changed.emit(self.polyline_enabled)
|
||||||
|
logger.debug(f"Polyline tool {'enabled' if checked else 'disabled'}")
|
||||||
|
|
||||||
|
def _on_polyline_pen_width_changed(self, width: int):
|
||||||
|
"""Handle polyline pen width changes."""
|
||||||
|
self.polyline_pen_width_changed.emit(width)
|
||||||
|
logger.debug(f"Polyline pen width changed to {width}")
|
||||||
|
|
||||||
|
def _on_simplify_toggle(self, state: int):
|
||||||
|
"""Handle simplify-on-finish checkbox toggle."""
|
||||||
|
enabled = bool(state)
|
||||||
|
self.simplify_on_finish_changed.emit(enabled)
|
||||||
|
logger.debug(f"Simplify on finish set to {enabled}")
|
||||||
|
|
||||||
|
def _on_eps_change(self, val: float):
|
||||||
|
"""Handle epsilon (RDP tolerance) value changes."""
|
||||||
|
epsilon = float(val)
|
||||||
|
self.simplify_epsilon_changed.emit(epsilon)
|
||||||
|
logger.debug(f"Simplification epsilon changed to {epsilon}")
|
||||||
|
|
||||||
|
def _on_show_bboxes_toggle(self, state: int):
|
||||||
|
"""Handle 'Show bounding boxes' checkbox toggle."""
|
||||||
|
show = bool(state)
|
||||||
|
self.show_bboxes_changed.emit(show)
|
||||||
|
logger.debug(f"Show bounding boxes set to {show}")
|
||||||
|
|
||||||
|
def _on_color_picker(self):
|
||||||
|
"""Open color picker dialog and update the selected object's class color."""
|
||||||
|
if not self.current_class:
|
||||||
|
QMessageBox.warning(
|
||||||
|
self,
|
||||||
|
"No Class Selected",
|
||||||
|
"Please select an object class before changing its color.",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Use current class color (without alpha) as the base
|
||||||
|
base_color = QColor(self.current_class.get("color", self.current_color.name()))
|
||||||
|
color = QColorDialog.getColor(
|
||||||
|
base_color,
|
||||||
|
self,
|
||||||
|
"Select Class Color",
|
||||||
|
QColorDialog.ShowAlphaChannel, # Allow alpha in UI, but store RGB in DB
|
||||||
|
)
|
||||||
|
|
||||||
|
if not color.isValid():
|
||||||
|
return
|
||||||
|
|
||||||
|
# Normalize to opaque RGB for storage
|
||||||
|
new_color = QColor(color)
|
||||||
|
new_color.setAlpha(255)
|
||||||
|
hex_color = new_color.name()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Update in database
|
||||||
|
self.db_manager.update_object_class(
|
||||||
|
class_id=self.current_class["id"], color=hex_color
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to update class color in database: {e}")
|
||||||
|
QMessageBox.critical(
|
||||||
|
self,
|
||||||
|
"Error",
|
||||||
|
f"Failed to update class color in database:\n{str(e)}",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Update local class data and combo box item data
|
||||||
|
self.current_class["color"] = hex_color
|
||||||
|
current_index = self.class_combo.currentIndex()
|
||||||
|
if current_index >= 0:
|
||||||
|
self.class_combo.setItemData(current_index, dict(self.current_class))
|
||||||
|
|
||||||
|
# Update info label text
|
||||||
|
info_text = f"Class: {self.current_class['class_name']}\nColor: {hex_color}"
|
||||||
|
if self.current_class.get("description"):
|
||||||
|
info_text += f"\nDescription: {self.current_class['description']}"
|
||||||
|
self.class_info_label.setText(info_text)
|
||||||
|
|
||||||
|
# Use semi-transparent version for polyline pen / button preview
|
||||||
|
class_color = QColor(hex_color)
|
||||||
|
class_color.setAlpha(128)
|
||||||
|
self.current_color = class_color
|
||||||
|
self._update_color_button()
|
||||||
|
self.polyline_pen_color_changed.emit(class_color)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Updated class '{self.current_class['class_name']}' color to "
|
||||||
|
f"{hex_color} (polyline pen alpha={class_color.alpha()})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Notify listeners (e.g., AnnotationTab) so they can reload/redraw
|
||||||
|
self.class_color_changed.emit()
|
||||||
|
|
||||||
|
def _on_class_selected(self, index: int):
|
||||||
|
"""Handle object class selection (including '-- Select Class --')."""
|
||||||
|
class_data = self.class_combo.currentData()
|
||||||
|
|
||||||
|
if class_data:
|
||||||
|
self.current_class = class_data
|
||||||
|
|
||||||
|
# Update info label
|
||||||
|
info_text = (
|
||||||
|
f"Class: {class_data['class_name']}\n" f"Color: {class_data['color']}"
|
||||||
|
)
|
||||||
|
if class_data.get("description"):
|
||||||
|
info_text += f"\nDescription: {class_data['description']}"
|
||||||
|
|
||||||
|
self.class_info_label.setText(info_text)
|
||||||
|
|
||||||
|
# Update polyline pen color to match class color with semi-transparency
|
||||||
|
class_color = QColor(class_data["color"])
|
||||||
|
if class_color.isValid():
|
||||||
|
# Add 50% alpha for semi-transparency
|
||||||
|
class_color.setAlpha(128)
|
||||||
|
self.current_color = class_color
|
||||||
|
self._update_color_button()
|
||||||
|
self.polyline_pen_color_changed.emit(class_color)
|
||||||
|
|
||||||
|
self.class_selected.emit(class_data)
|
||||||
|
logger.debug(f"Selected class: {class_data['class_name']}")
|
||||||
|
else:
|
||||||
|
# "-- Select Class --" chosen: clear current class and show all annotations
|
||||||
|
self.current_class = None
|
||||||
|
self.class_info_label.setText("No class selected")
|
||||||
|
self.class_selected.emit(None)
|
||||||
|
logger.debug("Class selection cleared: showing annotations for all classes")
|
||||||
|
|
||||||
|
def _on_add_class(self):
|
||||||
|
"""Handle adding a new object class."""
|
||||||
|
# Get class name
|
||||||
|
class_name, ok = QInputDialog.getText(
|
||||||
|
self, "Add Object Class", "Enter class name:"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not ok or not class_name.strip():
|
||||||
|
return
|
||||||
|
|
||||||
|
class_name = class_name.strip()
|
||||||
|
|
||||||
|
# Check if class already exists
|
||||||
|
existing = self.db_manager.get_object_class_by_name(class_name)
|
||||||
|
if existing:
|
||||||
|
QMessageBox.warning(
|
||||||
|
self, "Class Exists", f"A class named '{class_name}' already exists."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get color
|
||||||
|
color = QColorDialog.getColor(self.current_color, self, "Select Class Color")
|
||||||
|
|
||||||
|
if not color.isValid():
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get optional description
|
||||||
|
description, ok = QInputDialog.getText(
|
||||||
|
self, "Class Description", "Enter class description (optional):"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not ok:
|
||||||
|
description = None
|
||||||
|
|
||||||
|
# Add to database
|
||||||
|
try:
|
||||||
|
class_id = self.db_manager.add_object_class(
|
||||||
|
class_name, color.name(), description.strip() if description else None
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Added new object class: {class_name} (ID: {class_id})")
|
||||||
|
|
||||||
|
# Reload classes and select the new one
|
||||||
|
self._load_object_classes()
|
||||||
|
|
||||||
|
# Find and select the newly added class
|
||||||
|
for i in range(self.class_combo.count()):
|
||||||
|
class_data = self.class_combo.itemData(i)
|
||||||
|
if class_data and class_data.get("id") == class_id:
|
||||||
|
self.class_combo.setCurrentIndex(i)
|
||||||
|
break
|
||||||
|
|
||||||
|
QMessageBox.information(
|
||||||
|
self, "Success", f"Class '{class_name}' added successfully!"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error adding object class: {e}")
|
||||||
|
QMessageBox.critical(
|
||||||
|
self, "Error", f"Failed to add object class:\n{str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_clear_annotations(self):
|
||||||
|
"""Handle clear annotations button."""
|
||||||
|
reply = QMessageBox.question(
|
||||||
|
self,
|
||||||
|
"Clear Annotations",
|
||||||
|
"Are you sure you want to clear all annotations?",
|
||||||
|
QMessageBox.Yes | QMessageBox.No,
|
||||||
|
QMessageBox.No,
|
||||||
|
)
|
||||||
|
|
||||||
|
if reply == QMessageBox.Yes:
|
||||||
|
self.clear_annotations_requested.emit()
|
||||||
|
logger.debug("Clear annotations requested")
|
||||||
|
|
||||||
|
def _on_delete_selected_annotation(self):
|
||||||
|
"""Handle delete selected annotation button."""
|
||||||
|
self.delete_selected_annotation_requested.emit()
|
||||||
|
logger.debug("Delete selected annotation requested")
|
||||||
|
|
||||||
|
def set_has_selected_annotation(self, has_selection: bool):
|
||||||
|
"""
|
||||||
|
Enable/disable actions that require a selected annotation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
has_selection: True if an annotation is currently selected on the canvas.
|
||||||
|
"""
|
||||||
|
self.delete_selected_btn.setEnabled(bool(has_selection))
|
||||||
|
|
||||||
|
def get_current_class(self) -> Optional[Dict]:
|
||||||
|
"""Get currently selected object class."""
|
||||||
|
return self.current_class
|
||||||
|
|
||||||
|
def get_polyline_pen_color(self) -> QColor:
|
||||||
|
"""Get current polyline pen color."""
|
||||||
|
return self.current_color
|
||||||
|
|
||||||
|
def get_polyline_pen_width(self) -> int:
|
||||||
|
"""Get current polyline pen width."""
|
||||||
|
return self.polyline_pen_width_spin.value()
|
||||||
|
|
||||||
|
def is_polyline_enabled(self) -> bool:
|
||||||
|
"""Check if polyline tool is enabled."""
|
||||||
|
return self.polyline_enabled
|
||||||
282
src/gui/widgets/image_display_widget.py
Normal file
282
src/gui/widgets/image_display_widget.py
Normal file
@@ -0,0 +1,282 @@
|
|||||||
|
"""
|
||||||
|
Image display widget with zoom functionality for the microscopy object detection application.
|
||||||
|
Reusable widget for displaying images with zoom controls.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QScrollArea
|
||||||
|
from PySide6.QtGui import QPixmap, QImage, QKeyEvent
|
||||||
|
from PySide6.QtCore import Qt, QEvent, Signal
|
||||||
|
from pathlib import Path
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from src.utils.image import Image, ImageLoadError
|
||||||
|
from src.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageDisplayWidget(QWidget):
|
||||||
|
"""
|
||||||
|
Reusable widget for displaying images with zoom functionality.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Display images from Image objects
|
||||||
|
- Zoom in/out with mouse wheel
|
||||||
|
- Zoom in/out with +/- keyboard keys
|
||||||
|
- Reset zoom with Ctrl+0
|
||||||
|
- Scroll area for large images
|
||||||
|
|
||||||
|
Signals:
|
||||||
|
zoom_changed: Emitted when zoom level changes (float zoom_scale)
|
||||||
|
"""
|
||||||
|
|
||||||
|
zoom_changed = Signal(float) # Emitted when zoom level changes
|
||||||
|
|
||||||
|
def __init__(self, parent=None):
|
||||||
|
"""
|
||||||
|
Initialize the image display widget.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parent: Parent widget
|
||||||
|
"""
|
||||||
|
super().__init__(parent)
|
||||||
|
|
||||||
|
self.current_image = None
|
||||||
|
self.original_pixmap = None # Store original pixmap for zoom
|
||||||
|
self.zoom_scale = 1.0 # Current zoom scale
|
||||||
|
self.zoom_min = 0.1 # Minimum zoom (10%)
|
||||||
|
self.zoom_max = 10.0 # Maximum zoom (1000%)
|
||||||
|
self.zoom_step = 0.1 # Zoom step for +/- keys
|
||||||
|
self.zoom_wheel_step = 0.15 # Zoom step for mouse wheel
|
||||||
|
|
||||||
|
self._setup_ui()
|
||||||
|
|
||||||
|
def _setup_ui(self):
|
||||||
|
"""Setup user interface."""
|
||||||
|
layout = QVBoxLayout()
|
||||||
|
layout.setContentsMargins(0, 0, 0, 0)
|
||||||
|
|
||||||
|
# Scroll area for image
|
||||||
|
self.scroll_area = QScrollArea()
|
||||||
|
self.scroll_area.setWidgetResizable(True)
|
||||||
|
self.scroll_area.setMinimumHeight(400)
|
||||||
|
|
||||||
|
self.image_label = QLabel("No image loaded")
|
||||||
|
self.image_label.setAlignment(Qt.AlignCenter)
|
||||||
|
self.image_label.setStyleSheet(
|
||||||
|
"QLabel { background-color: #2b2b2b; color: #888; }"
|
||||||
|
)
|
||||||
|
self.image_label.setScaledContents(False)
|
||||||
|
|
||||||
|
# Enable mouse tracking for wheel events
|
||||||
|
self.image_label.setMouseTracking(True)
|
||||||
|
self.scroll_area.setWidget(self.image_label)
|
||||||
|
|
||||||
|
# Install event filter to capture wheel events on scroll area
|
||||||
|
self.scroll_area.viewport().installEventFilter(self)
|
||||||
|
|
||||||
|
layout.addWidget(self.scroll_area)
|
||||||
|
self.setLayout(layout)
|
||||||
|
|
||||||
|
# Set focus policy to receive keyboard events
|
||||||
|
self.setFocusPolicy(Qt.StrongFocus)
|
||||||
|
|
||||||
|
def load_image(self, image: Image):
|
||||||
|
"""
|
||||||
|
Load and display an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Image object to display
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImageLoadError: If image cannot be displayed
|
||||||
|
"""
|
||||||
|
self.current_image = image
|
||||||
|
|
||||||
|
# Reset zoom when loading new image
|
||||||
|
self.zoom_scale = 1.0
|
||||||
|
|
||||||
|
# Convert to QPixmap and display
|
||||||
|
self._display_image()
|
||||||
|
|
||||||
|
logger.debug(f"Loaded image into display widget: {image.width}x{image.height}")
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""Clear the displayed image."""
|
||||||
|
self.current_image = None
|
||||||
|
self.original_pixmap = None
|
||||||
|
self.zoom_scale = 1.0
|
||||||
|
self.image_label.setText("No image loaded")
|
||||||
|
self.image_label.setPixmap(QPixmap())
|
||||||
|
logger.debug("Cleared image display")
|
||||||
|
|
||||||
|
def _display_image(self):
|
||||||
|
"""Display the current image in the image label."""
|
||||||
|
if self.current_image is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get RGB image data
|
||||||
|
if self.current_image.channels == 3:
|
||||||
|
image_data = self.current_image.get_rgb()
|
||||||
|
height, width, channels = image_data.shape
|
||||||
|
else:
|
||||||
|
image_data = self.current_image.get_grayscale()
|
||||||
|
height, width = image_data.shape
|
||||||
|
channels = 1
|
||||||
|
|
||||||
|
# Ensure data is contiguous for proper QImage display
|
||||||
|
image_data = np.ascontiguousarray(image_data)
|
||||||
|
|
||||||
|
# Use actual stride from numpy array for correct display
|
||||||
|
bytes_per_line = image_data.strides[0]
|
||||||
|
|
||||||
|
qimage = QImage(
|
||||||
|
image_data.data,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
bytes_per_line,
|
||||||
|
self.current_image.qtimage_format,
|
||||||
|
).copy() # Copy to ensure Qt owns its memory after this scope
|
||||||
|
|
||||||
|
# Convert to pixmap
|
||||||
|
pixmap = QPixmap.fromImage(qimage)
|
||||||
|
|
||||||
|
# Store original pixmap for zooming
|
||||||
|
self.original_pixmap = pixmap
|
||||||
|
|
||||||
|
# Apply zoom and display
|
||||||
|
self._apply_zoom()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error displaying image: {e}")
|
||||||
|
raise ImageLoadError(f"Failed to display image: {str(e)}")
|
||||||
|
|
||||||
|
def _apply_zoom(self):
|
||||||
|
"""Apply current zoom level to the displayed image."""
|
||||||
|
if self.original_pixmap is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Calculate scaled size
|
||||||
|
scaled_width = int(self.original_pixmap.width() * self.zoom_scale)
|
||||||
|
scaled_height = int(self.original_pixmap.height() * self.zoom_scale)
|
||||||
|
|
||||||
|
# Scale pixmap
|
||||||
|
scaled_pixmap = self.original_pixmap.scaled(
|
||||||
|
scaled_width,
|
||||||
|
scaled_height,
|
||||||
|
Qt.KeepAspectRatio,
|
||||||
|
(
|
||||||
|
Qt.SmoothTransformation
|
||||||
|
if self.zoom_scale >= 1.0
|
||||||
|
else Qt.FastTransformation
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Display in label
|
||||||
|
self.image_label.setPixmap(scaled_pixmap)
|
||||||
|
self.image_label.setScaledContents(False)
|
||||||
|
self.image_label.adjustSize()
|
||||||
|
|
||||||
|
# Emit zoom changed signal
|
||||||
|
self.zoom_changed.emit(self.zoom_scale)
|
||||||
|
|
||||||
|
def zoom_in(self):
|
||||||
|
"""Zoom in on the image."""
|
||||||
|
if self.original_pixmap is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
new_scale = self.zoom_scale + self.zoom_step
|
||||||
|
if new_scale <= self.zoom_max:
|
||||||
|
self.zoom_scale = new_scale
|
||||||
|
self._apply_zoom()
|
||||||
|
logger.debug(f"Zoomed in to {int(self.zoom_scale * 100)}%")
|
||||||
|
|
||||||
|
def zoom_out(self):
|
||||||
|
"""Zoom out from the image."""
|
||||||
|
if self.original_pixmap is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
new_scale = self.zoom_scale - self.zoom_step
|
||||||
|
if new_scale >= self.zoom_min:
|
||||||
|
self.zoom_scale = new_scale
|
||||||
|
self._apply_zoom()
|
||||||
|
logger.debug(f"Zoomed out to {int(self.zoom_scale * 100)}%")
|
||||||
|
|
||||||
|
def reset_zoom(self):
|
||||||
|
"""Reset zoom to 100%."""
|
||||||
|
if self.original_pixmap is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.zoom_scale = 1.0
|
||||||
|
self._apply_zoom()
|
||||||
|
logger.debug("Reset zoom to 100%")
|
||||||
|
|
||||||
|
def set_zoom(self, scale: float):
|
||||||
|
"""
|
||||||
|
Set zoom to a specific scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scale: Zoom scale (1.0 = 100%)
|
||||||
|
"""
|
||||||
|
if self.original_pixmap is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Clamp to min/max
|
||||||
|
scale = max(self.zoom_min, min(self.zoom_max, scale))
|
||||||
|
|
||||||
|
self.zoom_scale = scale
|
||||||
|
self._apply_zoom()
|
||||||
|
logger.debug(f"Set zoom to {int(self.zoom_scale * 100)}%")
|
||||||
|
|
||||||
|
def get_zoom_percentage(self) -> int:
|
||||||
|
"""
|
||||||
|
Get current zoom level as percentage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Zoom level as integer percentage (e.g., 100 for 100%)
|
||||||
|
"""
|
||||||
|
return int(self.zoom_scale * 100)
|
||||||
|
|
||||||
|
def keyPressEvent(self, event: QKeyEvent):
|
||||||
|
"""Handle keyboard events for zooming."""
|
||||||
|
if event.key() in (Qt.Key_Plus, Qt.Key_Equal):
|
||||||
|
# + or = key (= is the unshifted + on many keyboards)
|
||||||
|
self.zoom_in()
|
||||||
|
event.accept()
|
||||||
|
elif event.key() == Qt.Key_Minus:
|
||||||
|
# - key
|
||||||
|
self.zoom_out()
|
||||||
|
event.accept()
|
||||||
|
elif event.key() == Qt.Key_0 and event.modifiers() == Qt.ControlModifier:
|
||||||
|
# Ctrl+0 to reset zoom
|
||||||
|
self.reset_zoom()
|
||||||
|
event.accept()
|
||||||
|
else:
|
||||||
|
super().keyPressEvent(event)
|
||||||
|
|
||||||
|
def eventFilter(self, obj, event: QEvent) -> bool:
|
||||||
|
"""Event filter to capture wheel events for zooming."""
|
||||||
|
if event.type() == QEvent.Wheel:
|
||||||
|
wheel_event = event
|
||||||
|
if self.original_pixmap is not None:
|
||||||
|
# Get wheel angle delta
|
||||||
|
delta = wheel_event.angleDelta().y()
|
||||||
|
|
||||||
|
# Zoom in/out based on wheel direction
|
||||||
|
if delta > 0:
|
||||||
|
# Scroll up = zoom in
|
||||||
|
new_scale = self.zoom_scale + self.zoom_wheel_step
|
||||||
|
if new_scale <= self.zoom_max:
|
||||||
|
self.zoom_scale = new_scale
|
||||||
|
self._apply_zoom()
|
||||||
|
else:
|
||||||
|
# Scroll down = zoom out
|
||||||
|
new_scale = self.zoom_scale - self.zoom_wheel_step
|
||||||
|
if new_scale >= self.zoom_min:
|
||||||
|
self.zoom_scale = new_scale
|
||||||
|
self._apply_zoom()
|
||||||
|
|
||||||
|
return True # Event handled
|
||||||
|
|
||||||
|
return super().eventFilter(obj, event)
|
||||||
49
src/gui_launcher.py
Normal file
49
src/gui_launcher.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
"""GUI launcher module for microscopy object detection application."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from PySide6.QtWidgets import QApplication
|
||||||
|
from PySide6.QtCore import Qt
|
||||||
|
|
||||||
|
from src import __version__
|
||||||
|
from src.gui.main_window import MainWindow
|
||||||
|
from src.utils.logger import setup_logging
|
||||||
|
from src.utils.config_manager import ConfigManager
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Launch the GUI application."""
|
||||||
|
# Setup logging
|
||||||
|
config_manager = ConfigManager()
|
||||||
|
log_config = config_manager.get_section("logging")
|
||||||
|
setup_logging(
|
||||||
|
log_file=log_config.get("file", "logs/app.log"),
|
||||||
|
level=log_config.get("level", "INFO"),
|
||||||
|
log_format=log_config.get("format"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Enable High DPI scaling
|
||||||
|
QApplication.setHighDpiScaleFactorRoundingPolicy(
|
||||||
|
Qt.HighDpiScaleFactorRoundingPolicy.PassThrough
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create Qt application
|
||||||
|
app = QApplication(sys.argv)
|
||||||
|
app.setApplicationName("Microscopy Object Detection")
|
||||||
|
app.setOrganizationName("MicroscopyLab")
|
||||||
|
app.setApplicationVersion(__version__)
|
||||||
|
|
||||||
|
# Set application style
|
||||||
|
app.setStyle("Fusion")
|
||||||
|
|
||||||
|
# Create and show main window
|
||||||
|
window = MainWindow()
|
||||||
|
window.show()
|
||||||
|
|
||||||
|
# Run application
|
||||||
|
sys.exit(app.exec())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -5,12 +5,12 @@ Handles detection inference and result storage.
|
|||||||
|
|
||||||
from typing import List, Dict, Optional, Callable
|
from typing import List, Dict, Optional, Callable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from PIL import Image
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from src.model.yolo_wrapper import YOLOWrapper
|
from src.model.yolo_wrapper import YOLOWrapper
|
||||||
from src.database.db_manager import DatabaseManager
|
from src.database.db_manager import DatabaseManager
|
||||||
|
from src.utils.image import Image
|
||||||
from src.utils.logger import get_logger
|
from src.utils.logger import get_logger
|
||||||
from src.utils.file_utils import get_relative_path
|
from src.utils.file_utils import get_relative_path
|
||||||
|
|
||||||
@@ -42,6 +42,7 @@ class InferenceEngine:
|
|||||||
relative_path: str,
|
relative_path: str,
|
||||||
conf: float = 0.25,
|
conf: float = 0.25,
|
||||||
save_to_db: bool = True,
|
save_to_db: bool = True,
|
||||||
|
repository_root: Optional[str] = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""
|
"""
|
||||||
Detect objects in a single image.
|
Detect objects in a single image.
|
||||||
@@ -51,48 +52,79 @@ class InferenceEngine:
|
|||||||
relative_path: Relative path from repository root
|
relative_path: Relative path from repository root
|
||||||
conf: Confidence threshold
|
conf: Confidence threshold
|
||||||
save_to_db: Whether to save results to database
|
save_to_db: Whether to save results to database
|
||||||
|
repository_root: Base directory used to compute relative_path (if known)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with detection results
|
Dictionary with detection results
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Normalize storage path (fall back to absolute path when repo root is unknown)
|
||||||
|
stored_relative_path = relative_path
|
||||||
|
if not repository_root:
|
||||||
|
stored_relative_path = str(Path(image_path).resolve())
|
||||||
|
|
||||||
# Get image dimensions
|
# Get image dimensions
|
||||||
img = Image.open(image_path)
|
img = Image(image_path)
|
||||||
width, height = img.size
|
width = img.width
|
||||||
img.close()
|
height = img.height
|
||||||
|
|
||||||
# Perform detection
|
# Perform detection
|
||||||
detections = self.yolo.predict(image_path, conf=conf)
|
detections = self.yolo.predict(image_path, conf=conf)
|
||||||
|
|
||||||
# Add/get image in database
|
# Add/get image in database
|
||||||
image_id = self.db_manager.get_or_create_image(
|
image_id = self.db_manager.get_or_create_image(
|
||||||
relative_path=relative_path,
|
relative_path=stored_relative_path,
|
||||||
filename=Path(image_path).name,
|
filename=Path(image_path).name,
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save detections to database
|
inserted_count = 0
|
||||||
if save_to_db and detections:
|
deleted_count = 0
|
||||||
detection_records = []
|
|
||||||
for det in detections:
|
|
||||||
# Use normalized bbox from detection
|
|
||||||
bbox_normalized = det[
|
|
||||||
"bbox_normalized"
|
|
||||||
] # [x_min, y_min, x_max, y_max]
|
|
||||||
|
|
||||||
record = {
|
# Save detections to database, replacing any previous results for this image/model
|
||||||
"image_id": image_id,
|
if save_to_db:
|
||||||
"model_id": self.model_id,
|
deleted_count = self.db_manager.delete_detections_for_image(
|
||||||
"class_name": det["class_name"],
|
image_id, self.model_id
|
||||||
"bbox": tuple(bbox_normalized),
|
)
|
||||||
"confidence": det["confidence"],
|
if detections:
|
||||||
"metadata": {"class_id": det["class_id"]},
|
detection_records = []
|
||||||
}
|
for det in detections:
|
||||||
detection_records.append(record)
|
# Use normalized bbox from detection
|
||||||
|
bbox_normalized = det[
|
||||||
|
"bbox_normalized"
|
||||||
|
] # [x_min, y_min, x_max, y_max]
|
||||||
|
|
||||||
self.db_manager.add_detections_batch(detection_records)
|
metadata = {
|
||||||
logger.info(f"Saved {len(detection_records)} detections to database")
|
"class_id": det["class_id"],
|
||||||
|
"source_path": str(Path(image_path).resolve()),
|
||||||
|
}
|
||||||
|
if repository_root:
|
||||||
|
metadata["repository_root"] = str(
|
||||||
|
Path(repository_root).resolve()
|
||||||
|
)
|
||||||
|
|
||||||
|
record = {
|
||||||
|
"image_id": image_id,
|
||||||
|
"model_id": self.model_id,
|
||||||
|
"class_name": det["class_name"],
|
||||||
|
"bbox": tuple(bbox_normalized),
|
||||||
|
"confidence": det["confidence"],
|
||||||
|
"segmentation_mask": det.get("segmentation_mask"),
|
||||||
|
"metadata": metadata,
|
||||||
|
}
|
||||||
|
detection_records.append(record)
|
||||||
|
|
||||||
|
inserted_count = self.db_manager.add_detections_batch(
|
||||||
|
detection_records
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Saved {inserted_count} detections to database (replaced {deleted_count})"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"Detection run removed {deleted_count} stale entries but produced no new detections"
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
@@ -141,7 +173,12 @@ class InferenceEngine:
|
|||||||
rel_path = get_relative_path(image_path, repository_root)
|
rel_path = get_relative_path(image_path, repository_root)
|
||||||
|
|
||||||
# Perform detection
|
# Perform detection
|
||||||
result = self.detect_single(image_path, rel_path, conf)
|
result = self.detect_single(
|
||||||
|
image_path,
|
||||||
|
rel_path,
|
||||||
|
conf=conf,
|
||||||
|
repository_root=repository_root,
|
||||||
|
)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
# Update progress
|
# Update progress
|
||||||
@@ -160,6 +197,7 @@ class InferenceEngine:
|
|||||||
conf: float = 0.25,
|
conf: float = 0.25,
|
||||||
bbox_thickness: int = 2,
|
bbox_thickness: int = 2,
|
||||||
bbox_colors: Optional[Dict[str, str]] = None,
|
bbox_colors: Optional[Dict[str, str]] = None,
|
||||||
|
draw_masks: bool = True,
|
||||||
) -> tuple:
|
) -> tuple:
|
||||||
"""
|
"""
|
||||||
Detect objects and return annotated image.
|
Detect objects and return annotated image.
|
||||||
@@ -169,6 +207,7 @@ class InferenceEngine:
|
|||||||
conf: Confidence threshold
|
conf: Confidence threshold
|
||||||
bbox_thickness: Thickness of bounding boxes
|
bbox_thickness: Thickness of bounding boxes
|
||||||
bbox_colors: Dictionary mapping class names to hex colors
|
bbox_colors: Dictionary mapping class names to hex colors
|
||||||
|
draw_masks: Whether to draw segmentation masks (if available)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (detections, annotated_image_array)
|
Tuple of (detections, annotated_image_array)
|
||||||
@@ -189,12 +228,8 @@ class InferenceEngine:
|
|||||||
bbox_colors = {}
|
bbox_colors = {}
|
||||||
default_color = self._hex_to_bgr(bbox_colors.get("default", "#00FF00"))
|
default_color = self._hex_to_bgr(bbox_colors.get("default", "#00FF00"))
|
||||||
|
|
||||||
# Draw bounding boxes
|
# Draw detections
|
||||||
for det in detections:
|
for det in detections:
|
||||||
# Get absolute coordinates
|
|
||||||
bbox_abs = det["bbox_absolute"]
|
|
||||||
x1, y1, x2, y2 = [int(v) for v in bbox_abs]
|
|
||||||
|
|
||||||
# Get color for this class
|
# Get color for this class
|
||||||
class_name = det["class_name"]
|
class_name = det["class_name"]
|
||||||
color_hex = bbox_colors.get(
|
color_hex = bbox_colors.get(
|
||||||
@@ -202,7 +237,33 @@ class InferenceEngine:
|
|||||||
)
|
)
|
||||||
color = self._hex_to_bgr(color_hex)
|
color = self._hex_to_bgr(color_hex)
|
||||||
|
|
||||||
# Draw box
|
# Draw segmentation mask if available and requested
|
||||||
|
if draw_masks and det.get("segmentation_mask"):
|
||||||
|
mask_normalized = det["segmentation_mask"]
|
||||||
|
if mask_normalized and len(mask_normalized) > 0:
|
||||||
|
# Convert normalized coordinates to absolute pixels
|
||||||
|
mask_points = np.array(
|
||||||
|
[
|
||||||
|
[int(pt[0] * width), int(pt[1] * height)]
|
||||||
|
for pt in mask_normalized
|
||||||
|
],
|
||||||
|
dtype=np.int32,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a semi-transparent overlay
|
||||||
|
overlay = img.copy()
|
||||||
|
cv2.fillPoly(overlay, [mask_points], color)
|
||||||
|
# Blend with original image (30% opacity)
|
||||||
|
cv2.addWeighted(overlay, 0.3, img, 0.7, 0, img)
|
||||||
|
|
||||||
|
# Draw mask contour
|
||||||
|
cv2.polylines(img, [mask_points], True, color, bbox_thickness)
|
||||||
|
|
||||||
|
# Get absolute coordinates for bounding box
|
||||||
|
bbox_abs = det["bbox_absolute"]
|
||||||
|
x1, y1, x2, y2 = [int(v) for v in bbox_abs]
|
||||||
|
|
||||||
|
# Draw bounding box
|
||||||
cv2.rectangle(img, (x1, y1), (x2, y2), color, bbox_thickness)
|
cv2.rectangle(img, (x1, y1), (x2, y2), color, bbox_thickness)
|
||||||
|
|
||||||
# Prepare label
|
# Prepare label
|
||||||
|
|||||||
@@ -1,13 +1,21 @@
|
|||||||
"""
|
"""YOLO model wrapper for the microscopy object detection application.
|
||||||
YOLO model wrapper for the microscopy object detection application.
|
|
||||||
Provides a clean interface to YOLOv8 for training, validation, and inference.
|
Notes on 16-bit TIFF support:
|
||||||
|
- Ultralytics training defaults assume 8-bit images and normalize by dividing by 255.
|
||||||
|
- This project can patch Ultralytics at runtime to decode TIFFs via `tifffile` and
|
||||||
|
normalize `uint16` correctly.
|
||||||
|
|
||||||
|
See [`apply_ultralytics_16bit_tiff_patches()`](src/utils/ultralytics_16bit_patch.py:1).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from ultralytics import YOLO
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, List, Dict, Callable, Any
|
from typing import Optional, List, Dict, Callable, Any
|
||||||
import torch
|
import torch
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
from src.utils.image import Image
|
||||||
from src.utils.logger import get_logger
|
from src.utils.logger import get_logger
|
||||||
|
from src.utils.ultralytics_16bit_patch import apply_ultralytics_16bit_tiff_patches
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -16,7 +24,7 @@ logger = get_logger(__name__)
|
|||||||
class YOLOWrapper:
|
class YOLOWrapper:
|
||||||
"""Wrapper for YOLOv8 model operations."""
|
"""Wrapper for YOLOv8 model operations."""
|
||||||
|
|
||||||
def __init__(self, model_path: str = "yolov8s.pt"):
|
def __init__(self, model_path: str = "yolov8s-seg.pt"):
|
||||||
"""
|
"""
|
||||||
Initialize YOLO model.
|
Initialize YOLO model.
|
||||||
|
|
||||||
@@ -28,6 +36,9 @@ class YOLOWrapper:
|
|||||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
logger.info(f"YOLOWrapper initialized with device: {self.device}")
|
logger.info(f"YOLOWrapper initialized with device: {self.device}")
|
||||||
|
|
||||||
|
# Apply Ultralytics runtime patches early (before first import/instantiation of YOLO datasets/trainers).
|
||||||
|
apply_ultralytics_16bit_tiff_patches()
|
||||||
|
|
||||||
def load_model(self) -> bool:
|
def load_model(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Load YOLO model from path.
|
Load YOLO model from path.
|
||||||
@@ -37,6 +48,9 @@ class YOLOWrapper:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info(f"Loading YOLO model from {self.model_path}")
|
logger.info(f"Loading YOLO model from {self.model_path}")
|
||||||
|
# Import YOLO lazily to ensure runtime patches are applied first.
|
||||||
|
from ultralytics import YOLO
|
||||||
|
|
||||||
self.model = YOLO(self.model_path)
|
self.model = YOLO(self.model_path)
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
logger.info("Model loaded successfully")
|
logger.info("Model loaded successfully")
|
||||||
@@ -55,6 +69,7 @@ class YOLOWrapper:
|
|||||||
save_dir: str = "data/models",
|
save_dir: str = "data/models",
|
||||||
name: str = "custom_model",
|
name: str = "custom_model",
|
||||||
resume: bool = False,
|
resume: bool = False,
|
||||||
|
callbacks: Optional[Dict[str, Callable]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@@ -69,19 +84,29 @@ class YOLOWrapper:
|
|||||||
save_dir: Directory to save trained model
|
save_dir: Directory to save trained model
|
||||||
name: Name for the training run
|
name: Name for the training run
|
||||||
resume: Resume training from last checkpoint
|
resume: Resume training from last checkpoint
|
||||||
|
callbacks: Optional Ultralytics callback dictionary
|
||||||
**kwargs: Additional training arguments
|
**kwargs: Additional training arguments
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with training results
|
Dictionary with training results
|
||||||
"""
|
"""
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
self.load_model()
|
if not self.load_model():
|
||||||
|
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Starting training: {name}")
|
logger.info(f"Starting training: {name}")
|
||||||
logger.info(
|
logger.info(f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}")
|
||||||
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
|
|
||||||
)
|
# Defaults for 16-bit safety: disable augmentations that force uint8 and HSV ops that assume 0..255.
|
||||||
|
# Users can override by passing explicit kwargs.
|
||||||
|
kwargs.setdefault("mosaic", 0.0)
|
||||||
|
kwargs.setdefault("mixup", 0.0)
|
||||||
|
kwargs.setdefault("cutmix", 0.0)
|
||||||
|
kwargs.setdefault("copy_paste", 0.0)
|
||||||
|
kwargs.setdefault("hsv_h", 0.0)
|
||||||
|
kwargs.setdefault("hsv_s", 0.0)
|
||||||
|
kwargs.setdefault("hsv_v", 0.0)
|
||||||
|
|
||||||
# Train the model
|
# Train the model
|
||||||
results = self.model.train(
|
results = self.model.train(
|
||||||
@@ -117,13 +142,12 @@ class YOLOWrapper:
|
|||||||
Dictionary with validation metrics
|
Dictionary with validation metrics
|
||||||
"""
|
"""
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
self.load_model()
|
if not self.load_model():
|
||||||
|
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Starting validation on {split} split")
|
logger.info(f"Starting validation on {split} split")
|
||||||
results = self.model.val(
|
results = self.model.val(data=data_yaml, split=split, device=self.device, **kwargs)
|
||||||
data=data_yaml, split=split, device=self.device, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Validation completed successfully")
|
logger.info("Validation completed successfully")
|
||||||
return self._format_validation_results(results)
|
return self._format_validation_results(results)
|
||||||
@@ -158,10 +182,13 @@ class YOLOWrapper:
|
|||||||
List of detection dictionaries
|
List of detection dictionaries
|
||||||
"""
|
"""
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
self.load_model()
|
if not self.load_model():
|
||||||
|
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
||||||
|
|
||||||
|
prepared_source, cleanup_path = self._prepare_source(source)
|
||||||
|
imgsz = 1088
|
||||||
try:
|
try:
|
||||||
logger.info(f"Running inference on {source}")
|
logger.info(f"Running inference on {source} -> prepared_source {prepared_source}")
|
||||||
results = self.model.predict(
|
results = self.model.predict(
|
||||||
source=source,
|
source=source,
|
||||||
conf=conf,
|
conf=conf,
|
||||||
@@ -170,6 +197,7 @@ class YOLOWrapper:
|
|||||||
save_txt=save_txt,
|
save_txt=save_txt,
|
||||||
save_conf=save_conf,
|
save_conf=save_conf,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
imgsz=imgsz,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -180,10 +208,14 @@ class YOLOWrapper:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during inference: {e}")
|
logger.error(f"Error during inference: {e}")
|
||||||
raise
|
raise
|
||||||
|
finally:
|
||||||
|
if 0: # cleanup_path:
|
||||||
|
try:
|
||||||
|
os.remove(cleanup_path)
|
||||||
|
except OSError as cleanup_error:
|
||||||
|
logger.warning(f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}")
|
||||||
|
|
||||||
def export(
|
def export(self, format: str = "onnx", output_path: Optional[str] = None, **kwargs) -> str:
|
||||||
self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
|
|
||||||
) -> str:
|
|
||||||
"""
|
"""
|
||||||
Export model to different format.
|
Export model to different format.
|
||||||
|
|
||||||
@@ -196,7 +228,8 @@ class YOLOWrapper:
|
|||||||
Path to exported model
|
Path to exported model
|
||||||
"""
|
"""
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
self.load_model()
|
if not self.load_model():
|
||||||
|
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Exporting model to {format} format")
|
logger.info(f"Exporting model to {format} format")
|
||||||
@@ -208,13 +241,35 @@ class YOLOWrapper:
|
|||||||
logger.error(f"Error exporting model: {e}")
|
logger.error(f"Error exporting model: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def _prepare_source(self, source):
|
||||||
|
"""Convert single-channel images to RGB temporarily for inference."""
|
||||||
|
cleanup_path = None
|
||||||
|
|
||||||
|
if isinstance(source, (str, Path)):
|
||||||
|
source_path = Path(source)
|
||||||
|
if source_path.is_file():
|
||||||
|
try:
|
||||||
|
img_obj = Image(source_path)
|
||||||
|
suffix = source_path.suffix or ".png"
|
||||||
|
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
|
||||||
|
tmp_path = tmp.name
|
||||||
|
tmp.close()
|
||||||
|
img_obj.save(tmp_path)
|
||||||
|
cleanup_path = tmp_path
|
||||||
|
logger.info(f"Converted image {source_path} to RGB for inference at {tmp_path}")
|
||||||
|
return tmp_path, cleanup_path
|
||||||
|
except Exception as convert_error:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to preprocess {source_path} as RGB, continuing with original file: {convert_error}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return source, cleanup_path
|
||||||
|
|
||||||
def _format_training_results(self, results) -> Dict[str, Any]:
|
def _format_training_results(self, results) -> Dict[str, Any]:
|
||||||
"""Format training results into dictionary."""
|
"""Format training results into dictionary."""
|
||||||
try:
|
try:
|
||||||
# Get the results dict
|
# Get the results dict
|
||||||
results_dict = (
|
results_dict = results.results_dict if hasattr(results, "results_dict") else {}
|
||||||
results.results_dict if hasattr(results, "results_dict") else {}
|
|
||||||
)
|
|
||||||
|
|
||||||
formatted = {
|
formatted = {
|
||||||
"success": True,
|
"success": True,
|
||||||
@@ -247,9 +302,7 @@ class YOLOWrapper:
|
|||||||
"mAP50-95": float(box_metrics.map),
|
"mAP50-95": float(box_metrics.map),
|
||||||
"precision": float(box_metrics.mp),
|
"precision": float(box_metrics.mp),
|
||||||
"recall": float(box_metrics.mr),
|
"recall": float(box_metrics.mr),
|
||||||
"fitness": (
|
"fitness": (float(results.fitness) if hasattr(results, "fitness") else 0.0),
|
||||||
float(results.fitness) if hasattr(results, "fitness") else 0.0
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add per-class metrics if available
|
# Add per-class metrics if available
|
||||||
@@ -259,11 +312,7 @@ class YOLOWrapper:
|
|||||||
if idx < len(box_metrics.ap):
|
if idx < len(box_metrics.ap):
|
||||||
class_metrics[name] = {
|
class_metrics[name] = {
|
||||||
"ap": float(box_metrics.ap[idx]),
|
"ap": float(box_metrics.ap[idx]),
|
||||||
"ap50": (
|
"ap50": (float(box_metrics.ap50[idx]) if hasattr(box_metrics, "ap50") else 0.0),
|
||||||
float(box_metrics.ap50[idx])
|
|
||||||
if hasattr(box_metrics, "ap50")
|
|
||||||
else 0.0
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
formatted["class_metrics"] = class_metrics
|
formatted["class_metrics"] = class_metrics
|
||||||
|
|
||||||
@@ -282,6 +331,10 @@ class YOLOWrapper:
|
|||||||
boxes = result.boxes
|
boxes = result.boxes
|
||||||
image_path = str(result.path)
|
image_path = str(result.path)
|
||||||
orig_shape = result.orig_shape # (height, width)
|
orig_shape = result.orig_shape # (height, width)
|
||||||
|
height, width = orig_shape
|
||||||
|
|
||||||
|
# Check if this is a segmentation model with masks
|
||||||
|
has_masks = hasattr(result, "masks") and result.masks is not None
|
||||||
|
|
||||||
for i in range(len(boxes)):
|
for i in range(len(boxes)):
|
||||||
# Get normalized coordinates
|
# Get normalized coordinates
|
||||||
@@ -292,13 +345,32 @@ class YOLOWrapper:
|
|||||||
"class_id": int(boxes.cls[i]),
|
"class_id": int(boxes.cls[i]),
|
||||||
"class_name": result.names[int(boxes.cls[i])],
|
"class_name": result.names[int(boxes.cls[i])],
|
||||||
"confidence": float(boxes.conf[i]),
|
"confidence": float(boxes.conf[i]),
|
||||||
"bbox_normalized": [
|
"bbox_normalized": [float(v) for v in xyxyn], # [x_min, y_min, x_max, y_max]
|
||||||
float(v) for v in xyxyn
|
"bbox_absolute": [float(v) for v in boxes.xyxy[i].cpu().numpy()], # Absolute pixels
|
||||||
], # [x_min, y_min, x_max, y_max]
|
|
||||||
"bbox_absolute": [
|
|
||||||
float(v) for v in boxes.xyxy[i].cpu().numpy()
|
|
||||||
], # Absolute pixels
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Extract segmentation mask if available
|
||||||
|
if has_masks:
|
||||||
|
try:
|
||||||
|
# Get the mask for this detection
|
||||||
|
mask_data = result.masks.xy[i] # Polygon coordinates in absolute pixels
|
||||||
|
|
||||||
|
# Convert to normalized coordinates
|
||||||
|
if len(mask_data) > 0:
|
||||||
|
mask_normalized = []
|
||||||
|
for point in mask_data:
|
||||||
|
x_norm = float(point[0]) / width
|
||||||
|
y_norm = float(point[1]) / height
|
||||||
|
mask_normalized.append([x_norm, y_norm])
|
||||||
|
detection["segmentation_mask"] = mask_normalized
|
||||||
|
else:
|
||||||
|
detection["segmentation_mask"] = None
|
||||||
|
except Exception as mask_error:
|
||||||
|
logger.warning(f"Error extracting mask for detection {i}: {mask_error}")
|
||||||
|
detection["segmentation_mask"] = None
|
||||||
|
else:
|
||||||
|
detection["segmentation_mask"] = None
|
||||||
|
|
||||||
detections.append(detection)
|
detections.append(detection)
|
||||||
|
|
||||||
return detections
|
return detections
|
||||||
@@ -308,9 +380,7 @@ class YOLOWrapper:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_bbox_format(
|
def convert_bbox_format(bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy") -> List[float]:
|
||||||
bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy"
|
|
||||||
) -> List[float]:
|
|
||||||
"""
|
"""
|
||||||
Convert bounding box between formats.
|
Convert bounding box between formats.
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,7 @@
|
|||||||
|
"""
|
||||||
|
Utility modules for the microscopy object detection application.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.utils.image import Image, ImageLoadError
|
||||||
|
|
||||||
|
__all__ = ["Image", "ImageLoadError"]
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import yaml
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
from src.utils.logger import get_logger
|
from src.utils.logger import get_logger
|
||||||
|
from src.utils.image import Image
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -46,18 +47,15 @@ class ConfigManager:
|
|||||||
"database": {"path": "data/detections.db"},
|
"database": {"path": "data/detections.db"},
|
||||||
"image_repository": {
|
"image_repository": {
|
||||||
"base_path": "",
|
"base_path": "",
|
||||||
"allowed_extensions": [
|
"allowed_extensions": Image.SUPPORTED_EXTENSIONS,
|
||||||
".jpg",
|
|
||||||
".jpeg",
|
|
||||||
".png",
|
|
||||||
".tif",
|
|
||||||
".tiff",
|
|
||||||
".bmp",
|
|
||||||
],
|
|
||||||
},
|
},
|
||||||
"models": {
|
"models": {
|
||||||
"default_base_model": "yolov8s.pt",
|
"default_base_model": "yolov8s-seg.pt",
|
||||||
"models_directory": "data/models",
|
"models_directory": "data/models",
|
||||||
|
"base_model_choices": [
|
||||||
|
"yolov8s-seg.pt",
|
||||||
|
"yolo11s-seg.pt",
|
||||||
|
],
|
||||||
},
|
},
|
||||||
"training": {
|
"training": {
|
||||||
"default_epochs": 100,
|
"default_epochs": 100,
|
||||||
@@ -65,6 +63,20 @@ class ConfigManager:
|
|||||||
"default_imgsz": 640,
|
"default_imgsz": 640,
|
||||||
"default_patience": 50,
|
"default_patience": 50,
|
||||||
"default_lr0": 0.01,
|
"default_lr0": 0.01,
|
||||||
|
"two_stage": {
|
||||||
|
"enabled": False,
|
||||||
|
"stage1": {
|
||||||
|
"epochs": 20,
|
||||||
|
"lr0": 0.0005,
|
||||||
|
"patience": 10,
|
||||||
|
"freeze": 10,
|
||||||
|
},
|
||||||
|
"stage2": {
|
||||||
|
"epochs": 150,
|
||||||
|
"lr0": 0.0003,
|
||||||
|
"patience": 30,
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"detection": {
|
"detection": {
|
||||||
"default_confidence": 0.25,
|
"default_confidence": 0.25,
|
||||||
@@ -213,6 +225,4 @@ class ConfigManager:
|
|||||||
|
|
||||||
def get_allowed_extensions(self) -> list:
|
def get_allowed_extensions(self) -> list:
|
||||||
"""Get list of allowed image file extensions."""
|
"""Get list of allowed image file extensions."""
|
||||||
return self.get(
|
return self.get("image_repository.allowed_extensions", Image.SUPPORTED_EXTENSIONS)
|
||||||
"image_repository.allowed_extensions", [".jpg", ".jpeg", ".png"]
|
|
||||||
)
|
|
||||||
|
|||||||
103
src/utils/create_mask_from_detection.py
Normal file
103
src/utils/create_mask_from_detection.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from skimage.draw import polygon
|
||||||
|
from tifffile import TiffFile
|
||||||
|
|
||||||
|
from src.database.db_manager import DatabaseManager
|
||||||
|
|
||||||
|
|
||||||
|
def read_image(image_path: Path) -> np.ndarray:
|
||||||
|
metadata = {}
|
||||||
|
with TiffFile(image_path) as tif:
|
||||||
|
image = tif.asarray()
|
||||||
|
metadata = tif.imagej_metadata
|
||||||
|
return image, metadata
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
polygon_vertices = np.array([[10, 10], [50, 10], [50, 50], [10, 50]])
|
||||||
|
image = np.zeros((100, 100), dtype=np.uint8)
|
||||||
|
rr, cc = polygon(polygon_vertices[:, 0], polygon_vertices[:, 1])
|
||||||
|
image[rr, cc] = 255
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
db = DatabaseManager()
|
||||||
|
model_name = "c17"
|
||||||
|
model_id = db.get_models(filters={"model_name": model_name})[0]["id"]
|
||||||
|
print(f"Model name {model_name}, id {model_id}")
|
||||||
|
detections = db.get_detections(filters={"model_id": model_id})
|
||||||
|
|
||||||
|
file_stems = set()
|
||||||
|
|
||||||
|
for detection in detections:
|
||||||
|
file_stems.add(detection["image_filename"].split("_")[0])
|
||||||
|
|
||||||
|
print("Files:", file_stems)
|
||||||
|
|
||||||
|
for stem in file_stems:
|
||||||
|
print(stem)
|
||||||
|
detections = db.get_detections(filters={"model_id": model_id, "i.filename": f"LIKE %{stem}%"})
|
||||||
|
annotations = []
|
||||||
|
for detection in detections:
|
||||||
|
source_path = Path(detection["metadata"]["source_path"])
|
||||||
|
image, metadata = read_image(source_path)
|
||||||
|
|
||||||
|
offset = np.array(list(map(int, metadata["tile_section"].split(","))))[::-1]
|
||||||
|
scale = np.array(list(map(int, metadata["patch_size"].split(","))))[::-1]
|
||||||
|
# tile_size = np.array(list(map(int, metadata["tile_size"].split(","))))
|
||||||
|
segmentation = np.array(detection["segmentation_mask"]) # * tile_size
|
||||||
|
|
||||||
|
# print(source_path, image, metadata, segmentation.shape)
|
||||||
|
# print(offset)
|
||||||
|
# print(scale)
|
||||||
|
# print(segmentation)
|
||||||
|
|
||||||
|
# segmentation = (segmentation + offset * tile_size) / (tile_size * scale)
|
||||||
|
segmentation = (segmentation + offset) / scale
|
||||||
|
|
||||||
|
yolo_annotation = f"{detection['metadata']['class_id']} " + " ".join(
|
||||||
|
[f"{x:.6f} {y:.6f}" for x, y in segmentation]
|
||||||
|
)
|
||||||
|
annotations.append(yolo_annotation)
|
||||||
|
# print(segmentation)
|
||||||
|
# print(yolo_annotation)
|
||||||
|
|
||||||
|
# aa
|
||||||
|
print(
|
||||||
|
" ",
|
||||||
|
detection["model_name"],
|
||||||
|
detection["image_id"],
|
||||||
|
detection["image_filename"],
|
||||||
|
source_path,
|
||||||
|
metadata["label_path"],
|
||||||
|
)
|
||||||
|
# section_i_section_j = detection["image_filename"].split("_")[1].split(".")[0]
|
||||||
|
# print(" ", section_i_section_j)
|
||||||
|
|
||||||
|
label_path = metadata["label_path"]
|
||||||
|
print(" ", label_path)
|
||||||
|
with open(label_path, "w") as f:
|
||||||
|
f.write("\n".join(annotations))
|
||||||
|
|
||||||
|
exit()
|
||||||
|
|
||||||
|
for detection in detections:
|
||||||
|
print(detection["model_name"], detection["image_id"], detection["image_filename"])
|
||||||
|
|
||||||
|
print(detections[0])
|
||||||
|
# polygon_vertices = np.array([[10, 10], [50, 10], [50, 50], [10, 50]])
|
||||||
|
|
||||||
|
# image = np.zeros((100, 100), dtype=np.uint8)
|
||||||
|
|
||||||
|
# rr, cc = polygon(polygon_vertices[:, 0], polygon_vertices[:, 1])
|
||||||
|
|
||||||
|
# image[rr, cc] = 255
|
||||||
|
|
||||||
|
# import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
# plt.imshow(image, cmap='gray')
|
||||||
|
# plt.show()
|
||||||
@@ -28,7 +28,9 @@ def get_image_files(
|
|||||||
List of absolute paths to image files
|
List of absolute paths to image files
|
||||||
"""
|
"""
|
||||||
if allowed_extensions is None:
|
if allowed_extensions is None:
|
||||||
allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
|
from src.utils.image import Image
|
||||||
|
|
||||||
|
allowed_extensions = Image.SUPPORTED_EXTENSIONS
|
||||||
|
|
||||||
# Normalize extensions to lowercase
|
# Normalize extensions to lowercase
|
||||||
allowed_extensions = [ext.lower() for ext in allowed_extensions]
|
allowed_extensions = [ext.lower() for ext in allowed_extensions]
|
||||||
@@ -204,7 +206,9 @@ def is_image_file(
|
|||||||
True if file is an image
|
True if file is an image
|
||||||
"""
|
"""
|
||||||
if allowed_extensions is None:
|
if allowed_extensions is None:
|
||||||
allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
|
from src.utils.image import Image
|
||||||
|
|
||||||
|
allowed_extensions = Image.SUPPORTED_EXTENSIONS
|
||||||
|
|
||||||
extension = Path(file_path).suffix.lower()
|
extension = Path(file_path).suffix.lower()
|
||||||
return extension in [ext.lower() for ext in allowed_extensions]
|
return extension in [ext.lower() for ext in allowed_extensions]
|
||||||
|
|||||||
370
src/utils/image.py
Normal file
370
src/utils/image.py
Normal file
@@ -0,0 +1,370 @@
|
|||||||
|
"""
|
||||||
|
Image loading and management utilities for the microscopy object detection application.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
from src.utils.logger import get_logger
|
||||||
|
from src.utils.file_utils import validate_file_path, is_image_file
|
||||||
|
|
||||||
|
from PySide6.QtGui import QImage
|
||||||
|
|
||||||
|
from tifffile import imread, imwrite
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.5) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Convert a grayscale image to a pseudo-RGB image using a gamma correction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arr: Input grayscale image as numpy array
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pseudo-RGB image as numpy array
|
||||||
|
"""
|
||||||
|
if arr.ndim != 2:
|
||||||
|
raise ValueError("Input array must be a grayscale image with shape (H, W)")
|
||||||
|
|
||||||
|
a1 = arr.copy().astype(np.float32)
|
||||||
|
a1 -= np.percentile(a1, 2)
|
||||||
|
a1[a1 < 0] = 0
|
||||||
|
p999 = np.percentile(a1, 99.9)
|
||||||
|
a1[a1 > p999] = p999
|
||||||
|
a1 /= a1.max()
|
||||||
|
|
||||||
|
if 1:
|
||||||
|
a2 = a1.copy()
|
||||||
|
a2 = a2**gamma
|
||||||
|
a2 /= a2.max()
|
||||||
|
|
||||||
|
a3 = a1.copy()
|
||||||
|
p9999 = np.percentile(a3, 99.99)
|
||||||
|
a3[a3 > p9999] = p9999
|
||||||
|
a3 /= a3.max()
|
||||||
|
|
||||||
|
# return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
|
||||||
|
# return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
|
||||||
|
out = np.stack([a1, a2, a3], axis=0)
|
||||||
|
# print(any(np.isnan(out).flatten()))
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ImageLoadError(Exception):
|
||||||
|
"""Exception raised when an image cannot be loaded."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Image:
|
||||||
|
"""
|
||||||
|
A class for loading and managing images from file paths.
|
||||||
|
|
||||||
|
Supports multiple image formats: .jpg, .jpeg, .png, .tif, .tiff, .bmp
|
||||||
|
Provides access to image data in multiple formats (OpenCV/numpy, PIL).
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
path: Path to the image file
|
||||||
|
data: Image data as numpy array (OpenCV format, BGR)
|
||||||
|
pil_image: Image data as PIL Image (RGB)
|
||||||
|
width: Image width in pixels
|
||||||
|
height: Image height in pixels
|
||||||
|
channels: Number of color channels
|
||||||
|
format: Image file format
|
||||||
|
size_bytes: File size in bytes
|
||||||
|
"""
|
||||||
|
|
||||||
|
SUPPORTED_EXTENSIONS = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
|
||||||
|
|
||||||
|
def __init__(self, image_path: Union[str, Path]):
|
||||||
|
"""
|
||||||
|
Initialize an Image object by loading from a file path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_path: Path to the image file (string or Path object)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImageLoadError: If the image cannot be loaded or is invalid
|
||||||
|
"""
|
||||||
|
self.path = Path(image_path)
|
||||||
|
self._data: Optional[np.ndarray] = None
|
||||||
|
self._width: int = 0
|
||||||
|
self._height: int = 0
|
||||||
|
self._channels: int = 0
|
||||||
|
self._format: str = ""
|
||||||
|
self._size_bytes: int = 0
|
||||||
|
self._dtype: Optional[np.dtype] = None
|
||||||
|
|
||||||
|
# Load the image
|
||||||
|
self._load()
|
||||||
|
|
||||||
|
def _load(self) -> None:
|
||||||
|
"""
|
||||||
|
Load the image from disk.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImageLoadError: If the image cannot be loaded
|
||||||
|
"""
|
||||||
|
# Validate path
|
||||||
|
if not validate_file_path(str(self.path), must_exist=True):
|
||||||
|
raise ImageLoadError(f"Invalid or non-existent file path: {self.path}")
|
||||||
|
|
||||||
|
# Check file extension
|
||||||
|
if not is_image_file(str(self.path), self.SUPPORTED_EXTENSIONS):
|
||||||
|
ext = self.path.suffix.lower()
|
||||||
|
raise ImageLoadError(
|
||||||
|
f"Unsupported image format: {ext}. " f"Supported formats: {', '.join(self.SUPPORTED_EXTENSIONS)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.path.suffix.lower() in [".tif", ".tiff"]:
|
||||||
|
self._data = imread(str(self.path))
|
||||||
|
else:
|
||||||
|
# raise NotImplementedError("RGB is not implemented")
|
||||||
|
# Load with OpenCV (returns BGR format)
|
||||||
|
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
|
||||||
|
|
||||||
|
if self._data is None:
|
||||||
|
raise ImageLoadError(f"Failed to load image with OpenCV: {self.path}")
|
||||||
|
|
||||||
|
# Extract metadata
|
||||||
|
# print(self._data.shape)
|
||||||
|
if len(self._data.shape) == 2:
|
||||||
|
self._height, self._width = self._data.shape[:2]
|
||||||
|
self._channels = 1
|
||||||
|
else:
|
||||||
|
self._height, self._width = self._data.shape[1:]
|
||||||
|
self._channels = self._data.shape[0]
|
||||||
|
# self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1
|
||||||
|
self._format = self.path.suffix.lower().lstrip(".")
|
||||||
|
self._size_bytes = self.path.stat().st_size
|
||||||
|
self._dtype = self._data.dtype
|
||||||
|
|
||||||
|
if 0:
|
||||||
|
logger.info(
|
||||||
|
f"Successfully loaded image: {self.path.name} "
|
||||||
|
f"({self._width}x{self._height}, {self._channels} channels, "
|
||||||
|
f"{self._format.upper()})"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading image {self.path}: {e}")
|
||||||
|
raise ImageLoadError(f"Failed to load image: {e}") from e
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Get image data as numpy array (OpenCV format, BGR or grayscale).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image data as numpy array
|
||||||
|
"""
|
||||||
|
if self._data is None:
|
||||||
|
raise ImageLoadError("Image data not available")
|
||||||
|
return self._data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def width(self) -> int:
|
||||||
|
"""Get image width in pixels."""
|
||||||
|
return self._width
|
||||||
|
|
||||||
|
@property
|
||||||
|
def height(self) -> int:
|
||||||
|
"""Get image height in pixels."""
|
||||||
|
return self._height
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self) -> Tuple[int, int, int]:
|
||||||
|
"""
|
||||||
|
Get image shape as (height, width, channels).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (height, width, channels)
|
||||||
|
"""
|
||||||
|
print("shape", self._height, self._width, self._channels)
|
||||||
|
return (self._height, self._width, self._channels)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def channels(self) -> int:
|
||||||
|
"""Get number of color channels."""
|
||||||
|
return self._channels
|
||||||
|
|
||||||
|
@property
|
||||||
|
def format(self) -> str:
|
||||||
|
"""Get image file format (e.g., 'jpg', 'png')."""
|
||||||
|
return self._format
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size_bytes(self) -> int:
|
||||||
|
"""Get file size in bytes."""
|
||||||
|
return self._size_bytes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size_mb(self) -> float:
|
||||||
|
"""Get file size in megabytes."""
|
||||||
|
return self._size_bytes / (1024 * 1024)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self) -> np.dtype:
|
||||||
|
"""Get the data type of the image array."""
|
||||||
|
|
||||||
|
if self._dtype is None:
|
||||||
|
raise ImageLoadError("Image dtype not available")
|
||||||
|
return self._dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def qtimage_format(self) -> QImage.Format:
|
||||||
|
"""
|
||||||
|
Get the appropriate QImage format for the image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
QImage.Format enum value
|
||||||
|
"""
|
||||||
|
if self._channels == 3:
|
||||||
|
return QImage.Format_RGB888
|
||||||
|
elif self._channels == 4:
|
||||||
|
return QImage.Format_RGBA8888
|
||||||
|
elif self._channels == 1:
|
||||||
|
if self._dtype == np.uint16:
|
||||||
|
return QImage.Format_Grayscale16
|
||||||
|
elif self._dtype == np.uint8:
|
||||||
|
return QImage.Format_Grayscale8
|
||||||
|
elif self._dtype == np.float32:
|
||||||
|
return QImage.Format_BGR30
|
||||||
|
else:
|
||||||
|
raise ImageLoadError(f"Unsupported number of channels: {self._channels}")
|
||||||
|
|
||||||
|
def get_rgb(self) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Get image data as RGB numpy array.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image data in RGB format as numpy array
|
||||||
|
"""
|
||||||
|
if self.channels == 1:
|
||||||
|
img = get_pseudo_rgb(self.data)
|
||||||
|
self._dtype = img.dtype
|
||||||
|
return img, True
|
||||||
|
|
||||||
|
elif self._channels == 3:
|
||||||
|
return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB), False
|
||||||
|
elif self._channels == 4:
|
||||||
|
return cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA), False
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# else:
|
||||||
|
# return self._data
|
||||||
|
|
||||||
|
def get_qt_rgb(self) -> np.ascontiguousarray:
|
||||||
|
# we keep data as (C, H, W)
|
||||||
|
_img, pseudo = self.get_rgb()
|
||||||
|
|
||||||
|
if pseudo:
|
||||||
|
img = np.zeros((self.height, self.width, 4), dtype=np.float32)
|
||||||
|
img[..., 0] = _img[0] # R gradient
|
||||||
|
img[..., 1] = _img[1] # G gradient
|
||||||
|
img[..., 2] = _img[2] # B constant
|
||||||
|
img[..., 3] = 1.0 # A = 1.0 (opaque)
|
||||||
|
|
||||||
|
return np.ascontiguousarray(img)
|
||||||
|
else:
|
||||||
|
return np.ascontiguousarray(_img)
|
||||||
|
|
||||||
|
def get_grayscale(self) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Get image as grayscale numpy array.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Grayscale image as numpy array
|
||||||
|
"""
|
||||||
|
if self._channels == 1:
|
||||||
|
return self._data
|
||||||
|
else:
|
||||||
|
return cv2.cvtColor(self._data, cv2.COLOR_BGR2GRAY)
|
||||||
|
|
||||||
|
def copy(self) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Get a copy of the image data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Copy of image data as numpy array
|
||||||
|
"""
|
||||||
|
return self._data.copy()
|
||||||
|
|
||||||
|
def resize(self, width: int, height: int) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Resize the image to specified dimensions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
width: Target width in pixels
|
||||||
|
height: Target height in pixels
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resized image as numpy array (does not modify original)
|
||||||
|
"""
|
||||||
|
return cv2.resize(self._data, (width, height))
|
||||||
|
|
||||||
|
def is_grayscale(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if image is grayscale.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if image is grayscale (1 channel)
|
||||||
|
"""
|
||||||
|
return self._channels == 1
|
||||||
|
|
||||||
|
def is_color(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if image is color.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if image has 3 or more channels
|
||||||
|
"""
|
||||||
|
return self._channels >= 3
|
||||||
|
|
||||||
|
def save(self, path: Union[str, Path], pseudo_rgb: bool = True) -> None:
|
||||||
|
|
||||||
|
if self.channels == 1:
|
||||||
|
if pseudo_rgb:
|
||||||
|
img = get_pseudo_rgb(self.data)
|
||||||
|
print("Image.save", img.shape)
|
||||||
|
else:
|
||||||
|
img = np.repeat(self.data, 3, axis=2)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Only grayscale images are supported for now.")
|
||||||
|
|
||||||
|
imwrite(path, data=img)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""String representation of the Image object."""
|
||||||
|
return (
|
||||||
|
f"Image(path='{self.path.name}', "
|
||||||
|
# Display as HxWxC to match the conventional NumPy shape semantics.
|
||||||
|
f"shape=({self._height}x{self._width}x{self._channels}), "
|
||||||
|
f"format={self._format}, "
|
||||||
|
f"size={self.size_mb:.2f}MB)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""String representation of the Image object."""
|
||||||
|
return self.__repr__()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--path", type=str, required=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
img = Image(args.path)
|
||||||
|
img.save(args.path + "test.tif")
|
||||||
|
print(img)
|
||||||
168
src/utils/image_converters.py
Normal file
168
src/utils/image_converters.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from roifile import ImagejRoi
|
||||||
|
from tifffile import TiffFile, TiffWriter
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
class UT:
|
||||||
|
"""
|
||||||
|
Docstring for UT
|
||||||
|
|
||||||
|
Operetta files along with rois drawn in ImageJ
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, roifile_fn: Path, no_labels: bool):
|
||||||
|
self.roifile_fn = roifile_fn
|
||||||
|
print("is file", self.roifile_fn.is_file())
|
||||||
|
self.rois = None
|
||||||
|
if no_labels:
|
||||||
|
self.rois = ImagejRoi.fromfile(self.roifile_fn)
|
||||||
|
print(self.roifile_fn.stem)
|
||||||
|
print(self.roifile_fn.parent.parts[-1])
|
||||||
|
if "Roi-" in self.roifile_fn.stem:
|
||||||
|
self.stem = self.roifile_fn.stem.split("Roi-")[1]
|
||||||
|
else:
|
||||||
|
self.stem = self.roifile_fn.parent.parts[-1]
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.roifile_fn = roifile_fn / roifile_fn.parts[-1]
|
||||||
|
self.stem = self.roifile_fn.stem
|
||||||
|
|
||||||
|
print(self.roifile_fn)
|
||||||
|
|
||||||
|
print(self.stem)
|
||||||
|
self.image, self.image_props = self._load_images()
|
||||||
|
|
||||||
|
def _load_images(self):
|
||||||
|
"""Loading sequence of tif files
|
||||||
|
array sequence is CZYX
|
||||||
|
"""
|
||||||
|
print("Loading images:", self.roifile_fn.parent, self.stem)
|
||||||
|
fns = list(self.roifile_fn.parent.glob(f"{self.stem.lower()}*.tif*"))
|
||||||
|
stems = [fn.stem.split(self.stem)[-1] for fn in fns]
|
||||||
|
n_ch = len(set([stem.split("-ch")[-1].split("t")[0] for stem in stems]))
|
||||||
|
n_p = len(set([stem.split("-")[0] for stem in stems]))
|
||||||
|
n_t = len(set([stem.split("t")[1] for stem in stems]))
|
||||||
|
|
||||||
|
with TiffFile(fns[0]) as tif:
|
||||||
|
img = tif.asarray()
|
||||||
|
w, h = img.shape
|
||||||
|
dtype = img.dtype
|
||||||
|
self.image_props = {
|
||||||
|
"channels": n_ch,
|
||||||
|
"planes": n_p,
|
||||||
|
"tiles": n_t,
|
||||||
|
"width": w,
|
||||||
|
"height": h,
|
||||||
|
"dtype": dtype,
|
||||||
|
}
|
||||||
|
print("Image props", self.image_props)
|
||||||
|
|
||||||
|
image_stack = np.zeros((n_ch, n_p, w, h), dtype=dtype)
|
||||||
|
for fn in fns:
|
||||||
|
with TiffFile(fn) as tif:
|
||||||
|
img = tif.asarray()
|
||||||
|
stem = fn.stem.split(self.stem)[-1]
|
||||||
|
ch = int(stem.split("-ch")[-1].split("t")[0])
|
||||||
|
p = int(stem.split("-")[0].split("p")[1])
|
||||||
|
t = int(stem.split("t")[1])
|
||||||
|
print(fn.stem, "ch", ch, "p", p, "t", t)
|
||||||
|
image_stack[ch - 1, p - 1] = img
|
||||||
|
|
||||||
|
print(image_stack.shape)
|
||||||
|
|
||||||
|
return image_stack, self.image_props
|
||||||
|
|
||||||
|
@property
|
||||||
|
def width(self):
|
||||||
|
return self.image_props["width"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def height(self):
|
||||||
|
return self.image_props["height"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def nchannels(self):
|
||||||
|
return self.image_props["channels"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def nplanes(self):
|
||||||
|
return self.image_props["planes"]
|
||||||
|
|
||||||
|
def export_rois(
|
||||||
|
self,
|
||||||
|
path: Path,
|
||||||
|
subfolder: str = "labels",
|
||||||
|
class_index: int = 0,
|
||||||
|
):
|
||||||
|
"""Export rois to a file"""
|
||||||
|
with open(path / subfolder / f"{self.stem}.txt", "w") as f:
|
||||||
|
for i, roi in enumerate(self.rois):
|
||||||
|
rc = roi.subpixel_coordinates
|
||||||
|
if rc is None:
|
||||||
|
print(f"No coordinates: {self.roifile_fn}, element {i}, out of {len(self.rois)}")
|
||||||
|
continue
|
||||||
|
xmn, ymn = rc.min(axis=0)
|
||||||
|
xmx, ymx = rc.max(axis=0)
|
||||||
|
xc = (xmn + xmx) / 2
|
||||||
|
yc = (ymn + ymx) / 2
|
||||||
|
bw = xmx - xmn
|
||||||
|
bh = ymx - ymn
|
||||||
|
coords = f"{xc/self.width} {yc/self.height} {bw/self.width} {bh/self.height} "
|
||||||
|
for x, y in rc:
|
||||||
|
coords += f"{x/self.width} {y/self.height} "
|
||||||
|
f.write(f"{class_index} {coords}\n")
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
def export_image(
|
||||||
|
self,
|
||||||
|
path: Path,
|
||||||
|
subfolder: str = "images",
|
||||||
|
plane_mode: str = "max projection",
|
||||||
|
channel: int = 0,
|
||||||
|
):
|
||||||
|
"""Export image to a file"""
|
||||||
|
|
||||||
|
if plane_mode == "max projection":
|
||||||
|
self.image = np.max(self.image[channel], axis=0)
|
||||||
|
print(self.image.shape)
|
||||||
|
|
||||||
|
print(path / subfolder / f"{self.stem}.tif")
|
||||||
|
with TiffWriter(path / subfolder / f"{self.stem}.tif") as tif:
|
||||||
|
tif.write(self.image)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("-i", "--input", nargs="*", type=Path)
|
||||||
|
parser.add_argument("-o", "--output", type=Path)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-labels",
|
||||||
|
action="store_false",
|
||||||
|
help="Source does not have labels, export only images",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# print(args)
|
||||||
|
# aa
|
||||||
|
|
||||||
|
for path in args.input:
|
||||||
|
print("Path:", path)
|
||||||
|
if not args.no_labels:
|
||||||
|
print("No labels")
|
||||||
|
ut = UT(path, args.no_labels)
|
||||||
|
ut.export_image(args.output, plane_mode="max projection", channel=0)
|
||||||
|
|
||||||
|
else:
|
||||||
|
for rfn in Path(path).glob("*.zip"):
|
||||||
|
# if Path(path).suffix == ".zip":
|
||||||
|
print("Roi FN:", rfn)
|
||||||
|
ut = UT(rfn, args.no_labels)
|
||||||
|
ut.export_rois(args.output, class_index=0)
|
||||||
|
ut.export_image(args.output, plane_mode="max projection", channel=0)
|
||||||
|
|
||||||
|
print()
|
||||||
368
src/utils/image_splitter.py
Normal file
368
src/utils/image_splitter.py
Normal file
@@ -0,0 +1,368 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from tifffile import imread, imwrite
|
||||||
|
from shapely.geometry import LineString
|
||||||
|
from copy import deepcopy
|
||||||
|
from scipy.ndimage import zoom
|
||||||
|
|
||||||
|
|
||||||
|
# debug
|
||||||
|
from src.utils.image import Image
|
||||||
|
from show_yolo_seg import draw_annotations
|
||||||
|
|
||||||
|
import pylab as plt
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
|
class Label:
|
||||||
|
def __init__(self, yolo_annotation: str):
|
||||||
|
class_id, bbox, polygon = self.parse_yolo_annotation(yolo_annotation)
|
||||||
|
self.class_id = class_id
|
||||||
|
self.bbox = bbox
|
||||||
|
self.polygon = polygon
|
||||||
|
|
||||||
|
def parse_yolo_annotation(self, yolo_annotation: str):
|
||||||
|
class_id, *coords = yolo_annotation.split()
|
||||||
|
class_id = int(class_id)
|
||||||
|
bbox = np.array(coords[:4], dtype=np.float32)
|
||||||
|
polygon = np.array(coords[4:], dtype=np.float32).reshape(-1, 2) if len(coords) > 4 else None
|
||||||
|
if not any(np.isclose(polygon[0], polygon[-1])):
|
||||||
|
polygon = np.vstack([polygon, polygon[0]])
|
||||||
|
return class_id, bbox, polygon
|
||||||
|
|
||||||
|
def offset_label(
|
||||||
|
self,
|
||||||
|
img_w,
|
||||||
|
img_h,
|
||||||
|
distance: float = 1.0,
|
||||||
|
cap_style: int = 2,
|
||||||
|
join_style: int = 2,
|
||||||
|
):
|
||||||
|
if self.polygon is None:
|
||||||
|
self.bbox = np.array(
|
||||||
|
[
|
||||||
|
self.bbox[0] - distance if self.bbox[0] - distance > 0 else 0,
|
||||||
|
self.bbox[1] - distance if self.bbox[1] - distance > 0 else 0,
|
||||||
|
self.bbox[2] + distance if self.bbox[2] + distance < 1 else 1,
|
||||||
|
self.bbox[3] + distance if self.bbox[3] + distance < 1 else 1,
|
||||||
|
],
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
return self.bbox
|
||||||
|
|
||||||
|
def coords_are_normalized(coords):
|
||||||
|
# If every coordinate is between 0 and 1 (inclusive-ish), assume normalized
|
||||||
|
print(coords)
|
||||||
|
# if not coords:
|
||||||
|
# return False
|
||||||
|
return all(max(coords.flatten)) <= 1.001
|
||||||
|
|
||||||
|
def poly_to_pts(coords, img_w, img_h):
|
||||||
|
# coords: [x1 y1 x2 y2 ...] either normalized or absolute
|
||||||
|
# if coords_are_normalized(coords):
|
||||||
|
coords = [coords[i] * (img_w if i % 2 == 0 else img_h) for i in range(len(coords))]
|
||||||
|
pts = np.array(coords, dtype=np.int32).reshape(-1, 2)
|
||||||
|
return pts
|
||||||
|
|
||||||
|
pts = poly_to_pts(self.polygon, img_w, img_h)
|
||||||
|
line = LineString(pts)
|
||||||
|
# Buffer distance in pixels
|
||||||
|
buffered = line.buffer(distance=distance, cap_style=cap_style, join_style=join_style)
|
||||||
|
self.polygon = np.array(buffered.exterior.coords, dtype=np.float32) / (img_w, img_h)
|
||||||
|
xmn, ymn = self.polygon.min(axis=0)
|
||||||
|
xmx, ymx = self.polygon.max(axis=0)
|
||||||
|
xc = (xmn + xmx) / 2
|
||||||
|
yc = (ymn + ymx) / 2
|
||||||
|
bw = xmx - xmn
|
||||||
|
bh = ymx - ymn
|
||||||
|
self.bbox = np.array([xc, yc, bw, bh], dtype=np.float32)
|
||||||
|
|
||||||
|
return self.bbox, self.polygon
|
||||||
|
|
||||||
|
def translate(self, x, y, scale_x, scale_y):
|
||||||
|
self.bbox[0] -= x
|
||||||
|
self.bbox[0] *= scale_x
|
||||||
|
self.bbox[1] -= y
|
||||||
|
self.bbox[1] *= scale_y
|
||||||
|
self.bbox[2] *= scale_x
|
||||||
|
self.bbox[3] *= scale_y
|
||||||
|
if self.polygon is not None:
|
||||||
|
self.polygon[:, 0] -= x
|
||||||
|
self.polygon[:, 0] *= scale_x
|
||||||
|
self.polygon[:, 1] -= y
|
||||||
|
self.polygon[:, 1] *= scale_y
|
||||||
|
|
||||||
|
def in_range(self, hrange, wrange):
|
||||||
|
xc, yc, h, w = self.bbox
|
||||||
|
x1 = xc - w / 2
|
||||||
|
y1 = yc - h / 2
|
||||||
|
x2 = xc + w / 2
|
||||||
|
y2 = yc + h / 2
|
||||||
|
truth_val = (
|
||||||
|
xc >= wrange[0]
|
||||||
|
and x1 <= wrange[1]
|
||||||
|
and x2 >= wrange[0]
|
||||||
|
and x2 <= wrange[1]
|
||||||
|
and y1 >= hrange[0]
|
||||||
|
and y1 <= hrange[1]
|
||||||
|
and y2 >= hrange[0]
|
||||||
|
and y2 <= hrange[1]
|
||||||
|
)
|
||||||
|
|
||||||
|
print(x1, x2, wrange, y1, y2, hrange, truth_val)
|
||||||
|
return truth_val
|
||||||
|
|
||||||
|
def to_string(self, bbox: list = None, polygon: list = None):
|
||||||
|
coords = ""
|
||||||
|
if bbox is None:
|
||||||
|
bbox = self.bbox
|
||||||
|
# coords += " ".join([f"{x:.6f}" for x in self.bbox])
|
||||||
|
if polygon is None:
|
||||||
|
polygon = self.polygon
|
||||||
|
if self.polygon is not None:
|
||||||
|
coords += " " + " ".join([f"{x:.6f} {y:.6f}" for x, y in self.polygon])
|
||||||
|
return f"{self.class_id} {coords}"
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"Class: {self.class_id}, BBox: {self.bbox}, Polygon: {self.polygon}"
|
||||||
|
|
||||||
|
|
||||||
|
class YoloLabelReader:
|
||||||
|
def __init__(self, label_path: Path):
|
||||||
|
self.label_path = label_path
|
||||||
|
self.labels = self._read_labels()
|
||||||
|
|
||||||
|
def _read_labels(self):
|
||||||
|
with open(self.label_path, "r") as f:
|
||||||
|
labels = [Label(line) for line in f.readlines()]
|
||||||
|
|
||||||
|
return labels
|
||||||
|
|
||||||
|
def get_labels(self, hrange, wrange):
|
||||||
|
"""hrange and wrange are tuples of (start, end) normalized to [0, 1]"""
|
||||||
|
labels = []
|
||||||
|
# print(hrange, wrange)
|
||||||
|
for lbl in self.labels:
|
||||||
|
# print(lbl)
|
||||||
|
if lbl.in_range(hrange, wrange):
|
||||||
|
labels.append(lbl)
|
||||||
|
return labels if len(labels) > 0 else None
|
||||||
|
|
||||||
|
def __get_item__(self, index):
|
||||||
|
return self.labels[index]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.labels)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.labels)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageSplitter:
|
||||||
|
def __init__(self, image_path: Path, label_path: Path):
|
||||||
|
self.image = imread(image_path)
|
||||||
|
self.image_path = image_path
|
||||||
|
self.label_path = label_path
|
||||||
|
if not label_path.exists():
|
||||||
|
print(f"Label file {label_path} not found")
|
||||||
|
self.labels = None
|
||||||
|
else:
|
||||||
|
self.labels = YoloLabelReader(label_path)
|
||||||
|
|
||||||
|
def split_into_tiles(self, patch_size: tuple = (2, 2)):
|
||||||
|
"""Split image into patches of size patch_size"""
|
||||||
|
hstep, wstep = (
|
||||||
|
self.image.shape[0] // patch_size[0],
|
||||||
|
self.image.shape[1] // patch_size[1],
|
||||||
|
)
|
||||||
|
h, w = self.image.shape[:2]
|
||||||
|
|
||||||
|
for i in range(patch_size[0]):
|
||||||
|
for j in range(patch_size[1]):
|
||||||
|
metadata = {
|
||||||
|
"image_path": str(self.image_path),
|
||||||
|
"label_path": str(self.label_path),
|
||||||
|
"tile_section": f"{i}, {j}",
|
||||||
|
"tile_size": f"{hstep}, {wstep}",
|
||||||
|
"patch_size": f"{patch_size[0]}, {patch_size[1]}",
|
||||||
|
}
|
||||||
|
tile_reference = f"i{i}j{j}"
|
||||||
|
hrange = (i * hstep / h, (i + 1) * hstep / h)
|
||||||
|
wrange = (j * wstep / w, (j + 1) * wstep / w)
|
||||||
|
tile = self.image[i * hstep : (i + 1) * hstep, j * wstep : (j + 1) * wstep]
|
||||||
|
|
||||||
|
labels = None
|
||||||
|
if self.labels is not None:
|
||||||
|
labels = deepcopy(self.labels.get_labels(hrange, wrange))
|
||||||
|
print(id(labels))
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
print(hrange[0], wrange[0])
|
||||||
|
for l in labels:
|
||||||
|
print(l.bbox)
|
||||||
|
[l.translate(wrange[0], hrange[0], 2, 2) for l in labels]
|
||||||
|
print("translated")
|
||||||
|
for l in labels:
|
||||||
|
print(l.bbox)
|
||||||
|
|
||||||
|
# print(labels)
|
||||||
|
yield tile_reference, tile, labels, metadata
|
||||||
|
|
||||||
|
def split_respective_to_label(self, padding: int = 67):
|
||||||
|
if self.labels is None:
|
||||||
|
raise ValueError("No labels found. Only images having labels can be split.")
|
||||||
|
|
||||||
|
for i, label in enumerate(self.labels):
|
||||||
|
tile_reference = f"_lbl-{i+1:02d}"
|
||||||
|
# print(label.bbox)
|
||||||
|
metadata = {"image_path": str(self.image_path), "label_path": str(self.label_path), "label_index": str(i)}
|
||||||
|
|
||||||
|
xc_norm, yc_norm, h_norm, w_norm = label.bbox # normalized coords
|
||||||
|
xc, yc, h, w = [
|
||||||
|
int(np.round(f))
|
||||||
|
for f in [
|
||||||
|
xc_norm * self.image.shape[1],
|
||||||
|
yc_norm * self.image.shape[0],
|
||||||
|
h_norm * self.image.shape[0],
|
||||||
|
w_norm * self.image.shape[1],
|
||||||
|
]
|
||||||
|
] # image coords
|
||||||
|
|
||||||
|
# print("img coords:", xc, yc, h, w)
|
||||||
|
pad_xneg = padding + 1 # int(w / 2) + padding
|
||||||
|
pad_xpos = padding # int(w / 2) + padding
|
||||||
|
pad_yneg = padding + 1 # int(h / 2) + padding
|
||||||
|
pad_ypos = padding # int(h / 2) + padding
|
||||||
|
if xc - pad_xneg < 0:
|
||||||
|
pad_xneg = xc
|
||||||
|
if pad_xpos + xc > self.image.shape[1]:
|
||||||
|
pad_xpos = self.image.shape[1] - xc
|
||||||
|
if yc - pad_yneg < 0:
|
||||||
|
pad_yneg = yc
|
||||||
|
if pad_ypos + yc > self.image.shape[0]:
|
||||||
|
pad_ypos = self.image.shape[0] - yc
|
||||||
|
|
||||||
|
# print("pads:", pad_xneg, pad_xpos, pad_yneg, pad_ypos)
|
||||||
|
|
||||||
|
tile = self.image[
|
||||||
|
yc - pad_yneg : yc + pad_ypos,
|
||||||
|
xc - pad_xneg : xc + pad_xpos,
|
||||||
|
]
|
||||||
|
ny, nx = tile.shape
|
||||||
|
x_offset = pad_xneg
|
||||||
|
y_offset = pad_yneg
|
||||||
|
|
||||||
|
# print("tile shape:", tile.shape)
|
||||||
|
|
||||||
|
yolo_annotation = f"{label.class_id} " # {x_offset/nx} {y_offset/ny} {h_norm} {w_norm} "
|
||||||
|
yolo_annotation += " ".join(
|
||||||
|
[
|
||||||
|
f"{(x*self.image.shape[1]-(xc - x_offset))/nx:.6f} {(y*self.image.shape[0]-(yc-y_offset))/ny:.6f}"
|
||||||
|
for x, y in label.polygon
|
||||||
|
]
|
||||||
|
)
|
||||||
|
print(yolo_annotation)
|
||||||
|
new_label = Label(yolo_annotation=yolo_annotation)
|
||||||
|
|
||||||
|
yield tile_reference, tile, [new_label], metadata
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
|
||||||
|
if args.output:
|
||||||
|
args.output.mkdir(exist_ok=True, parents=True)
|
||||||
|
(args.output / "images").mkdir(exist_ok=True)
|
||||||
|
(args.output / "images-zoomed").mkdir(exist_ok=True)
|
||||||
|
(args.output / "labels").mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
for image_path in (args.input / "images").glob("*.tif"):
|
||||||
|
data = ImageSplitter(
|
||||||
|
image_path=image_path,
|
||||||
|
label_path=(args.input / "labels" / image_path.stem).with_suffix(".txt"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.split_around_label:
|
||||||
|
data = data.split_respective_to_label(padding=args.padding)
|
||||||
|
else:
|
||||||
|
data = data.split_into_tiles(patch_size=args.patch_size)
|
||||||
|
|
||||||
|
for tile_reference, tile, labels, metadata in data:
|
||||||
|
print()
|
||||||
|
print(tile_reference, tile.shape, labels, metadata) # len(labels) if labels else None)
|
||||||
|
|
||||||
|
# { debug
|
||||||
|
debug = False
|
||||||
|
if debug:
|
||||||
|
plt.figure(figsize=(10, 10 * tile.shape[0] / tile.shape[1]))
|
||||||
|
if labels is None:
|
||||||
|
plt.imshow(tile, cmap="gray")
|
||||||
|
plt.axis("off")
|
||||||
|
plt.title(f"{image_path.name} ({tile_reference})")
|
||||||
|
plt.show()
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(labels[0].bbox)
|
||||||
|
# Draw annotations
|
||||||
|
out = draw_annotations(
|
||||||
|
cv2.cvtColor((tile / tile.max() * 255).astype(np.uint8), cv2.COLOR_GRAY2BGR),
|
||||||
|
[l.to_string() for l in labels],
|
||||||
|
alpha=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert BGR -> RGB for matplotlib display
|
||||||
|
out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
|
||||||
|
plt.imshow(out_rgb)
|
||||||
|
plt.axis("off")
|
||||||
|
plt.title(f"{image_path.name} ({tile_reference})")
|
||||||
|
plt.show()
|
||||||
|
# } debug
|
||||||
|
|
||||||
|
if args.output:
|
||||||
|
# imwrite(args.output / "images" / f"{image_path.stem}_{tile_reference}.tif", tile, metadata=metadata)
|
||||||
|
scale = 5
|
||||||
|
tile_zoomed = zoom(tile, zoom=scale)
|
||||||
|
metadata["scale"] = scale
|
||||||
|
imwrite(
|
||||||
|
args.output / "images" / f"{image_path.stem}_{tile_reference}.tif",
|
||||||
|
tile_zoomed,
|
||||||
|
metadata=metadata,
|
||||||
|
imagej=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
with open(args.output / "labels" / f"{image_path.stem}_{tile_reference}.txt", "w") as f:
|
||||||
|
for label in labels:
|
||||||
|
# label.offset_label(tile.shape[1], tile.shape[0])
|
||||||
|
f.write(label.to_string() + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("-i", "--input", type=Path)
|
||||||
|
parser.add_argument("-o", "--output", type=Path)
|
||||||
|
parser.add_argument(
|
||||||
|
"-p",
|
||||||
|
"--patch-size",
|
||||||
|
nargs=2,
|
||||||
|
type=int,
|
||||||
|
default=[2, 2],
|
||||||
|
help="Number of patches along height and width, rows and columns, respectively",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-sal",
|
||||||
|
"--split-around-label",
|
||||||
|
action="store_true",
|
||||||
|
help="If enabled, the image will be split around the label and for each label, a separate image will be created.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--padding",
|
||||||
|
type=int,
|
||||||
|
default=67,
|
||||||
|
help="Padding around the label when splitting around the label.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
||||||
1
src/utils/show_yolo_seg.py
Symbolic link
1
src/utils/show_yolo_seg.py
Symbolic link
@@ -0,0 +1 @@
|
|||||||
|
../../tests/show_yolo_seg.py
|
||||||
157
src/utils/ultralytics_16bit_patch.py
Normal file
157
src/utils/ultralytics_16bit_patch.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
"""Ultralytics runtime patches for 16-bit TIFF training.
|
||||||
|
|
||||||
|
Goals:
|
||||||
|
- Use `tifffile` to decode `.tif/.tiff` reliably (OpenCV can silently drop bit-depth depending on codec).
|
||||||
|
- Preserve 16-bit data through the dataloader as `uint16` tensors.
|
||||||
|
- Fix Ultralytics trainer normalization (default divides by 255) to scale `uint16` correctly.
|
||||||
|
- Avoid uint8-forcing augmentations by recommending/setting hyp values (handled by caller).
|
||||||
|
|
||||||
|
This module is intended to be imported/called **before** instantiating/using YOLO.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from src.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
|
||||||
|
"""Apply runtime monkey-patches to Ultralytics to better support 16-bit TIFFs.
|
||||||
|
|
||||||
|
This function is safe to call multiple times.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
force: If True, re-apply patches even if already applied.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Import inside function to ensure patching occurs before YOLO model/dataset is created.
|
||||||
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# import tifffile
|
||||||
|
import torch
|
||||||
|
from src.utils.image import Image
|
||||||
|
|
||||||
|
from ultralytics.utils import patches as ul_patches
|
||||||
|
|
||||||
|
already_patched = getattr(ul_patches.imread, "__name__", "") == "tifffile_imread"
|
||||||
|
if already_patched and not force:
|
||||||
|
return
|
||||||
|
|
||||||
|
_original_imread = ul_patches.imread
|
||||||
|
|
||||||
|
def tifffile_imread(filename: str, flags: int = cv2.IMREAD_COLOR, pseudo_rgb: bool = True) -> Optional[np.ndarray]:
|
||||||
|
"""Replacement for [`ultralytics.utils.patches.imread()`](venv/lib/python3.12/site-packages/ultralytics/utils/patches.py:20).
|
||||||
|
|
||||||
|
- For `.tif/.tiff`, uses `tifffile.imread()` and preserves dtype (e.g. uint16).
|
||||||
|
- For other formats, falls back to Ultralytics' original implementation.
|
||||||
|
- Always returns HWC (3 dims). For grayscale, returns (H, W, 1) or (H, W, 3) depending on requested flags.
|
||||||
|
"""
|
||||||
|
# print("here")
|
||||||
|
# return _original_imread(filename, flags)
|
||||||
|
ext = os.path.splitext(filename)[1].lower()
|
||||||
|
if ext in (".tif", ".tiff"):
|
||||||
|
arr = Image(filename).get_qt_rgb()[:, :, :3]
|
||||||
|
|
||||||
|
# Normalize common shapes:
|
||||||
|
# - (H, W) -> (H, W, 1)
|
||||||
|
# - (C, H, W) -> (H, W, C) (heuristic)
|
||||||
|
if arr is None:
|
||||||
|
return None
|
||||||
|
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[0] < arr.shape[1]:
|
||||||
|
arr = np.transpose(arr, (1, 2, 0))
|
||||||
|
if arr.ndim == 2:
|
||||||
|
arr = arr[..., None]
|
||||||
|
|
||||||
|
# Ensure contiguous array for downstream OpenCV ops.
|
||||||
|
# logger.info(f"Loading with monkey-patched imread: {filename}")
|
||||||
|
arr = arr.astype(np.float32)
|
||||||
|
arr /= arr.max()
|
||||||
|
arr *= 2**8 - 1
|
||||||
|
arr = arr.astype(np.uint8)
|
||||||
|
# print(arr.shape, arr.dtype, any(np.isnan(arr).flatten()), np.where(np.isnan(arr)), arr.min(), arr.max())
|
||||||
|
return np.ascontiguousarray(arr)
|
||||||
|
|
||||||
|
# logger.info(f"Loading with original imread: {filename}")
|
||||||
|
return _original_imread(filename, flags)
|
||||||
|
|
||||||
|
# Patch the canonical reference.
|
||||||
|
ul_patches.imread = tifffile_imread
|
||||||
|
|
||||||
|
# Patch common module-level imports (some Ultralytics modules do `from ... import imread`).
|
||||||
|
# Importing these modules is safe and helps ensure the patched function is used.
|
||||||
|
try:
|
||||||
|
import ultralytics.data.base as _ul_base
|
||||||
|
|
||||||
|
_ul_base.imread = tifffile_imread
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
import ultralytics.data.loaders as _ul_loaders
|
||||||
|
|
||||||
|
_ul_loaders.imread = tifffile_imread
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Patch trainer normalization: default divides by 255 regardless of input dtype.
|
||||||
|
from ultralytics.models.yolo.detect import train as detect_train
|
||||||
|
|
||||||
|
_orig_preprocess_batch = detect_train.DetectionTrainer.preprocess_batch
|
||||||
|
|
||||||
|
def preprocess_batch_16bit(self, batch: dict) -> dict: # type: ignore[override]
|
||||||
|
# Start from upstream behavior to keep device placement + multiscale identical,
|
||||||
|
# but replace the 255 division with dtype-aware scaling.
|
||||||
|
# logger.info(f"Preprocessing batch with monkey-patched preprocess_batch")
|
||||||
|
for k, v in batch.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")
|
||||||
|
|
||||||
|
img = batch.get("img")
|
||||||
|
if isinstance(img, torch.Tensor):
|
||||||
|
# Decide scaling denom based on dtype (avoid expensive reductions if possible).
|
||||||
|
if img.dtype == torch.uint8:
|
||||||
|
denom = 255.0
|
||||||
|
elif img.dtype == torch.uint16:
|
||||||
|
denom = 65535.0
|
||||||
|
elif img.dtype.is_floating_point:
|
||||||
|
# Assume already in 0-1 range if float.
|
||||||
|
denom = 1.0
|
||||||
|
else:
|
||||||
|
# Generic integer fallback.
|
||||||
|
try:
|
||||||
|
denom = float(torch.iinfo(img.dtype).max)
|
||||||
|
except Exception:
|
||||||
|
denom = 255.0
|
||||||
|
|
||||||
|
batch["img"] = img.float() / denom
|
||||||
|
|
||||||
|
# Multi-scale branch copied from upstream to avoid re-introducing `/255` scaling.
|
||||||
|
if getattr(self.args, "multi_scale", False):
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
imgs = batch["img"]
|
||||||
|
sz = (
|
||||||
|
random.randrange(int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride))
|
||||||
|
// self.stride
|
||||||
|
* self.stride
|
||||||
|
)
|
||||||
|
sf = sz / max(imgs.shape[2:])
|
||||||
|
if sf != 1:
|
||||||
|
ns = [math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]]
|
||||||
|
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
|
||||||
|
batch["img"] = imgs
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
detect_train.DetectionTrainer.preprocess_batch = preprocess_batch_16bit
|
||||||
|
|
||||||
|
# Tag function to make it easier to detect patch state.
|
||||||
|
setattr(detect_train.DetectionTrainer.preprocess_batch, "_ultralytics_16bit_patch", True)
|
||||||
231
tests/show_yolo_seg.py
Normal file
231
tests/show_yolo_seg.py
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
show_yolo_seg.py
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python show_yolo_seg.py /path/to/image.jpg /path/to/labels.txt
|
||||||
|
|
||||||
|
Supports:
|
||||||
|
- Segmentation polygons: "class x1 y1 x2 y2 ... xn yn"
|
||||||
|
- YOLO bbox lines as fallback: "class x_center y_center width height"
|
||||||
|
Coordinates can be normalized [0..1] or absolute pixels (auto-detected).
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
import random
|
||||||
|
from shapely.geometry import LineString
|
||||||
|
|
||||||
|
from src.utils.image import Image
|
||||||
|
|
||||||
|
|
||||||
|
def parse_label_line(line):
|
||||||
|
parts = line.strip().split()
|
||||||
|
if not parts:
|
||||||
|
return None
|
||||||
|
cls = int(float(parts[0]))
|
||||||
|
coords = [float(x) for x in parts[1:]]
|
||||||
|
return cls, coords
|
||||||
|
|
||||||
|
|
||||||
|
def coords_are_normalized(coords):
|
||||||
|
# If every coordinate is between 0 and 1 (inclusive-ish), assume normalized
|
||||||
|
if not coords:
|
||||||
|
return False
|
||||||
|
return max(coords) <= 1.001
|
||||||
|
|
||||||
|
|
||||||
|
def yolo_bbox_to_xyxy(coords, img_w, img_h):
|
||||||
|
# coords: [xc, yc, w, h] normalized or absolute
|
||||||
|
xc, yc, w, h = coords[:4]
|
||||||
|
if max(coords) <= 1.001:
|
||||||
|
xc *= img_w
|
||||||
|
yc *= img_h
|
||||||
|
w *= img_w
|
||||||
|
h *= img_h
|
||||||
|
x1 = int(round(xc - w / 2))
|
||||||
|
y1 = int(round(yc - h / 2))
|
||||||
|
x2 = int(round(xc + w / 2))
|
||||||
|
y2 = int(round(yc + h / 2))
|
||||||
|
return x1, y1, x2, y2
|
||||||
|
|
||||||
|
|
||||||
|
def poly_to_pts(coords, img_w, img_h):
|
||||||
|
# coords: [x1 y1 x2 y2 ...] either normalized or absolute
|
||||||
|
if coords_are_normalized(coords[4:]):
|
||||||
|
coords = [coords[i] * (img_w if i % 2 == 0 else img_h) for i in range(len(coords))]
|
||||||
|
pts = np.array(coords, dtype=np.int32).reshape(-1, 2)
|
||||||
|
return pts
|
||||||
|
|
||||||
|
|
||||||
|
def random_color_for_class(cls):
|
||||||
|
random.seed(cls) # deterministic per class
|
||||||
|
return (
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
255,
|
||||||
|
) # tuple(int(x) for x in np.array([random.randint(0, 255) for _ in range(3)]))
|
||||||
|
|
||||||
|
|
||||||
|
def draw_annotations(img, labels, alpha=0.4, draw_bbox_for_poly=True):
|
||||||
|
# img: BGR numpy array
|
||||||
|
overlay = img.copy()
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
for line in labels:
|
||||||
|
if isinstance(line, str):
|
||||||
|
cls, coords = parse_label_line(line)
|
||||||
|
if isinstance(line, tuple):
|
||||||
|
cls, coords = line
|
||||||
|
|
||||||
|
if not coords:
|
||||||
|
continue
|
||||||
|
# polygon case (>=6 coordinates)
|
||||||
|
if len(coords) >= 6:
|
||||||
|
color = random_color_for_class(cls)
|
||||||
|
|
||||||
|
x1, y1, x2, y2 = yolo_bbox_to_xyxy(coords[:4], w, h)
|
||||||
|
print(x1, y1, x2, y2)
|
||||||
|
cv2.rectangle(img, (x1, y1), (x2, y2), color, 1)
|
||||||
|
|
||||||
|
pts = poly_to_pts(coords[4:], w, h)
|
||||||
|
# line = LineString(pts)
|
||||||
|
# # Buffer distance in pixels
|
||||||
|
# buffered = line.buffer(3, cap_style=2, join_style=2)
|
||||||
|
# coords = np.array(buffered.exterior.coords, dtype=np.int32)
|
||||||
|
# cv2.fillPoly(overlay, [coords], color=(255, 255, 255))
|
||||||
|
|
||||||
|
# fill on overlay
|
||||||
|
cv2.fillPoly(overlay, [pts], color)
|
||||||
|
# outline on base image
|
||||||
|
cv2.polylines(img, [pts], isClosed=True, color=color, thickness=1)
|
||||||
|
# put class text at first point
|
||||||
|
x, y = int(pts[0, 0]), int(pts[0, 1]) - 6
|
||||||
|
if 0:
|
||||||
|
cv2.putText(
|
||||||
|
img,
|
||||||
|
str(cls),
|
||||||
|
(x, max(6, y)),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
0.6,
|
||||||
|
(255, 255, 255),
|
||||||
|
2,
|
||||||
|
cv2.LINE_AA,
|
||||||
|
)
|
||||||
|
|
||||||
|
# YOLO bbox case (4 coords)
|
||||||
|
elif len(coords) == 4:
|
||||||
|
x1, y1, x2, y2 = yolo_bbox_to_xyxy(coords, w, h)
|
||||||
|
color = random_color_for_class(cls)
|
||||||
|
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
|
||||||
|
cv2.putText(
|
||||||
|
img,
|
||||||
|
str(cls),
|
||||||
|
(x1, max(6, y1 - 4)),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
0.6,
|
||||||
|
(255, 255, 255),
|
||||||
|
2,
|
||||||
|
cv2.LINE_AA,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Unknown / invalid format, skip
|
||||||
|
continue
|
||||||
|
|
||||||
|
# blend overlay for filled polygons
|
||||||
|
cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0, img)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def load_labels_file(label_path):
|
||||||
|
labels = []
|
||||||
|
with open(label_path, "r") as f:
|
||||||
|
for raw in f:
|
||||||
|
line = raw.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
parsed = parse_label_line(line)
|
||||||
|
if parsed:
|
||||||
|
labels.append(parsed)
|
||||||
|
return labels
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Show YOLO segmentation / polygon annotations")
|
||||||
|
parser.add_argument("image", type=str, help="Path to image file")
|
||||||
|
parser.add_argument("--labels", type=str, help="Path to YOLO label file (polygons)")
|
||||||
|
parser.add_argument("--alpha", type=float, default=0.4, help="Polygon fill alpha (0..1)")
|
||||||
|
parser.add_argument("--no-bbox", action="store_true", help="Don't draw bounding boxes for polygons")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
img_path = Path(args.image)
|
||||||
|
if args.labels:
|
||||||
|
lbl_path = Path(args.labels)
|
||||||
|
else:
|
||||||
|
lbl_path = img_path.with_suffix(".txt")
|
||||||
|
lbl_path = Path(str(lbl_path).replace("images", "labels"))
|
||||||
|
|
||||||
|
if not img_path.exists():
|
||||||
|
print("Image not found:", img_path)
|
||||||
|
sys.exit(1)
|
||||||
|
if not lbl_path.exists():
|
||||||
|
print("Label file not found:", lbl_path)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# img = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
|
||||||
|
img = (Image(img_path).get_qt_rgb() * 255).astype(np.uint8)
|
||||||
|
|
||||||
|
if img is None:
|
||||||
|
print("Could not load image:", img_path)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
labels = load_labels_file(str(lbl_path))
|
||||||
|
if not labels:
|
||||||
|
print("No labels parsed from", lbl_path)
|
||||||
|
# continue and just show image
|
||||||
|
out = draw_annotations(img.copy(), labels, alpha=args.alpha, draw_bbox_for_poly=(not args.no_bbox))
|
||||||
|
|
||||||
|
out_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
plt.figure(figsize=(10, 10 * out.shape[0] / out.shape[1]))
|
||||||
|
if 0:
|
||||||
|
plt.imshow(out_rgb.transpose(1, 0, 2))
|
||||||
|
else:
|
||||||
|
plt.imshow(out_rgb)
|
||||||
|
|
||||||
|
for label in labels:
|
||||||
|
lclass, coords = label
|
||||||
|
# print(lclass, coords)
|
||||||
|
bbox = coords[:4]
|
||||||
|
# print("bbox", bbox)
|
||||||
|
bbox = np.array(bbox) * np.array([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
|
||||||
|
yc, xc, h, w = bbox
|
||||||
|
# print("bbox", bbox)
|
||||||
|
|
||||||
|
# polyline = np.array(coords[4:]).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
|
||||||
|
polyline = np.array(coords).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
|
||||||
|
# print("pl", coords[4:])
|
||||||
|
# print("pl", polyline)
|
||||||
|
|
||||||
|
# Convert BGR -> RGB for matplotlib display
|
||||||
|
# out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
|
||||||
|
# out_rgb = Image()
|
||||||
|
plt.plot(polyline[:, 0], polyline[:, 1], "y", linewidth=2)
|
||||||
|
if 0:
|
||||||
|
plt.plot(
|
||||||
|
[yc - h / 2, yc - h / 2, yc + h / 2, yc + h / 2, yc - h / 2],
|
||||||
|
[xc - w / 2, xc + w / 2, xc + w / 2, xc - w / 2, xc - w / 2],
|
||||||
|
"r",
|
||||||
|
linewidth=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# plt.axis("off")
|
||||||
|
plt.title(f"{img_path.name} ({lbl_path.name})")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
145
tests/test_image.py
Normal file
145
tests/test_image.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
"""
|
||||||
|
Tests for the Image class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
from src.utils.image import Image, ImageLoadError
|
||||||
|
|
||||||
|
|
||||||
|
class TestImage:
|
||||||
|
"""Test cases for the Image class."""
|
||||||
|
|
||||||
|
def test_load_nonexistent_file(self):
|
||||||
|
"""Test loading a non-existent file raises ImageLoadError."""
|
||||||
|
with pytest.raises(ImageLoadError):
|
||||||
|
Image("nonexistent_file.jpg")
|
||||||
|
|
||||||
|
def test_load_unsupported_format(self, tmp_path):
|
||||||
|
"""Test loading an unsupported format raises ImageLoadError."""
|
||||||
|
# Create a dummy file with unsupported extension
|
||||||
|
test_file = tmp_path / "test.txt"
|
||||||
|
test_file.write_text("not an image")
|
||||||
|
|
||||||
|
with pytest.raises(ImageLoadError):
|
||||||
|
Image(test_file)
|
||||||
|
|
||||||
|
def test_supported_extensions(self):
|
||||||
|
"""Test that supported extensions are correctly defined."""
|
||||||
|
expected_extensions = Image.SUPPORTED_EXTENSIONS
|
||||||
|
assert Image.SUPPORTED_EXTENSIONS == expected_extensions
|
||||||
|
|
||||||
|
def test_image_properties(self, tmp_path):
|
||||||
|
"""Test image properties after loading."""
|
||||||
|
# Create a simple test image using numpy and cv2
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
test_img = np.zeros((100, 200, 3), dtype=np.uint8)
|
||||||
|
test_img[:, :] = [255, 0, 0] # Blue in BGR
|
||||||
|
|
||||||
|
test_file = tmp_path / "test.jpg"
|
||||||
|
cv2.imwrite(str(test_file), test_img)
|
||||||
|
|
||||||
|
# Load the image
|
||||||
|
img = Image(test_file)
|
||||||
|
|
||||||
|
# Check properties
|
||||||
|
assert img.width == 200
|
||||||
|
assert img.height == 100
|
||||||
|
assert img.channels == 3
|
||||||
|
assert img.format == "jpg"
|
||||||
|
assert img.shape == (100, 200, 3)
|
||||||
|
assert img.size_bytes > 0
|
||||||
|
assert img.is_color()
|
||||||
|
assert not img.is_grayscale()
|
||||||
|
|
||||||
|
def test_get_rgb(self, tmp_path):
|
||||||
|
"""Test RGB conversion."""
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
# Create BGR image
|
||||||
|
test_img = np.zeros((50, 50, 3), dtype=np.uint8)
|
||||||
|
test_img[:, :] = [255, 0, 0] # Blue in BGR
|
||||||
|
|
||||||
|
test_file = tmp_path / "test_rgb.png"
|
||||||
|
cv2.imwrite(str(test_file), test_img)
|
||||||
|
|
||||||
|
img = Image(test_file)
|
||||||
|
rgb_data = img.get_rgb()
|
||||||
|
|
||||||
|
# RGB should have red channel at 255
|
||||||
|
assert rgb_data[0, 0, 0] == 0 # R
|
||||||
|
assert rgb_data[0, 0, 1] == 0 # G
|
||||||
|
assert rgb_data[0, 0, 2] == 255 # B (was BGR blue)
|
||||||
|
|
||||||
|
def test_get_grayscale(self, tmp_path):
|
||||||
|
"""Test grayscale conversion."""
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
test_img = np.zeros((50, 50, 3), dtype=np.uint8)
|
||||||
|
test_img[:, :] = [128, 128, 128]
|
||||||
|
|
||||||
|
test_file = tmp_path / "test_gray.png"
|
||||||
|
cv2.imwrite(str(test_file), test_img)
|
||||||
|
|
||||||
|
img = Image(test_file)
|
||||||
|
gray_data = img.get_grayscale()
|
||||||
|
|
||||||
|
assert len(gray_data.shape) == 2 # Should be 2D
|
||||||
|
assert gray_data.shape == (50, 50)
|
||||||
|
|
||||||
|
def test_copy(self, tmp_path):
|
||||||
|
"""Test copying image data."""
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
test_img = np.zeros((50, 50, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
test_file = tmp_path / "test_copy.png"
|
||||||
|
cv2.imwrite(str(test_file), test_img)
|
||||||
|
|
||||||
|
img = Image(test_file)
|
||||||
|
copied = img.copy()
|
||||||
|
|
||||||
|
# Modify copy
|
||||||
|
copied[0, 0] = [255, 255, 255]
|
||||||
|
|
||||||
|
# Original should be unchanged
|
||||||
|
assert not np.array_equal(img.data[0, 0], copied[0, 0])
|
||||||
|
|
||||||
|
def test_resize(self, tmp_path):
|
||||||
|
"""Test image resizing."""
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
test_img = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
test_file = tmp_path / "test_resize.png"
|
||||||
|
cv2.imwrite(str(test_file), test_img)
|
||||||
|
|
||||||
|
img = Image(test_file)
|
||||||
|
resized = img.resize(50, 50)
|
||||||
|
|
||||||
|
assert resized.shape == (50, 50, 3)
|
||||||
|
# Original should be unchanged
|
||||||
|
assert img.width == 100
|
||||||
|
assert img.height == 100
|
||||||
|
|
||||||
|
def test_str_repr(self, tmp_path):
|
||||||
|
"""Test string representation."""
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
test_img = np.zeros((100, 200, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
test_file = tmp_path / "test_str.jpg"
|
||||||
|
cv2.imwrite(str(test_file), test_img)
|
||||||
|
|
||||||
|
img = Image(test_file)
|
||||||
|
|
||||||
|
str_repr = str(img)
|
||||||
|
assert "test_str.jpg" in str_repr
|
||||||
|
assert "100x200x3" in str_repr
|
||||||
|
assert "jpg" in str_repr
|
||||||
|
|
||||||
|
repr_str = repr(img)
|
||||||
|
assert "Image" in repr_str
|
||||||
|
assert "test_str.jpg" in repr_str
|
||||||
Reference in New Issue
Block a user