71 Commits

Author SHA1 Message Date
d03ffdc4d0 Adding annotation list 2026-01-16 14:42:08 +02:00
8d30e6bb7a Adding auto zoom for annotation image view and changing tab order 2026-01-16 14:20:12 +02:00
f810fec4d8 Adding model deletion feature from database 2026-01-16 14:13:40 +02:00
9c8931e6f3 Finish validation tab 2026-01-16 13:58:02 +02:00
20578c1fdf Adding a file and feature to delete all detections from database 2026-01-16 13:43:05 +02:00
2c494dac49 Adding export for labels in results 2026-01-16 11:15:12 +02:00
506c74e53a Small update 2026-01-16 10:39:46 +02:00
eefda5b878 Adding metdata to tiled images 2026-01-16 10:39:14 +02:00
31cb6a6c8e Using 8bit images 2026-01-16 10:38:34 +02:00
0c19ea2557 Updating 2026-01-16 10:30:13 +02:00
89e47591db Formatting 2026-01-16 10:27:15 +02:00
69cde09e53 Changing alpha value 2026-01-16 10:26:25 +02:00
fcbd5fb16d correcting label writing and formatting code 2026-01-16 10:24:19 +02:00
ca52312925 Adding LIKE option for filtering queries 2026-01-16 10:18:48 +02:00
0a93bf797a Adding auto zoom when result is loaded 2026-01-12 14:15:02 +02:00
d998c65665 Updating image splitter 2026-01-12 13:28:00 +02:00
510eabfa94 Adding splitter method 2026-01-05 13:56:57 +02:00
395d263900 Update 2026-01-05 08:59:36 +02:00
e98d287b8a Updating tiff image patch 2026-01-02 12:44:06 +02:00
d25101de2d adding files 2026-01-02 12:40:44 +02:00
f88beef188 Another test 2025-12-19 13:50:49 +02:00
2fd9a2acf4 RGB 2025-12-19 13:31:24 +02:00
2bcd18cc75 Bug fix 2025-12-19 13:13:12 +02:00
5d25378c46 Testing with uint conversion 2025-12-19 13:10:36 +02:00
2b0b48921e Testing more grayscale 2025-12-19 12:02:11 +02:00
b0c05f0225 testing grayscale 2025-12-19 11:55:38 +02:00
97badaa390 Samll update 2025-12-19 11:31:12 +02:00
8f8132ce61 Testing detect 2025-12-19 10:44:11 +02:00
6ae7481e25 Adding debug messages 2025-12-19 10:15:53 +02:00
061f8b3ca2 Fixing pseudo rgb 2025-12-19 09:56:43 +02:00
a8e5db3135 Small change 2025-12-18 13:03:12 +02:00
268ed5175e Appling pseudo channels for RGB 2025-12-18 12:52:13 +02:00
5e9d3b1dc4 Adding logger 2025-12-18 12:04:41 +02:00
7d83e9b9b1 Adding important file 2025-12-17 00:45:56 +02:00
e364d06217 Implementing uint16 reading with tifffile 2025-12-16 23:02:45 +02:00
e5036c10cf Small fix 2025-12-16 18:03:56 +02:00
c7e388d9ae Updating progressbar 2025-12-16 17:20:25 +02:00
6b995e7325 upate 2025-12-16 13:24:20 +02:00
0e0741d323 Update on convert_grayscale_to_rgb_preserve_range, making it class method 2025-12-16 12:37:34 +02:00
dd99a0677c Updating image converter and aading simple script to visulaize segmentation 2025-12-16 11:27:38 +02:00
9c4c39fb39 Adding image converter 2025-12-12 23:52:34 +02:00
20a87c9040 Updating config 2025-12-12 21:51:12 +02:00
9f7d2be1ac Updating the base model preset 2025-12-11 23:27:02 +02:00
dbde07c0e8 Making training tab scrollable 2025-12-11 23:12:39 +02:00
b3c5a51dbb Using QPolygonF instead of drawLine 2025-12-11 17:14:07 +02:00
9a221acb63 Making image manipulations thru one class 2025-12-11 16:59:56 +02:00
32a6a122bd Fixing circular import 2025-12-11 16:06:39 +02:00
9ba44043ef Defining image extensions only in one place 2025-12-11 15:50:14 +02:00
8eb1cc8c86 Fixing grayscale conversion 2025-12-11 15:15:38 +02:00
e4ce882a18 Grayscale RGB conversion modified 2025-12-11 15:06:59 +02:00
6b6d6fad03 2Stage training fix 2025-12-11 12:50:34 +02:00
c0684a9c14 Implementing 2 stage training 2025-12-11 12:04:08 +02:00
221c80aa8c Small image showing fix 2025-12-11 11:20:20 +02:00
833b222fad Adding result shower 2025-12-10 16:55:28 +02:00
5370d31dce Merge pull request 'Update training' (#2) from training into main
Reviewed-on: #2
2025-12-10 15:47:00 +02:00
5d196c3a4a Update training 2025-12-10 15:46:26 +02:00
f719c7ec40 Merge pull request 'segmentation' (#1) from segmentation into main
Reviewed-on: #1
2025-12-10 12:08:54 +02:00
e6a5e74fa1 Adding feature to remove annotations 2025-12-10 00:19:59 +02:00
35e2398e95 Fixing bounding box drawing 2025-12-09 23:56:29 +02:00
c3d44ac945 Renaming Pen tool to polyline tool 2025-12-09 23:38:23 +02:00
dad5c2bf74 Updating 2025-12-09 22:44:23 +02:00
73cb698488 Saving state before replacing annotation tool 2025-12-09 22:00:56 +02:00
12f2bf94d5 Updating polyline saving and drawing 2025-12-09 15:42:42 +02:00
710b684456 Updating annotations 2025-12-08 23:59:44 +02:00
fc22479621 Adding pen tool for annotation 2025-12-08 23:15:54 +02:00
f84dea0bff Adding splitter and saving layout state when closing the app 2025-12-08 22:40:07 +02:00
bb26d43dd7 Adding image_display widget 2025-12-08 17:33:32 +02:00
4b5d2a7c45 Adding image loading 2025-12-08 16:28:58 +02:00
42fb2b782d Bug fix in installing and lauching the program 2025-12-05 16:18:37 +02:00
310e0b2285 Making it installabel package and switching to segmentation mode 2025-12-05 15:51:16 +02:00
9011276584 Small update 2025-12-05 15:30:19 +02:00
43 changed files with 9304 additions and 353 deletions

View File

@@ -2,11 +2,11 @@
## Project Overview
A desktop application for detecting organelles and membrane branching structures in microscopy images using YOLOv8s, with comprehensive training, validation, and visualization capabilities.
A desktop application for detecting and segmenting organelles and membrane branching structures in microscopy images using YOLOv8s-seg, with comprehensive training, validation, and visualization capabilities including pixel-accurate segmentation masks.
## Technology Stack
- **ML Framework**: Ultralytics YOLOv8 (YOLOv8s.pt model)
- **ML Framework**: Ultralytics YOLOv8 (YOLOv8s-seg.pt segmentation model)
- **GUI Framework**: PySide6 (Qt6 for Python)
- **Visualization**: pyqtgraph
- **Database**: SQLite3
@@ -110,6 +110,7 @@ erDiagram
float x_max
float y_max
float confidence
text segmentation_mask
datetime detected_at
json metadata
}
@@ -122,6 +123,7 @@ erDiagram
float y_min
float x_max
float y_max
text segmentation_mask
string annotator
datetime created_at
boolean verified
@@ -139,7 +141,7 @@ Stores information about trained models and their versions.
| model_name | TEXT | NOT NULL | User-friendly model name |
| model_version | TEXT | NOT NULL | Version string (e.g., "v1.0") |
| model_path | TEXT | NOT NULL | Path to model weights file |
| base_model | TEXT | NOT NULL | Base model used (e.g., "yolov8s.pt") |
| base_model | TEXT | NOT NULL | Base model used (e.g., "yolov8s-seg.pt") |
| created_at | TIMESTAMP | DEFAULT CURRENT_TIMESTAMP | Model creation timestamp |
| training_params | JSON | | Training hyperparameters |
| metrics | JSON | | Validation metrics (mAP, precision, recall) |
@@ -159,7 +161,7 @@ Stores metadata about microscopy images.
| checksum | TEXT | | MD5 hash for integrity verification |
#### **detections** table
Stores object detection results.
Stores object detection results with optional segmentation masks.
| Column | Type | Constraints | Description |
|--------|------|-------------|-------------|
@@ -172,11 +174,12 @@ Stores object detection results.
| x_max | REAL | NOT NULL | Bounding box right coordinate (normalized 0-1) |
| y_max | REAL | NOT NULL | Bounding box bottom coordinate (normalized 0-1) |
| confidence | REAL | NOT NULL | Detection confidence score (0-1) |
| segmentation_mask | TEXT | | JSON array of polygon coordinates [[x1,y1], [x2,y2], ...] (normalized 0-1) |
| detected_at | TIMESTAMP | DEFAULT CURRENT_TIMESTAMP | When detection was performed |
| metadata | JSON | | Additional metadata (processing time, etc.) |
#### **annotations** table
Stores manual annotations for training data (future feature).
Stores manual annotations for training data with optional segmentation masks (future feature).
| Column | Type | Constraints | Description |
|--------|------|-------------|-------------|
@@ -187,6 +190,7 @@ Stores manual annotations for training data (future feature).
| y_min | REAL | NOT NULL | Bounding box top coordinate (normalized) |
| x_max | REAL | NOT NULL | Bounding box right coordinate (normalized) |
| y_max | REAL | NOT NULL | Bounding box bottom coordinate (normalized) |
| segmentation_mask | TEXT | | JSON array of polygon coordinates [[x1,y1], [x2,y2], ...] (normalized 0-1) |
| annotator | TEXT | | Name of person who created annotation |
| created_at | TIMESTAMP | DEFAULT CURRENT_TIMESTAMP | Annotation timestamp |
| verified | BOOLEAN | DEFAULT 0 | Whether annotation is verified |
@@ -245,8 +249,9 @@ graph TB
### Key Components
#### 1. **YOLO Wrapper** ([`src/model/yolo_wrapper.py`](src/model/yolo_wrapper.py))
Encapsulates YOLOv8 operations:
- Load pre-trained YOLOv8s model
Encapsulates YOLOv8-seg operations:
- Load pre-trained YOLOv8s-seg segmentation model
- Extract pixel-accurate segmentation masks
- Fine-tune on custom microscopy dataset
- Export trained models
- Provide training progress callbacks
@@ -255,10 +260,10 @@ Encapsulates YOLOv8 operations:
**Key Methods:**
```python
class YOLOWrapper:
def __init__(self, model_path: str = "yolov8s.pt")
def __init__(self, model_path: str = "yolov8s-seg.pt")
def train(self, data_yaml: str, epochs: int, callbacks: dict)
def validate(self, data_yaml: str) -> dict
def predict(self, image_path: str, conf: float) -> list
def predict(self, image_path: str, conf: float) -> list # Returns detections with segmentation masks
def export_model(self, format: str, output_path: str)
```
@@ -435,7 +440,7 @@ image_repository:
allowed_extensions: [".jpg", ".jpeg", ".png", ".tif", ".tiff"]
models:
default_base_model: "yolov8s.pt"
default_base_model: "yolov8s-seg.pt"
models_directory: "data/models"
training:

178
BUILD.md Normal file
View File

@@ -0,0 +1,178 @@
# Building and Publishing Guide
This guide explains how to build and publish the microscopy-object-detection package.
## Prerequisites
```bash
pip install build twine
```
## Building the Package
### 1. Clean Previous Builds
```bash
rm -rf build/ dist/ *.egg-info
```
### 2. Build Distribution Archives
```bash
python -m build
```
This will create both wheel (`.whl`) and source distribution (`.tar.gz`) in the `dist/` directory.
### 3. Verify the Build
```bash
ls dist/
# Should show:
# microscopy_object_detection-1.0.0-py3-none-any.whl
# microscopy_object_detection-1.0.0.tar.gz
```
## Testing the Package Locally
### Install in Development Mode
```bash
pip install -e .
```
### Install from Built Package
```bash
pip install dist/microscopy_object_detection-1.0.0-py3-none-any.whl
```
### Test the Installation
```bash
# Test CLI
microscopy-detect --version
# Test GUI launcher
microscopy-detect-gui
```
## Publishing to PyPI
### 1. Configure PyPI Credentials
Create or update `~/.pypirc`:
```ini
[pypi]
username = __token__
password = pypi-YOUR-API-TOKEN-HERE
```
### 2. Upload to Test PyPI (Recommended First)
```bash
python -m twine upload --repository testpypi dist/*
```
Then test installation:
```bash
pip install --index-url https://test.pypi.org/simple/ microscopy-object-detection
```
### 3. Upload to PyPI
```bash
python -m twine upload dist/*
```
## Version Management
Update version in multiple files:
- `setup.py`: Update `version` parameter
- `pyproject.toml`: Update `version` field
- `src/__init__.py`: Update `__version__` variable
## Git Tags
After publishing, tag the release:
```bash
git tag -a v1.0.0 -m "Release version 1.0.0"
git push origin v1.0.0
```
## Package Structure
The built package includes:
- All Python source files in `src/`
- Configuration files in `config/`
- Database schema file (`src/database/schema.sql`)
- Documentation files (README.md, LICENSE, etc.)
- Entry points for CLI and GUI
## Troubleshooting
### Import Errors
If you get import errors, ensure:
- All `__init__.py` files are present
- Package structure follows the setup configuration
- Dependencies are listed in `requirements.txt`
### Missing Files
If files are missing in the built package:
- Check `MANIFEST.in` includes the required patterns
- Check `pyproject.toml` package-data configuration
- Rebuild with `python -m build --no-isolation` for debugging
### Version Conflicts
If version conflicts occur:
- Ensure version is consistent across all files
- Clear build artifacts and rebuild
- Check for cached installations: `pip list | grep microscopy`
## CI/CD Integration
### GitHub Actions Example
```yaml
name: Build and Publish
on:
release:
types: [created]
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: '3.8'
- name: Install dependencies
run: |
pip install build twine
- name: Build package
run: python -m build
- name: Publish to PyPI
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
run: twine upload dist/*
```
## Best Practices
1. **Version Bumping**: Use semantic versioning (MAJOR.MINOR.PATCH)
2. **Testing**: Always test on Test PyPI before publishing to PyPI
3. **Documentation**: Update README.md and CHANGELOG.md for each release
4. **Git Tags**: Tag releases in git for easy reference
5. **Dependencies**: Keep requirements.txt updated and specify version ranges
## Resources
- [Python Packaging Guide](https://packaging.python.org/)
- [setuptools Documentation](https://setuptools.pypa.io/)
- [PyPI Publishing Guide](https://packaging.python.org/tutorials/packaging-projects/)

236
INSTALL_TEST.md Normal file
View File

@@ -0,0 +1,236 @@
# Installation Testing Guide
This guide helps you verify that the package installation works correctly.
## Clean Installation Test
### 1. Remove Any Previous Installations
```bash
# Deactivate any active virtual environment
deactivate
# Remove old virtual environment (if exists)
rm -rf venv
# Create fresh virtual environment
python3 -m venv venv
source venv/bin/activate # On Linux/Mac
# or
venv\Scripts\activate # On Windows
```
### 2. Install the Package
#### Option A: Editable/Development Install
```bash
pip install -e .
```
This allows you to modify source code and see changes immediately.
#### Option B: Regular Install
```bash
pip install .
```
This installs the package as if it were from PyPI.
### 3. Verify Installation
```bash
# Check package is installed
pip list | grep microscopy
# Check version
microscopy-detect --version
# Expected output: microscopy-object-detection 1.0.0
# Test Python import
python -c "import src; print(src.__version__)"
# Expected output: 1.0.0
```
### 4. Test Entry Points
```bash
# Test CLI
microscopy-detect --help
# Test GUI launcher (will open window)
microscopy-detect-gui
```
### 5. Verify Package Contents
```python
# Run this in Python shell
import src
import src.database
import src.model
import src.gui
# Check schema file is included
from pathlib import Path
import src.database
db_path = Path(src.database.__file__).parent
schema_file = db_path / 'schema.sql'
print(f"Schema file exists: {schema_file.exists()}")
# Expected: Schema file exists: True
```
## Troubleshooting
### Issue: ModuleNotFoundError
**Error:**
```
ModuleNotFoundError: No module named 'src'
```
**Solution:**
```bash
# Reinstall with verbose output
pip install -e . -v
# Or try regular install
pip install . --force-reinstall
```
### Issue: Entry Points Not Working
**Error:**
```
microscopy-detect: command not found
```
**Solution:**
```bash
# Check if scripts are in PATH
which microscopy-detect
# If not found, check pip install location
pip show microscopy-object-detection
# You might need to add to PATH or use full path
~/.local/bin/microscopy-detect # Linux
```
### Issue: Import Errors for PySide6
**Error:**
```
ImportError: cannot import name 'QApplication' from 'PySide6.QtWidgets'
```
**Solution:**
```bash
# Install Qt dependencies (Linux only)
sudo apt-get install libxcb-xinerama0
# Reinstall PySide6
pip uninstall PySide6
pip install PySide6
```
### Issue: Config Files Not Found
**Error:**
```
FileNotFoundError: config/app_config.yaml
```
**Solution:**
The config file should be created automatically. If not:
```bash
# Create config directory in your home
mkdir -p ~/.microscopy-detect
cp config/app_config.yaml ~/.microscopy-detect/
# Or run from source directory first time
cd /home/martin/code/object_detection
python main.py
```
## Manual Testing Checklist
- [ ] Package installs without errors
- [ ] Version command works (`microscopy-detect --version`)
- [ ] Help command works (`microscopy-detect --help`)
- [ ] GUI launches (`microscopy-detect-gui`)
- [ ] Can import all modules in Python
- [ ] Database schema file is accessible
- [ ] Configuration loads correctly
## Build and Install from Wheel
```bash
# Build the package
python -m build
# Install from wheel
pip install dist/microscopy_object_detection-1.0.0-py3-none-any.whl
# Test
microscopy-detect --version
```
## Uninstall
```bash
pip uninstall microscopy-object-detection
```
## Development Workflow
### After Code Changes
If installed with `-e` (editable mode):
- Python code changes are immediately available
- No need to reinstall
If installed with regular `pip install .`:
- Reinstall after changes: `pip install . --force-reinstall`
### After Adding New Files
```bash
# Reinstall to include new files
pip install -e . --force-reinstall
```
## Expected Installation Output
```
Processing /home/martin/code/object_detection
Installing build dependencies ... done
Getting requirements to build wheel ... done
Preparing metadata (pyproject.toml) ... done
Building wheels for collected packages: microscopy-object-detection
Building wheel for microscopy-object-detection (pyproject.toml) ... done
Successfully built microscopy-object-detection
Installing collected packages: microscopy-object-detection
Successfully installed microscopy-object-detection-1.0.0
```
## Success Criteria
Installation is successful when:
1. ✅ No error messages during installation
2.`pip list` shows the package
3.`microscopy-detect --version` returns correct version
4. ✅ GUI launches without errors
5. ✅ All Python modules can be imported
6. ✅ Database operations work
7. ✅ Detection functionality works
## Next Steps After Successful Install
1. Configure image repository path
2. Run first detection
3. Train a custom model
4. Export results
For usage instructions, see [QUICKSTART.md](QUICKSTART.md)

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024 Your Name
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

37
MANIFEST.in Normal file
View File

@@ -0,0 +1,37 @@
# Include documentation files
include README.md
include LICENSE
include ARCHITECTURE.md
include IMPLEMENTATION_GUIDE.md
include QUICKSTART.md
include PLAN_SUMMARY.md
# Include requirements
include requirements.txt
# Include configuration files
recursive-include config *.yaml
recursive-include config *.yml
# Include database schema
recursive-include src/database *.sql
# Include tests
recursive-include tests *.py
# Exclude compiled Python files
global-exclude *.pyc
global-exclude *.pyo
global-exclude __pycache__
global-exclude *.so
global-exclude .DS_Store
# Exclude git and IDE files
global-exclude .git*
global-exclude .vscode
global-exclude .idea
# Exclude build artifacts
prune build
prune dist
prune *.egg-info

View File

@@ -38,7 +38,7 @@ This will install:
- OpenCV and Pillow (image processing)
- And other dependencies
**Note:** The first run will automatically download the YOLOv8s.pt model (~22MB).
**Note:** The first run will automatically download the YOLOv8s-seg.pt segmentation model (~23MB).
### 4. Verify Installation
@@ -84,11 +84,11 @@ In the Settings dialog:
### Single Image Detection
1. Go to the **Detection** tab
2. Select a model from the dropdown (default: Base Model yolov8s.pt)
2. Select a model from the dropdown (default: Base Model yolov8s-seg.pt)
3. Adjust confidence threshold with the slider
4. Click "Detect Single Image"
5. Select an image file
6. View results in the results panel
6. View results with segmentation masks overlaid on the image
### Batch Detection
@@ -108,9 +108,18 @@ Detection results include:
- **Class names**: Types of objects detected (e.g., organelle, membrane_branch)
- **Confidence scores**: Detection confidence (0-1)
- **Bounding boxes**: Object locations (stored in database)
- **Segmentation masks**: Pixel-accurate polygon coordinates for each detected object
All results are stored in the SQLite database at [`data/detections.db`](data/detections.db).
### Segmentation Visualization
The application automatically displays segmentation masks when available:
- Semi-transparent colored overlay (30% opacity) showing the exact shape of detected objects
- Polygon contours outlining each segmentation
- Color-coded by object class
- Toggle-able in future versions
## Database
The application uses SQLite to store:
@@ -176,7 +185,7 @@ sudo apt-get install libxcb-xinerama0
### Detection Not Working
**No models available**
- The base YOLOv8s model will be downloaded automatically on first use
- The base YOLOv8s-seg segmentation model will be downloaded automatically on first use
- Make sure you have internet connection for the first run
**Images not found**

View File

@@ -1,6 +1,6 @@
# Microscopy Object Detection Application
A desktop application for detecting organelles and membrane branching structures in microscopy images using YOLOv8, featuring comprehensive training, validation, and visualization capabilities.
A desktop application for detecting and segmenting organelles and membrane branching structures in microscopy images using YOLOv8-seg, featuring comprehensive training, validation, and visualization capabilities with pixel-accurate segmentation masks.
![Python](https://img.shields.io/badge/python-3.8+-blue.svg)
![PySide6](https://img.shields.io/badge/PySide6-6.5+-green.svg)
@@ -8,8 +8,8 @@ A desktop application for detecting organelles and membrane branching structures
## Features
- **🎯 Object Detection**: Real-time and batch detection of microscopy objects
- **🎓 Model Training**: Fine-tune YOLOv8s on custom microscopy datasets
- **🎯 Object Detection & Segmentation**: Real-time and batch detection with pixel-accurate segmentation masks
- **🎓 Model Training**: Fine-tune YOLOv8s-seg on custom microscopy datasets
- **📊 Validation & Metrics**: Comprehensive model validation with visualization
- **💾 Database Storage**: SQLite database for detection results and metadata
- **📈 Visualization**: Interactive plots and charts using pyqtgraph
@@ -34,14 +34,24 @@ A desktop application for detecting organelles and membrane branching structures
## Installation
### 1. Clone the Repository
### Option 1: Install from PyPI (Recommended)
```bash
pip install microscopy-object-detection
```
This will install the package and all its dependencies.
### Option 2: Install from Source
#### 1. Clone the Repository
```bash
git clone <repository-url>
cd object_detection
```
### 2. Create Virtual Environment
#### 2. Create Virtual Environment
```bash
# Linux/Mac
@@ -53,25 +63,44 @@ python -m venv venv
venv\Scripts\activate
```
### 3. Install Dependencies
#### 3. Install in Development Mode
```bash
pip install -r requirements.txt
# Install in editable mode with dev dependencies
pip install -e ".[dev]"
# Or install just the package
pip install .
```
### 4. Download Base Model
The application will automatically download the YOLOv8s.pt model on first use, or you can download it manually:
The application will automatically download the YOLOv8s-seg.pt segmentation model on first use, or you can download it manually:
```bash
# The model will be downloaded automatically by ultralytics
# Or download manually from: https://github.com/ultralytics/assets/releases
```
**Note:** YOLOv8s-seg is a segmentation model that provides pixel-accurate masks for detected objects, enabling more precise analysis than standard bounding box detection.
## Quick Start
### 1. Launch the Application
After installation, you can launch the application in two ways:
**Using the GUI launcher:**
```bash
microscopy-detect-gui
```
**Or using Python directly:**
```bash
python -m microscopy_object_detection
```
**If installed from source:**
```bash
python main.py
```
@@ -85,11 +114,12 @@ python main.py
### 3. Perform Detection
1. Navigate to the **Detection** tab
2. Select a model (default: yolov8s.pt)
2. Select a model (default: yolov8s-seg.pt)
3. Choose an image or folder
4. Set confidence threshold
5. Click **Detect**
6. View results and save to database
6. View results with segmentation masks overlaid
7. Save results to database
### 4. Train Custom Model
@@ -212,8 +242,8 @@ The application uses SQLite with the following main tables:
- **models**: Stores trained model information and metrics
- **images**: Stores image metadata and paths
- **detections**: Stores detection results with bounding boxes
- **annotations**: Stores manual annotations (future feature)
- **detections**: Stores detection results with bounding boxes and segmentation masks (polygon coordinates)
- **annotations**: Stores manual annotations with optional segmentation masks (future feature)
See [`ARCHITECTURE.md`](ARCHITECTURE.md) for detailed schema information.
@@ -230,7 +260,7 @@ image_repository:
allowed_extensions: [".jpg", ".jpeg", ".png", ".tif", ".tiff"]
models:
default_base_model: "yolov8s.pt"
default_base_model: "yolov8s-seg.pt"
models_directory: "data/models"
training:
@@ -258,7 +288,7 @@ visualization:
from src.model.yolo_wrapper import YOLOWrapper
# Initialize wrapper
yolo = YOLOWrapper("yolov8s.pt")
yolo = YOLOWrapper("yolov8s-seg.pt")
# Train model
results = yolo.train(
@@ -393,10 +423,10 @@ make html
**Issue**: Model not found error
**Solution**: Ensure YOLOv8s.pt is downloaded. Run:
**Solution**: Ensure YOLOv8s-seg.pt is downloaded. Run:
```python
from ultralytics import YOLO
model = YOLO('yolov8s.pt') # Will auto-download
model = YOLO('yolov8s-seg.pt') # Will auto-download
```

View File

@@ -1,48 +0,0 @@
database:
path: "data/detections.db"
image_repository:
base_path: "" # Set by user through GUI
allowed_extensions:
- ".jpg"
- ".jpeg"
- ".png"
- ".tif"
- ".tiff"
- ".bmp"
models:
default_base_model: "yolov8s.pt"
models_directory: "data/models"
training:
default_epochs: 100
default_batch_size: 16
default_imgsz: 640
default_patience: 50
default_lr0: 0.01
detection:
default_confidence: 0.25
default_iou: 0.45
max_batch_size: 100
visualization:
bbox_colors:
organelle: "#FF6B6B"
membrane_branch: "#4ECDC4"
default: "#00FF00"
bbox_thickness: 2
font_size: 12
export:
formats:
- csv
- json
- excel
default_format: "csv"
logging:
level: "INFO"
file: "logs/app.log"
format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"

220
docs/IMAGE_CLASS_USAGE.md Normal file
View File

@@ -0,0 +1,220 @@
# Image Class Usage Guide
The `Image` class provides a convenient way to load and work with images in the microscopy object detection application.
## Supported Formats
The Image class supports the following image formats:
- `.jpg`, `.jpeg` - JPEG images
- `.png` - PNG images
- `.tif`, `.tiff` - TIFF images (commonly used in microscopy)
- `.bmp` - Bitmap images
## Basic Usage
### Loading an Image
```python
from src.utils import Image, ImageLoadError
# Load an image from a file path
try:
img = Image("path/to/image.jpg")
print(f"Loaded image: {img.width}x{img.height} pixels")
except ImageLoadError as e:
print(f"Failed to load image: {e}")
```
### Accessing Image Properties
```python
img = Image("microscopy_image.tif")
# Basic properties
print(f"Width: {img.width} pixels")
print(f"Height: {img.height} pixels")
print(f"Channels: {img.channels}")
print(f"Format: {img.format}")
print(f"Shape: {img.shape}") # (height, width, channels)
# File information
print(f"File size: {img.size_mb:.2f} MB")
print(f"File size: {img.size_bytes} bytes")
# Image type checks
print(f"Is color: {img.is_color()}")
print(f"Is grayscale: {img.is_grayscale()}")
# String representation
print(img) # Shows summary of image properties
```
### Working with Image Data
```python
import numpy as np
img = Image("sample.png")
# Get image data as numpy array (OpenCV format, BGR)
bgr_data = img.data
print(f"Data shape: {bgr_data.shape}")
print(f"Data type: {bgr_data.dtype}")
# Get image as RGB (for display or processing)
rgb_data = img.get_rgb()
# Get grayscale version
gray_data = img.get_grayscale()
# Create a copy (for modifications)
img_copy = img.copy()
img_copy[0, 0] = [255, 255, 255] # Modify copy, original unchanged
# Resize image (returns new array, doesn't modify original)
resized = img.resize(640, 640)
```
### Using PIL Image
```python
img = Image("photo.jpg")
# Access as PIL Image (RGB format)
pil_img = img.pil_image
# Use PIL methods
pil_img.show() # Display image
pil_img.save("output.png") # Save with PIL
```
## Integration with YOLO
```python
from src.utils import Image
from ultralytics import YOLO
# Load model and image
model = YOLO("yolov8n.pt")
img = Image("microscopy/cell_01.tif")
# Run inference (YOLO accepts file paths or numpy arrays)
results = model(img.data)
# Or use the file path directly
results = model(str(img.path))
```
## Error Handling
```python
from src.utils import Image, ImageLoadError
def process_image(image_path):
try:
img = Image(image_path)
# Process the image...
return img
except ImageLoadError as e:
print(f"Cannot load image: {e}")
return None
```
## Advanced Usage
### Batch Processing
```python
from pathlib import Path
from src.utils import Image, ImageLoadError
def process_image_directory(directory):
"""Process all images in a directory."""
image_paths = Path(directory).glob("*.tif")
for path in image_paths:
try:
img = Image(path)
print(f"Processing {img.path.name}: {img.width}x{img.height}")
# Process the image...
except ImageLoadError as e:
print(f"Skipping {path}: {e}")
```
### Using with OpenCV Operations
```python
import cv2
from src.utils import Image
img = Image("input.jpg")
# Apply OpenCV operations on the data
blurred = cv2.GaussianBlur(img.data, (5, 5), 0)
edges = cv2.Canny(img.data, 100, 200)
# Note: These operations don't modify the original img.data
```
### Memory Efficient Processing
```python
from src.utils import Image
# The Image class loads data into memory
img = Image("large_image.tif")
print(f"Image size in memory: {img.data.nbytes / (1024**2):.2f} MB")
# When processing many images, consider loading one at a time
# and releasing memory by deleting the object
del img
```
## Best Practices
1. **Always use try-except** when loading images to handle errors gracefully
2. **Check image properties** before processing to ensure compatibility
3. **Use copy()** when you need to modify image data without affecting the original
4. **Path objects work too** - The class accepts both strings and Path objects
5. **Consider memory usage** when working with large images or batches
## Example: Complete Workflow
```python
from src.utils import Image, ImageLoadError
from src.utils.file_utils import get_image_files
def analyze_microscopy_images(directory):
"""Analyze all microscopy images in a directory."""
# Get all image files
image_files = get_image_files(directory, recursive=True)
results = []
for image_path in image_files:
try:
# Load image
img = Image(image_path)
# Analyze
result = {
'filename': img.path.name,
'width': img.width,
'height': img.height,
'channels': img.channels,
'format': img.format,
'size_mb': img.size_mb,
'is_color': img.is_color()
}
results.append(result)
print(f"✓ Analyzed {img.path.name}")
except ImageLoadError as e:
print(f"✗ Failed to load {image_path}: {e}")
return results
# Run analysis
results = analyze_microscopy_images("data/datasets/cells")
print(f"\nProcessed {len(results)} images")

151
examples/image_demo.py Normal file
View File

@@ -0,0 +1,151 @@
"""
Example script demonstrating the Image class functionality.
"""
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.utils import Image, ImageLoadError
def demonstrate_image_loading():
"""Demonstrate basic image loading functionality."""
print("=" * 60)
print("Image Class Demonstration")
print("=" * 60)
# Example 1: Try to load an image (replace with your own path)
example_paths = [
"data/datasets/example.jpg",
"data/datasets/sample.png",
"tests/test_image.jpg",
]
loaded_img = None
for image_path in example_paths:
if Path(image_path).exists():
try:
print(f"\n1. Loading image: {image_path}")
img = Image(image_path)
loaded_img = img
print(f" ✓ Successfully loaded!")
print(f" {img}")
break
except ImageLoadError as e:
print(f" ✗ Failed: {e}")
else:
print(f"\n1. Image not found: {image_path}")
if loaded_img is None:
print("\nNo example images found. Creating a test image...")
create_test_image()
return
# Example 2: Access image properties
print(f"\n2. Image Properties:")
print(f" Width: {loaded_img.width} pixels")
print(f" Height: {loaded_img.height} pixels")
print(f" Channels: {loaded_img.channels}")
print(f" Format: {loaded_img.format.upper()}")
print(f" Shape: {loaded_img.shape}")
print(f" File size: {loaded_img.size_mb:.2f} MB")
print(f" Is color: {loaded_img.is_color()}")
print(f" Is grayscale: {loaded_img.is_grayscale()}")
# Example 3: Get different formats
print(f"\n3. Accessing Image Data:")
print(f" BGR data shape: {loaded_img.data.shape}")
print(f" RGB data shape: {loaded_img.get_rgb().shape}")
print(f" Grayscale shape: {loaded_img.get_grayscale().shape}")
print(f" PIL image mode: {loaded_img.pil_image.mode}")
# Example 4: Resizing
print(f"\n4. Resizing Image:")
resized = loaded_img.resize(320, 320)
print(f" Original size: {loaded_img.width}x{loaded_img.height}")
print(f" Resized to: {resized.shape[1]}x{resized.shape[0]}")
# Example 5: Working with copies
print(f"\n5. Creating Copies:")
copy = loaded_img.copy()
print(f" Created copy with shape: {copy.shape}")
print(f" Original data unchanged: {(loaded_img.data == copy).all()}")
print("\n" + "=" * 60)
print("Demonstration Complete!")
print("=" * 60)
def create_test_image():
"""Create a test image for demonstration purposes."""
import cv2
import numpy as np
print("\nCreating a test image...")
# Create a colorful test image
width, height = 400, 300
test_img = np.zeros((height, width, 3), dtype=np.uint8)
# Add some colors
test_img[:100, :] = [255, 0, 0] # Blue section
test_img[100:200, :] = [0, 255, 0] # Green section
test_img[200:, :] = [0, 0, 255] # Red section
# Save the test image
test_path = Path("test_demo_image.png")
cv2.imwrite(str(test_path), test_img)
print(f"Test image created: {test_path}")
# Now load and demonstrate with it
try:
img = Image(test_path)
print(f"\nLoaded test image: {img}")
print(f"Dimensions: {img.width}x{img.height}")
print(f"Channels: {img.channels}")
print(f"Format: {img.format}")
# Clean up
test_path.unlink()
print(f"\nTest image cleaned up.")
except ImageLoadError as e:
print(f"Error loading test image: {e}")
def demonstrate_error_handling():
"""Demonstrate error handling."""
print("\n" + "=" * 60)
print("Error Handling Demonstration")
print("=" * 60)
# Try to load non-existent file
print("\n1. Loading non-existent file:")
try:
img = Image("nonexistent.jpg")
except ImageLoadError as e:
print(f" ✓ Caught error: {e}")
# Try unsupported format
print("\n2. Loading unsupported format:")
try:
# Create a text file
test_file = Path("test.txt")
test_file.write_text("not an image")
img = Image(test_file)
except ImageLoadError as e:
print(f" ✓ Caught error: {e}")
test_file.unlink() # Clean up
print("\n" + "=" * 60)
if __name__ == "__main__":
print("\n")
demonstrate_image_loading()
print("\n")
demonstrate_error_handling()
print("\n")

View File

@@ -6,12 +6,13 @@ Main entry point for the application.
import sys
from pathlib import Path
# Add src directory to path
# Add src directory to path for development mode
sys.path.insert(0, str(Path(__file__).parent))
from PySide6.QtWidgets import QApplication
from PySide6.QtCore import Qt
from src import __version__
from src.gui.main_window import MainWindow
from src.utils.logger import setup_logging
from src.utils.config_manager import ConfigManager
@@ -37,7 +38,7 @@ def main():
app = QApplication(sys.argv)
app.setApplicationName("Microscopy Object Detection")
app.setOrganizationName("MicroscopyLab")
app.setApplicationVersion("1.0.0")
app.setApplicationVersion(__version__)
# Set application style
app.setStyle("Fusion")

102
pyproject.toml Normal file
View File

@@ -0,0 +1,102 @@
[build-system]
requires = ["setuptools>=45", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "microscopy-object-detection"
version = "1.0.0"
description = "Desktop application for detecting and segmenting organelles in microscopy images using YOLOv8-seg"
readme = "README.md"
requires-python = ">=3.8"
license = { text = "MIT" }
authors = [{ name = "Your Name", email = "your.email@example.com" }]
keywords = [
"microscopy",
"yolov8",
"object-detection",
"segmentation",
"computer-vision",
"deep-learning",
]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering :: Image Recognition",
"Topic :: Scientific/Engineering :: Bio-Informatics",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Operating System :: OS Independent",
]
dependencies = [
"ultralytics>=8.0.0",
"PySide6>=6.5.0",
"pyqtgraph>=0.13.0",
"numpy>=1.24.0",
"opencv-python>=4.8.0",
"Pillow>=10.0.0",
"PyYAML>=6.0",
"pandas>=2.0.0",
"openpyxl>=3.1.0",
]
[project.optional-dependencies]
dev = [
"pytest>=7.0.0",
"pytest-cov>=4.0.0",
"black>=23.0.0",
"pylint>=2.17.0",
"mypy>=1.0.0",
]
[project.urls]
Homepage = "https://github.com/yourusername/object_detection"
Documentation = "https://github.com/yourusername/object_detection/blob/main/README.md"
Repository = "https://github.com/yourusername/object_detection"
"Bug Tracker" = "https://github.com/yourusername/object_detection/issues"
[project.scripts]
microscopy-detect = "src.cli:main"
[project.gui-scripts]
microscopy-detect-gui = "src.gui_launcher:main"
[tool.setuptools]
packages = [
"src",
"src.database",
"src.model",
"src.gui",
"src.gui.tabs",
"src.gui.dialogs",
"src.gui.widgets",
"src.utils",
]
include-package-data = true
[tool.setuptools.package-data]
"src.database" = ["*.sql"]
[tool.black]
line-length = 120
target-version = ['py38', 'py39', 'py310', 'py311']
include = '\.pyi?$'
[tool.pylint.messages_control]
max-line-length = 120
[tool.mypy]
python_version = "3.8"
warn_return_any = true
warn_unused_configs = true
disallow_untyped_defs = false
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py"]
python_functions = ["test_*"]
addopts = "-v --cov=src --cov-report=term-missing"

56
setup.py Normal file
View File

@@ -0,0 +1,56 @@
"""Setup script for Microscopy Object Detection Application."""
from setuptools import setup, find_packages
from pathlib import Path
# Read the contents of README file
this_directory = Path(__file__).parent
long_description = (this_directory / "README.md").read_text(encoding="utf-8")
# Read requirements
requirements = (this_directory / "requirements.txt").read_text().splitlines()
requirements = [
req.strip() for req in requirements if req.strip() and not req.startswith("#")
]
setup(
name="microscopy-object-detection",
version="1.0.0",
author="Your Name",
author_email="your.email@example.com",
description="Desktop application for detecting and segmenting organelles in microscopy images using YOLOv8-seg",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/yourusername/object_detection",
packages=find_packages(exclude=["tests", "tests.*", "docs"]),
include_package_data=True,
install_requires=requirements,
python_requires=">=3.8",
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering :: Image Recognition",
"Topic :: Scientific/Engineering :: Bio-Informatics",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Operating System :: OS Independent",
],
entry_points={
"console_scripts": [
"microscopy-detect=src.cli:main",
],
"gui_scripts": [
"microscopy-detect-gui=src.gui_launcher:main",
],
},
keywords="microscopy yolov8 object-detection segmentation computer-vision deep-learning",
project_urls={
"Bug Reports": "https://github.com/yourusername/object_detection/issues",
"Source": "https://github.com/yourusername/object_detection",
"Documentation": "https://github.com/yourusername/object_detection/blob/main/README.md",
},
)

View File

@@ -0,0 +1,19 @@
"""
Microscopy Object Detection Application
A desktop application for detecting and segmenting organelles and membrane
branching structures in microscopy images using YOLOv8-seg.
"""
__version__ = "1.0.0"
__author__ = "Your Name"
__email__ = "your.email@example.com"
__license__ = "MIT"
# Package metadata
__all__ = [
"__version__",
"__author__",
"__email__",
"__license__",
]

61
src/cli.py Normal file
View File

@@ -0,0 +1,61 @@
"""
Command-line interface for microscopy object detection application.
"""
import sys
import argparse
from pathlib import Path
from src import __version__
def main():
"""Main CLI entry point."""
parser = argparse.ArgumentParser(
description="Microscopy Object Detection Application - CLI Interface",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Launch GUI
microscopy-detect-gui
# Show version
microscopy-detect --version
# Get help
microscopy-detect --help
""",
)
parser.add_argument(
"--version",
action="version",
version=f"microscopy-object-detection {__version__}",
)
parser.add_argument(
"--gui",
action="store_true",
help="Launch the GUI application (same as microscopy-detect-gui)",
)
args = parser.parse_args()
if args.gui:
# Launch GUI
try:
from src.gui_launcher import main as gui_main
gui_main()
except Exception as e:
print(f"Error launching GUI: {e}", file=sys.stderr)
sys.exit(1)
else:
# Show help if no arguments provided
parser.print_help()
print("\nTo launch the GUI, use: microscopy-detect-gui")
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -6,10 +6,18 @@ Handles all database operations including CRUD operations, queries, and exports.
import sqlite3
import json
from datetime import datetime
from typing import List, Dict, Optional, Tuple, Any
from typing import List, Dict, Optional, Tuple, Any, Union
from pathlib import Path
import csv
import hashlib
import yaml
from src.utils.logger import get_logger
from src.utils.image import Image
IMAGE_EXTENSIONS = tuple(Image.SUPPORTED_EXTENSIONS)
logger = get_logger(__name__)
class DatabaseManager:
@@ -30,18 +38,46 @@ class DatabaseManager:
# Create directory if it doesn't exist
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
# Read schema file and execute
schema_path = Path(__file__).parent / "schema.sql"
with open(schema_path, "r") as f:
schema_sql = f.read()
conn = self.get_connection()
try:
# Check if annotations table needs migration
self._migrate_annotations_table(conn)
# Read schema file and execute
schema_path = Path(__file__).parent / "schema.sql"
with open(schema_path, "r") as f:
schema_sql = f.read()
conn.executescript(schema_sql)
conn.commit()
finally:
conn.close()
def _migrate_annotations_table(self, conn: sqlite3.Connection) -> None:
"""
Migrate annotations table from old schema (class_name) to new schema (class_id).
"""
cursor = conn.cursor()
# Check if annotations table exists
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='annotations'")
if not cursor.fetchone():
# Table doesn't exist yet, no migration needed
return
# Check if table has old schema (class_name column)
cursor.execute("PRAGMA table_info(annotations)")
columns = {row[1]: row for row in cursor.fetchall()}
if "class_name" in columns and "class_id" not in columns:
# Old schema detected, need to migrate
print("Migrating annotations table to new schema with class_id...")
# Drop old annotations table (assuming no critical data since this is a new feature)
cursor.execute("DROP TABLE IF EXISTS annotations")
conn.commit()
print("Old annotations table dropped, will be recreated with new schema")
def get_connection(self) -> sqlite3.Connection:
"""Get database connection with proper settings."""
conn = sqlite3.connect(self.db_path)
@@ -56,7 +92,7 @@ class DatabaseManager:
model_name: str,
model_version: str,
model_path: str,
base_model: str = "yolov8s.pt",
base_model: str = "yolov8s-seg.pt",
training_params: Optional[Dict] = None,
metrics: Optional[Dict] = None,
) -> int:
@@ -165,6 +201,28 @@ class DatabaseManager:
finally:
conn.close()
def delete_model(self, model_id: int) -> bool:
"""Delete a model from the database.
Note: detections referencing this model are deleted automatically via
the `detections.model_id` foreign key (ON DELETE CASCADE).
Args:
model_id: ID of the model to delete.
Returns:
True if a model row was deleted, False otherwise.
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("DELETE FROM models WHERE id = ?", (model_id,))
conn.commit()
return cursor.rowcount > 0
finally:
conn.close()
# ==================== Image Operations ====================
def add_image(
@@ -204,9 +262,7 @@ class DatabaseManager:
return cursor.lastrowid
except sqlite3.IntegrityError:
# Image already exists, return its ID
cursor.execute(
"SELECT id FROM images WHERE relative_path = ?", (relative_path,)
)
cursor.execute("SELECT id FROM images WHERE relative_path = ?", (relative_path,))
row = cursor.fetchone()
return row["id"] if row else None
finally:
@@ -217,17 +273,13 @@ class DatabaseManager:
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"SELECT * FROM images WHERE relative_path = ?", (relative_path,)
)
cursor.execute("SELECT * FROM images WHERE relative_path = ?", (relative_path,))
row = cursor.fetchone()
return dict(row) if row else None
finally:
conn.close()
def get_or_create_image(
self, relative_path: str, filename: str, width: int, height: int
) -> int:
def get_or_create_image(self, relative_path: str, filename: str, width: int, height: int) -> int:
"""Get existing image or create new one."""
existing = self.get_image_by_path(relative_path)
if existing:
@@ -243,6 +295,7 @@ class DatabaseManager:
class_name: str,
bbox: Tuple[float, float, float, float], # (x_min, y_min, x_max, y_max)
confidence: float,
segmentation_mask: Optional[List[List[float]]] = None,
metadata: Optional[Dict] = None,
) -> int:
"""
@@ -254,6 +307,7 @@ class DatabaseManager:
class_name: Detected object class
bbox: Bounding box coordinates (normalized 0-1)
confidence: Detection confidence score
segmentation_mask: Polygon coordinates for segmentation [[x1,y1], [x2,y2], ...]
metadata: Additional metadata
Returns:
@@ -265,8 +319,8 @@ class DatabaseManager:
x_min, y_min, x_max, y_max = bbox
cursor.execute(
"""
INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, segmentation_mask, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
image_id,
@@ -277,6 +331,7 @@ class DatabaseManager:
x_max,
y_max,
confidence,
json.dumps(segmentation_mask) if segmentation_mask else None,
json.dumps(metadata) if metadata else None,
),
)
@@ -302,8 +357,8 @@ class DatabaseManager:
bbox = det["bbox"]
cursor.execute(
"""
INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, segmentation_mask, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
det["image_id"],
@@ -314,11 +369,8 @@ class DatabaseManager:
bbox[2],
bbox[3],
det["confidence"],
(
json.dumps(det.get("metadata"))
if det.get("metadata")
else None
),
(json.dumps(det.get("segmentation_mask")) if det.get("segmentation_mask") else None),
(json.dumps(det.get("metadata")) if det.get("metadata") else None),
),
)
conn.commit()
@@ -363,15 +415,16 @@ class DatabaseManager:
if filters:
conditions = []
for key, value in filters.items():
if (
key.startswith("d.")
or key.startswith("i.")
or key.startswith("m.")
):
conditions.append(f"{key} = ?")
if key.startswith("d.") or key.startswith("i.") or key.startswith("m."):
if "like" in value.lower():
conditions.append(f"{key} LIKE ?")
params.append(value.split(" ")[1])
else:
conditions.append(f"{key} = ?")
params.append(value)
else:
conditions.append(f"d.{key} = ?")
params.append(value)
params.append(value)
query += " WHERE " + " AND ".join(conditions)
query += " ORDER BY d.detected_at DESC"
@@ -385,24 +438,41 @@ class DatabaseManager:
detections = []
for row in cursor.fetchall():
det = dict(row)
# Parse JSON metadata
# Parse JSON fields
if det.get("metadata"):
det["metadata"] = json.loads(det["metadata"])
if det.get("segmentation_mask"):
det["segmentation_mask"] = json.loads(det["segmentation_mask"])
detections.append(det)
return detections
finally:
conn.close()
def get_detections_for_image(
self, image_id: int, model_id: Optional[int] = None
) -> List[Dict]:
def get_detections_for_image(self, image_id: int, model_id: Optional[int] = None) -> List[Dict]:
"""Get all detections for a specific image."""
filters = {"image_id": image_id}
if model_id:
filters["model_id"] = model_id
return self.get_detections(filters)
def delete_detections_for_image(self, image_id: int, model_id: Optional[int] = None) -> int:
"""Delete detections tied to a specific image and optional model."""
conn = self.get_connection()
try:
cursor = conn.cursor()
if model_id is not None:
cursor.execute(
"DELETE FROM detections WHERE image_id = ? AND model_id = ?",
(image_id, model_id),
)
else:
cursor.execute("DELETE FROM detections WHERE image_id = ?", (image_id,))
conn.commit()
return cursor.rowcount
finally:
conn.close()
def delete_detections_for_model(self, model_id: int) -> int:
"""Delete all detections for a specific model."""
conn = self.get_connection()
@@ -414,6 +484,22 @@ class DatabaseManager:
finally:
conn.close()
def delete_all_detections(self) -> int:
"""Delete all detections from the database.
Returns:
Number of rows deleted.
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("DELETE FROM detections")
conn.commit()
return cursor.rowcount
finally:
conn.close()
# ==================== Statistics Operations ====================
def get_detection_statistics(
@@ -457,9 +543,7 @@ class DatabaseManager:
""",
params,
)
class_counts = {
row["class_name"]: row["count"] for row in cursor.fetchall()
}
class_counts = {row["class_name"]: row["count"] for row in cursor.fetchall()}
# Average confidence
cursor.execute(
@@ -516,9 +600,7 @@ class DatabaseManager:
# ==================== Export Operations ====================
def export_detections_to_csv(
self, output_path: str, filters: Optional[Dict] = None
) -> bool:
def export_detections_to_csv(self, output_path: str, filters: Optional[Dict] = None) -> bool:
"""Export detections to CSV file."""
try:
detections = self.get_detections(filters)
@@ -538,6 +620,7 @@ class DatabaseManager:
"x_max",
"y_max",
"confidence",
"segmentation_mask",
"detected_at",
]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
@@ -545,6 +628,9 @@ class DatabaseManager:
for det in detections:
row = {k: det[k] for k in fieldnames if k in det}
# Convert segmentation mask list to JSON string for CSV
if row.get("segmentation_mask") and isinstance(row["segmentation_mask"], list):
row["segmentation_mask"] = json.dumps(row["segmentation_mask"])
writer.writerow(row)
return True
@@ -552,9 +638,7 @@ class DatabaseManager:
print(f"Error exporting to CSV: {e}")
return False
def export_detections_to_json(
self, output_path: str, filters: Optional[Dict] = None
) -> bool:
def export_detections_to_json(self, output_path: str, filters: Optional[Dict] = None) -> bool:
"""Export detections to JSON file."""
try:
detections = self.get_detections(filters)
@@ -574,25 +658,118 @@ class DatabaseManager:
# ==================== Annotation Operations ====================
def get_annotated_images_summary(
self,
name_filter: Optional[str] = None,
order_by: str = "filename",
order_dir: str = "ASC",
limit: Optional[int] = None,
offset: int = 0,
) -> List[Dict]:
"""Return images that have at least one manual annotation.
Args:
name_filter: Optional substring filter applied to filename/relative_path.
order_by: One of: 'filename', 'relative_path', 'annotation_count', 'added_at'.
order_dir: 'ASC' or 'DESC'.
limit: Optional max number of rows.
offset: Pagination offset.
Returns:
List of dicts: {id, relative_path, filename, added_at, annotation_count}
"""
allowed_order_by = {
"filename": "i.filename",
"relative_path": "i.relative_path",
"annotation_count": "annotation_count",
"added_at": "i.added_at",
}
order_expr = allowed_order_by.get(order_by, "i.filename")
dir_norm = str(order_dir).upper().strip()
if dir_norm not in {"ASC", "DESC"}:
dir_norm = "ASC"
conn = self.get_connection()
try:
params: List[Any] = []
where_sql = ""
if name_filter:
# Case-insensitive substring search.
token = f"%{name_filter}%"
where_sql = "WHERE (i.filename LIKE ? OR i.relative_path LIKE ?)"
params.extend([token, token])
limit_sql = ""
if limit is not None:
limit_sql = " LIMIT ? OFFSET ?"
params.extend([int(limit), int(offset)])
query = f"""
SELECT
i.id,
i.relative_path,
i.filename,
i.added_at,
COUNT(a.id) AS annotation_count
FROM images i
JOIN annotations a ON a.image_id = i.id
{where_sql}
GROUP BY i.id
HAVING annotation_count > 0
ORDER BY {order_expr} {dir_norm}
{limit_sql}
"""
cursor = conn.cursor()
cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()]
finally:
conn.close()
def add_annotation(
self,
image_id: int,
class_name: str,
class_id: int,
bbox: Tuple[float, float, float, float],
annotator: str,
segmentation_mask: Optional[List[List[float]]] = None,
verified: bool = False,
) -> int:
"""Add manual annotation."""
"""
Add manual annotation.
Args:
image_id: ID of the image
class_id: ID of the object class (foreign key to object_classes)
bbox: Bounding box coordinates (normalized 0-1)
annotator: Name of person/tool creating annotation
segmentation_mask: Polygon coordinates for segmentation
verified: Whether annotation has been verified
Returns:
ID of the inserted annotation
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
x_min, y_min, x_max, y_max = bbox
cursor.execute(
"""
INSERT INTO annotations (image_id, class_name, x_min, y_min, x_max, y_max, annotator, verified)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
INSERT INTO annotations (image_id, class_id, x_min, y_min, x_max, y_max, segmentation_mask, annotator, verified)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(image_id, class_name, x_min, y_min, x_max, y_max, annotator, verified),
(
image_id,
class_id,
x_min,
y_min,
x_max,
y_max,
json.dumps(segmentation_mask) if segmentation_mask else None,
annotator,
verified,
),
)
conn.commit()
return cursor.lastrowid
@@ -600,15 +777,363 @@ class DatabaseManager:
conn.close()
def get_annotations_for_image(self, image_id: int) -> List[Dict]:
"""Get all annotations for an image."""
"""
Get all annotations for an image with class information.
Args:
image_id: ID of the image
Returns:
List of annotation dictionaries with joined class information
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM annotations WHERE image_id = ?", (image_id,))
cursor.execute(
"""
SELECT
a.*,
c.class_name,
c.color as class_color,
c.description as class_description
FROM annotations a
JOIN object_classes c ON a.class_id = c.id
WHERE a.image_id = ?
ORDER BY a.created_at DESC
""",
(image_id,),
)
annotations = []
for row in cursor.fetchall():
ann = dict(row)
if ann.get("segmentation_mask"):
ann["segmentation_mask"] = json.loads(ann["segmentation_mask"])
annotations.append(ann)
return annotations
finally:
conn.close()
def delete_annotation(self, annotation_id: int) -> bool:
"""
Delete a manual annotation by ID.
Args:
annotation_id: ID of the annotation to delete
Returns:
True if an annotation was deleted, False otherwise.
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("DELETE FROM annotations WHERE id = ?", (annotation_id,))
conn.commit()
return cursor.rowcount > 0
finally:
conn.close()
# ==================== Object Class Operations ====================
def get_object_classes(self) -> List[Dict]:
"""
Get all object classes.
Returns:
List of object class dictionaries
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM object_classes ORDER BY class_name")
return [dict(row) for row in cursor.fetchall()]
finally:
conn.close()
def get_object_class_by_id(self, class_id: int) -> Optional[Dict]:
"""Get object class by ID."""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM object_classes WHERE id = ?", (class_id,))
row = cursor.fetchone()
return dict(row) if row else None
finally:
conn.close()
def get_object_class_by_name(self, class_name: str) -> Optional[Dict]:
"""Get object class by name."""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM object_classes WHERE class_name = ?", (class_name,))
row = cursor.fetchone()
return dict(row) if row else None
finally:
conn.close()
def add_object_class(self, class_name: str, color: str, description: Optional[str] = None) -> int:
"""
Add a new object class.
Args:
class_name: Name of the object class
color: Hex color code (e.g., '#FF0000')
description: Optional description
Returns:
ID of the inserted object class
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO object_classes (class_name, color, description)
VALUES (?, ?, ?)
""",
(class_name, color, description),
)
conn.commit()
return cursor.lastrowid
except sqlite3.IntegrityError:
# Class already exists
existing = self.get_object_class_by_name(class_name)
return existing["id"] if existing else None
finally:
conn.close()
def update_object_class(
self,
class_id: int,
class_name: Optional[str] = None,
color: Optional[str] = None,
description: Optional[str] = None,
) -> bool:
"""
Update an object class.
Args:
class_id: ID of the class to update
class_name: New class name (optional)
color: New color (optional)
description: New description (optional)
Returns:
True if updated, False otherwise
"""
conn = self.get_connection()
try:
updates = {}
if class_name is not None:
updates["class_name"] = class_name
if color is not None:
updates["color"] = color
if description is not None:
updates["description"] = description
if not updates:
return False
set_clauses = [f"{key} = ?" for key in updates.keys()]
params = list(updates.values()) + [class_id]
query = f"UPDATE object_classes SET {', '.join(set_clauses)} WHERE id = ?"
cursor = conn.cursor()
cursor.execute(query, params)
conn.commit()
return cursor.rowcount > 0
finally:
conn.close()
def delete_object_class(self, class_id: int) -> bool:
"""
Delete an object class.
Args:
class_id: ID of the class to delete
Returns:
True if deleted, False otherwise
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("DELETE FROM object_classes WHERE id = ?", (class_id,))
conn.commit()
return cursor.rowcount > 0
finally:
conn.close()
# ==================== Dataset Utilities ====================
def compose_data_yaml(
self,
dataset_root: str,
output_path: Optional[str] = None,
splits: Optional[Dict[str, str]] = None,
) -> str:
"""
Compose a YOLO data.yaml file based on dataset folders and database metadata.
Args:
dataset_root: Base directory containing the dataset structure.
output_path: Optional output path; defaults to <dataset_root>/data.yaml.
splits: Optional mapping overriding train/val/test image directories (relative
to dataset_root or absolute paths).
Returns:
Path to the generated YAML file.
"""
dataset_root_path = Path(dataset_root).expanduser()
if not dataset_root_path.exists():
raise ValueError(f"Dataset root does not exist: {dataset_root_path}")
dataset_root_path = dataset_root_path.resolve()
split_map: Dict[str, str] = {key: "" for key in ("train", "val", "test")}
if splits:
for key, value in splits.items():
if key in split_map and value:
split_map[key] = value
inferred = self._infer_split_dirs(dataset_root_path)
for key in split_map:
if not split_map[key]:
split_map[key] = inferred.get(key, "")
for required in ("train", "val"):
if not split_map[required]:
raise ValueError(
"Unable to determine %s image directory under %s. Provide it "
"explicitly via the 'splits' argument." % (required, dataset_root_path)
)
yaml_splits: Dict[str, str] = {}
for key, value in split_map.items():
if not value:
continue
yaml_splits[key] = self._normalize_split_value(value, dataset_root_path)
class_names = self._fetch_annotation_class_names()
if not class_names:
class_names = [cls["class_name"] for cls in self.get_object_classes()]
if not class_names:
raise ValueError("No object classes available to populate data.yaml")
names_map = {idx: name for idx, name in enumerate(class_names)}
payload: Dict[str, Any] = {
"path": dataset_root_path.as_posix(),
"train": yaml_splits["train"],
"val": yaml_splits["val"],
"names": names_map,
"nc": len(class_names),
}
if yaml_splits.get("test"):
payload["test"] = yaml_splits["test"]
output_path_obj = Path(output_path).expanduser() if output_path else dataset_root_path / "data.yaml"
output_path_obj.parent.mkdir(parents=True, exist_ok=True)
with open(output_path_obj, "w", encoding="utf-8") as handle:
yaml.safe_dump(payload, handle, sort_keys=False)
logger.info(f"Generated data.yaml at {output_path_obj}")
return output_path_obj.as_posix()
def _fetch_annotation_class_names(self) -> List[str]:
"""Return class names referenced by annotations (ordered by class ID)."""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"""
SELECT DISTINCT c.id, c.class_name
FROM annotations a
JOIN object_classes c ON a.class_id = c.id
ORDER BY c.id
"""
)
rows = cursor.fetchall()
return [row["class_name"] for row in rows]
finally:
conn.close()
def _infer_split_dirs(self, dataset_root: Path) -> Dict[str, str]:
"""Infer train/val/test image directories relative to dataset_root."""
patterns = {
"train": [
"train/images",
"training/images",
"images/train",
"images/training",
"train",
"training",
],
"val": [
"val/images",
"validation/images",
"images/val",
"images/validation",
"val",
"validation",
],
"test": [
"test/images",
"testing/images",
"images/test",
"images/testing",
"test",
"testing",
],
}
inferred: Dict[str, str] = {key: "" for key in patterns}
for split_name, options in patterns.items():
for relative in options:
candidate = (dataset_root / relative).resolve()
if candidate.exists() and candidate.is_dir() and self._directory_has_images(candidate):
try:
inferred[split_name] = candidate.relative_to(dataset_root).as_posix()
except ValueError:
inferred[split_name] = candidate.as_posix()
break
return inferred
def _normalize_split_value(self, split_value: str, dataset_root: Path) -> str:
"""Validate and normalize a split directory to a YAML-friendly string."""
split_path = Path(split_value).expanduser()
if not split_path.is_absolute():
split_path = (dataset_root / split_path).resolve()
else:
split_path = split_path.resolve()
if not split_path.exists() or not split_path.is_dir():
raise ValueError(f"Split directory not found: {split_path}")
if not self._directory_has_images(split_path):
raise ValueError(f"No images found under {split_path}")
try:
return split_path.relative_to(dataset_root).as_posix()
except ValueError:
return split_path.as_posix()
@staticmethod
def _directory_has_images(directory: Path, max_checks: int = 2000) -> bool:
"""Return True if directory tree contains at least one image file."""
checked = 0
try:
for file_path in directory.rglob("*"):
if not file_path.is_file():
continue
if file_path.suffix.lower() in IMAGE_EXTENSIONS:
return True
checked += 1
if checked >= max_checks:
break
except Exception:
return False
return False
@staticmethod
def calculate_checksum(file_path: str) -> str:
"""Calculate MD5 checksum of a file."""

View File

@@ -5,7 +5,7 @@ These dataclasses represent the database entities.
from dataclasses import dataclass
from datetime import datetime
from typing import Optional, Dict, Tuple
from typing import Optional, Dict, Tuple, List
@dataclass
@@ -46,6 +46,9 @@ class Detection:
class_name: str
bbox: Tuple[float, float, float, float] # (x_min, y_min, x_max, y_max)
confidence: float
segmentation_mask: Optional[
List[List[float]]
] # List of polygon coordinates [[x1,y1], [x2,y2], ...]
detected_at: datetime
metadata: Optional[Dict]
@@ -58,6 +61,9 @@ class Annotation:
image_id: int
class_name: str
bbox: Tuple[float, float, float, float] # (x_min, y_min, x_max, y_max)
segmentation_mask: Optional[
List[List[float]]
] # List of polygon coordinates [[x1,y1], [x2,y2], ...]
annotator: str
created_at: datetime
verified: bool

View File

@@ -37,25 +37,41 @@ CREATE TABLE IF NOT EXISTS detections (
x_max REAL NOT NULL CHECK(x_max >= 0 AND x_max <= 1),
y_max REAL NOT NULL CHECK(y_max >= 0 AND y_max <= 1),
confidence REAL NOT NULL CHECK(confidence >= 0 AND confidence <= 1),
segmentation_mask TEXT, -- JSON string of polygon coordinates [[x1,y1], [x2,y2], ...]
detected_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
metadata TEXT, -- JSON string for additional metadata
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE,
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE
);
-- Annotations table: stores manual annotations (future feature)
-- Object classes table: stores annotation class definitions with colors
CREATE TABLE IF NOT EXISTS object_classes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
class_name TEXT NOT NULL UNIQUE,
color TEXT NOT NULL, -- Hex color code (e.g., '#FF0000')
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
description TEXT
);
-- Insert default object classes
INSERT OR IGNORE INTO object_classes (class_name, color, description) VALUES
('terminal', '#FFFF00', 'Axion terminal');
-- Annotations table: stores manual annotations
CREATE TABLE IF NOT EXISTS annotations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
image_id INTEGER NOT NULL,
class_name TEXT NOT NULL,
class_id INTEGER NOT NULL,
x_min REAL NOT NULL CHECK(x_min >= 0 AND x_min <= 1),
y_min REAL NOT NULL CHECK(y_min >= 0 AND y_min <= 1),
x_max REAL NOT NULL CHECK(x_max >= 0 AND x_max <= 1),
y_max REAL NOT NULL CHECK(y_max >= 0 AND y_max <= 1),
segmentation_mask TEXT, -- JSON string of polygon coordinates [[x1,y1], [x2,y2], ...]
annotator TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
verified BOOLEAN DEFAULT 0,
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE,
FOREIGN KEY (class_id) REFERENCES object_classes (id) ON DELETE CASCADE
);
-- Create indexes for performance optimization
@@ -67,4 +83,6 @@ CREATE INDEX IF NOT EXISTS idx_detections_confidence ON detections(confidence);
CREATE INDEX IF NOT EXISTS idx_images_relative_path ON images(relative_path);
CREATE INDEX IF NOT EXISTS idx_images_added_at ON images(added_at);
CREATE INDEX IF NOT EXISTS idx_annotations_image_id ON annotations(image_id);
CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at);
CREATE INDEX IF NOT EXISTS idx_annotations_class_id ON annotations(class_id);
CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at);
CREATE INDEX IF NOT EXISTS idx_object_classes_class_name ON object_classes(class_name);

View File

@@ -121,7 +121,7 @@ class ConfigDialog(QDialog):
models_layout.addRow("Models Directory:", self.models_dir_edit)
self.base_model_edit = QLineEdit()
self.base_model_edit.setPlaceholderText("yolov8s.pt")
self.base_model_edit.setPlaceholderText("yolov8s-seg.pt")
models_layout.addRow("Default Base Model:", self.base_model_edit)
models_group.setLayout(models_layout)
@@ -232,7 +232,7 @@ class ConfigDialog(QDialog):
self.config_manager.get("models.models_directory", "data/models")
)
self.base_model_edit.setText(
self.config_manager.get("models.default_base_model", "yolov8s.pt")
self.config_manager.get("models.default_base_model", "yolov8s-seg.pt")
)
# Training settings

View File

@@ -1,6 +1,7 @@
"""
Main window for the microscopy object detection application.
"""
"""Main window for the microscopy object detection application."""
import shutil
from pathlib import Path
from PySide6.QtWidgets import (
QMainWindow,
@@ -13,13 +14,14 @@ from PySide6.QtWidgets import (
QVBoxLayout,
QLabel,
)
from PySide6.QtCore import Qt, QTimer
from PySide6.QtCore import Qt, QTimer, QSettings
from PySide6.QtGui import QAction, QKeySequence
from src.database.db_manager import DatabaseManager
from src.utils.config_manager import ConfigManager
from src.utils.logger import get_logger
from src.gui.dialogs.config_dialog import ConfigDialog
from src.gui.dialogs.delete_model_dialog import DeleteModelDialog
from src.gui.tabs.detection_tab import DetectionTab
from src.gui.tabs.training_tab import TrainingTab
from src.gui.tabs.validation_tab import ValidationTab
@@ -52,8 +54,8 @@ class MainWindow(QMainWindow):
self._create_tab_widget()
self._create_status_bar()
# Center window on screen
self._center_window()
# Restore window geometry or center window on screen
self._restore_window_state()
logger.info("Main window initialized")
@@ -91,6 +93,12 @@ class MainWindow(QMainWindow):
db_stats_action.triggered.connect(self._show_database_stats)
tools_menu.addAction(db_stats_action)
tools_menu.addSeparator()
delete_model_action = QAction("Delete &Model…", self)
delete_model_action.triggered.connect(self._show_delete_model_dialog)
tools_menu.addAction(delete_model_action)
# Help menu
help_menu = menubar.addMenu("&Help")
@@ -117,10 +125,10 @@ class MainWindow(QMainWindow):
# Add tabs to widget
self.tab_widget.addTab(self.detection_tab, "Detection")
self.tab_widget.addTab(self.results_tab, "Results")
self.tab_widget.addTab(self.annotation_tab, "Annotation")
self.tab_widget.addTab(self.training_tab, "Training")
self.tab_widget.addTab(self.validation_tab, "Validation")
self.tab_widget.addTab(self.results_tab, "Results")
self.tab_widget.addTab(self.annotation_tab, "Annotation (Future)")
# Connect tab change signal
self.tab_widget.currentChanged.connect(self._on_tab_changed)
@@ -152,9 +160,25 @@ class MainWindow(QMainWindow):
"""Center window on screen."""
screen = self.screen().geometry()
size = self.geometry()
self.move(
(screen.width() - size.width()) // 2, (screen.height() - size.height()) // 2
)
self.move((screen.width() - size.width()) // 2, (screen.height() - size.height()) // 2)
def _restore_window_state(self):
"""Restore window geometry from settings or center window."""
settings = QSettings("microscopy_app", "object_detection")
geometry = settings.value("main_window/geometry")
if geometry:
self.restoreGeometry(geometry)
logger.debug("Restored window geometry from settings")
else:
self._center_window()
logger.debug("Centered window on screen")
def _save_window_state(self):
"""Save window geometry to settings."""
settings = QSettings("microscopy_app", "object_detection")
settings.setValue("main_window/geometry", self.saveGeometry())
logger.debug("Saved window geometry to settings")
def _show_settings(self):
"""Show settings dialog."""
@@ -175,6 +199,10 @@ class MainWindow(QMainWindow):
self.training_tab.refresh()
if hasattr(self, "results_tab"):
self.results_tab.refresh()
if hasattr(self, "annotation_tab"):
self.annotation_tab.refresh()
if hasattr(self, "validation_tab"):
self.validation_tab.refresh()
except Exception as e:
logger.error(f"Error applying settings: {e}")
@@ -191,6 +219,14 @@ class MainWindow(QMainWindow):
logger.debug(f"Switched to tab: {tab_name}")
self._update_status(f"Viewing: {tab_name}")
# Ensure the Annotation tab always shows up-to-date DB-backed lists.
try:
current_widget = self.tab_widget.widget(index)
if hasattr(self, "annotation_tab") and current_widget is self.annotation_tab:
self.annotation_tab.refresh()
except Exception as exc:
logger.debug(f"Failed to refresh annotation tab on selection: {exc}")
def _show_database_stats(self):
"""Show database statistics dialog."""
try:
@@ -213,9 +249,229 @@ class MainWindow(QMainWindow):
except Exception as e:
logger.error(f"Error getting database stats: {e}")
QMessageBox.warning(
self, "Error", f"Failed to get database statistics:\n{str(e)}"
)
QMessageBox.warning(self, "Error", f"Failed to get database statistics:\n{str(e)}")
def _show_delete_model_dialog(self) -> None:
"""Open the model deletion dialog."""
dialog = DeleteModelDialog(self.db_manager, self)
if not dialog.exec():
return
model_ids = dialog.selected_model_ids
if not model_ids:
return
self._delete_models(model_ids)
def _delete_models(self, model_ids: list[int]) -> None:
"""Delete one or more models from the database and remove artifacts from disk."""
deleted_count = 0
removed_paths: list[str] = []
remove_errors: list[str] = []
for model_id in model_ids:
model = None
try:
model = self.db_manager.get_model_by_id(int(model_id))
except Exception as exc:
logger.error(f"Failed to load model {model_id} before deletion: {exc}")
if not model:
remove_errors.append(f"Model id {model_id} not found in database.")
continue
try:
deleted = self.db_manager.delete_model(int(model_id))
except Exception as exc:
logger.error(f"Failed to delete model {model_id}: {exc}")
remove_errors.append(f"Failed to delete model id {model_id} from DB: {exc}")
continue
if not deleted:
remove_errors.append(f"Model id {model_id} was not deleted (already removed?).")
continue
deleted_count += 1
removed, errors = self._delete_model_artifacts_from_disk(model)
removed_paths.extend(removed)
remove_errors.extend(errors)
# Refresh tabs to reflect the deletion(s).
try:
if hasattr(self, "detection_tab"):
self.detection_tab.refresh()
if hasattr(self, "results_tab"):
self.results_tab.refresh()
if hasattr(self, "validation_tab"):
self.validation_tab.refresh()
if hasattr(self, "training_tab"):
self.training_tab.refresh()
except Exception as exc:
logger.warning(f"Failed to refresh tabs after model deletion: {exc}")
details: list[str] = []
if removed_paths:
details.append("Removed from disk:\n" + "\n".join(removed_paths))
if remove_errors:
details.append("\nDisk cleanup warnings:\n" + "\n".join(remove_errors))
QMessageBox.information(
self,
"Delete Model",
f"Deleted {deleted_count} model(s) from database." + ("\n\n" + "\n".join(details) if details else ""),
)
def _delete_model(self, model_id: int) -> None:
"""Delete a model from the database and remove its artifacts from disk."""
model = None
try:
model = self.db_manager.get_model_by_id(model_id)
except Exception as exc:
logger.error(f"Failed to load model {model_id} before deletion: {exc}")
if not model:
QMessageBox.warning(self, "Delete Model", "Selected model was not found in the database.")
return
model_path = str(model.get("model_path") or "")
try:
deleted = self.db_manager.delete_model(model_id)
except Exception as exc:
logger.error(f"Failed to delete model {model_id}: {exc}")
QMessageBox.critical(self, "Delete Model", f"Failed to delete model from database:\n{exc}")
return
if not deleted:
QMessageBox.warning(self, "Delete Model", "No model was deleted (it may have already been removed).")
return
removed_paths, remove_errors = self._delete_model_artifacts_from_disk(model)
# Refresh tabs to reflect the deletion.
try:
if hasattr(self, "detection_tab"):
self.detection_tab.refresh()
if hasattr(self, "results_tab"):
self.results_tab.refresh()
if hasattr(self, "validation_tab"):
self.validation_tab.refresh()
if hasattr(self, "training_tab"):
self.training_tab.refresh()
except Exception as exc:
logger.warning(f"Failed to refresh tabs after model deletion: {exc}")
details = []
if model_path:
details.append(f"Deleted model record for: {model_path}")
if removed_paths:
details.append("\nRemoved from disk:\n" + "\n".join(removed_paths))
if remove_errors:
details.append("\nDisk cleanup warnings:\n" + "\n".join(remove_errors))
QMessageBox.information(
self,
"Delete Model",
"Model deleted from database." + ("\n\n" + "\n".join(details) if details else ""),
)
def _delete_model_artifacts_from_disk(self, model: dict) -> tuple[list[str], list[str]]:
"""Best-effort removal of model artifacts on disk.
Strategy:
- Remove run directories inferred from:
- model.model_path (…/<run>/weights/*.pt => <run>)
- training_params.stage_results[].results.save_dir
but only if they are under the configured models directory.
- If the weights file itself exists and is outside the models directory, delete only the file.
Returns:
(removed_paths, errors)
"""
removed: list[str] = []
errors: list[str] = []
models_root = Path(self.config_manager.get_models_directory() or "data/models").expanduser()
try:
models_root_resolved = models_root.resolve()
except Exception:
models_root_resolved = models_root
inferred_dirs: list[Path] = []
# 1) From model_path
model_path_value = model.get("model_path")
if model_path_value:
try:
p = Path(str(model_path_value)).expanduser()
p_resolved = p.resolve() if p.exists() else p
if p_resolved.is_file():
if p_resolved.parent.name == "weights" and p_resolved.parent.parent.exists():
inferred_dirs.append(p_resolved.parent.parent)
elif p_resolved.parent.exists():
inferred_dirs.append(p_resolved.parent)
except Exception:
pass
# 2) From training_params.stage_results[].results.save_dir
training_params = model.get("training_params") or {}
if isinstance(training_params, dict):
stage_results = training_params.get("stage_results")
if isinstance(stage_results, list):
for stage in stage_results:
results = (stage or {}).get("results")
save_dir = (results or {}).get("save_dir") if isinstance(results, dict) else None
if not save_dir:
continue
try:
d = Path(str(save_dir)).expanduser()
if d.exists() and d.is_dir():
inferred_dirs.append(d)
except Exception:
continue
# Deduplicate inferred_dirs
unique_dirs: list[Path] = []
seen: set[str] = set()
for d in inferred_dirs:
try:
key = str(d.resolve())
except Exception:
key = str(d)
if key in seen:
continue
seen.add(key)
unique_dirs.append(d)
# Delete directories under models_root
for d in unique_dirs:
try:
d_resolved = d.resolve()
except Exception:
d_resolved = d
try:
if d_resolved.exists() and d_resolved.is_dir() and d_resolved.is_relative_to(models_root_resolved):
shutil.rmtree(d_resolved)
removed.append(str(d_resolved))
except Exception as exc:
errors.append(f"Failed to remove directory {d_resolved}: {exc}")
# If nothing matched (e.g., model_path outside models_root), delete just the file.
if model_path_value:
try:
p = Path(str(model_path_value)).expanduser()
if p.exists() and p.is_file():
p_resolved = p.resolve()
if not p_resolved.is_relative_to(models_root_resolved):
p_resolved.unlink()
removed.append(str(p_resolved))
except Exception as exc:
errors.append(f"Failed to remove model file {model_path_value}: {exc}")
return removed, errors
def _show_about(self):
"""Show about dialog."""
@@ -276,6 +532,20 @@ class MainWindow(QMainWindow):
)
if reply == QMessageBox.Yes:
# Save window state before closing
self._save_window_state()
# Persist tab state and stop background work before exit
if hasattr(self, "training_tab"):
self.training_tab.shutdown()
if hasattr(self, "annotation_tab"):
# Best-effort refresh so DB-backed UI state is consistent at shutdown.
try:
self.annotation_tab.refresh()
except Exception:
pass
self.annotation_tab.save_state()
logger.info("Application closing")
event.accept()
else:

View File

@@ -1,23 +1,49 @@
"""
Annotation tab for the microscopy object detection application.
Future feature for manual annotation.
Manual annotation with pen tool and object class management.
"""
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
from PySide6.QtWidgets import (
QWidget,
QVBoxLayout,
QHBoxLayout,
QLabel,
QGroupBox,
QPushButton,
QFileDialog,
QMessageBox,
QSplitter,
QLineEdit,
QTableWidget,
QTableWidgetItem,
QHeaderView,
QAbstractItemView,
)
from PySide6.QtCore import Qt, QSettings
from pathlib import Path
from src.database.db_manager import DatabaseManager
from src.utils.config_manager import ConfigManager
from src.utils.image import Image, ImageLoadError
from src.utils.logger import get_logger
from src.gui.widgets import AnnotationCanvasWidget, AnnotationToolsWidget
logger = get_logger(__name__)
class AnnotationTab(QWidget):
"""Annotation tab placeholder (future feature)."""
"""Annotation tab for manual image annotation."""
def __init__(
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
):
def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
super().__init__(parent)
self.db_manager = db_manager
self.config_manager = config_manager
self.current_image = None
self.current_image_path = None
self.current_image_id = None
self.current_annotations = []
# IDs of annotations currently selected on the canvas (multi-select)
self.selected_annotation_ids = []
self._setup_ui()
@@ -25,24 +51,692 @@ class AnnotationTab(QWidget):
"""Setup user interface."""
layout = QVBoxLayout()
group = QGroupBox("Annotation Tool (Future Feature)")
group_layout = QVBoxLayout()
label = QLabel(
"Annotation functionality will be implemented in future version.\n\n"
"Planned Features:\n"
"- Image browser\n"
"- Drawing tools for bounding boxes\n"
"- Class label assignment\n"
"- Export annotations to YOLO format\n"
"- Annotation verification"
)
group_layout.addWidget(label)
group.setLayout(group_layout)
# Main horizontal splitter to divide left (image) and right (controls)
self.main_splitter = QSplitter(Qt.Horizontal)
self.main_splitter.setHandleWidth(10)
layout.addWidget(group)
layout.addStretch()
# { Left-most pane: annotated images list
annotated_group = QGroupBox("Annotated Images")
annotated_layout = QVBoxLayout()
filter_row = QHBoxLayout()
filter_row.addWidget(QLabel("Filter:"))
self.annotated_filter_edit = QLineEdit()
self.annotated_filter_edit.setPlaceholderText("Type to filter by image name…")
self.annotated_filter_edit.textChanged.connect(self._refresh_annotated_images_list)
filter_row.addWidget(self.annotated_filter_edit, 1)
annotated_layout.addLayout(filter_row)
self.annotated_images_table = QTableWidget(0, 2)
self.annotated_images_table.setHorizontalHeaderLabels(["Image", "Annotations"])
self.annotated_images_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Stretch)
self.annotated_images_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeToContents)
self.annotated_images_table.setSelectionBehavior(QAbstractItemView.SelectRows)
self.annotated_images_table.setSelectionMode(QAbstractItemView.SingleSelection)
self.annotated_images_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
self.annotated_images_table.setSortingEnabled(True)
self.annotated_images_table.itemSelectionChanged.connect(self._on_annotated_image_selected)
annotated_layout.addWidget(self.annotated_images_table, 1)
annotated_group.setLayout(annotated_layout)
# }
# { Left splitter for image display and zoom info
self.left_splitter = QSplitter(Qt.Vertical)
self.left_splitter.setHandleWidth(10)
# Annotation canvas section
canvas_group = QGroupBox("Annotation Canvas")
canvas_layout = QVBoxLayout()
# Use the AnnotationCanvasWidget
self.annotation_canvas = AnnotationCanvasWidget()
# Auto-zoom so newly loaded images fill the available canvas viewport.
# (Matches the behavior used in ResultsTab.)
self.annotation_canvas.set_auto_fit_to_view(True)
self.annotation_canvas.zoom_changed.connect(self._on_zoom_changed)
self.annotation_canvas.annotation_drawn.connect(self._on_annotation_drawn)
# Selection of existing polylines (when tool is not in drawing mode)
self.annotation_canvas.annotation_selected.connect(self._on_annotation_selected)
canvas_layout.addWidget(self.annotation_canvas)
canvas_group.setLayout(canvas_layout)
self.left_splitter.addWidget(canvas_group)
# Controls info
controls_info = QLabel("Zoom: Mouse wheel or +/- keys | Drawing: Enable pen and drag mouse")
controls_info.setStyleSheet("QLabel { color: #888; font-style: italic; }")
self.left_splitter.addWidget(controls_info)
# }
# { Right splitter for annotation tools and controls
self.right_splitter = QSplitter(Qt.Vertical)
self.right_splitter.setHandleWidth(10)
# Annotation tools section
self.annotation_tools = AnnotationToolsWidget(self.db_manager)
self.annotation_tools.polyline_enabled_changed.connect(self.annotation_canvas.set_polyline_enabled)
self.annotation_tools.polyline_pen_color_changed.connect(self.annotation_canvas.set_polyline_pen_color)
self.annotation_tools.polyline_pen_width_changed.connect(self.annotation_canvas.set_polyline_pen_width)
# Show / hide bounding boxes
self.annotation_tools.show_bboxes_changed.connect(self.annotation_canvas.set_show_bboxes)
# RDP simplification controls
self.annotation_tools.simplify_on_finish_changed.connect(self._on_simplify_on_finish_changed)
self.annotation_tools.simplify_epsilon_changed.connect(self._on_simplify_epsilon_changed)
# Class selection and class-color changes
self.annotation_tools.class_selected.connect(self._on_class_selected)
self.annotation_tools.class_color_changed.connect(self._on_class_color_changed)
self.annotation_tools.clear_annotations_requested.connect(self._on_clear_annotations)
# Delete selected annotation on canvas
self.annotation_tools.delete_selected_annotation_requested.connect(self._on_delete_selected_annotation)
self.right_splitter.addWidget(self.annotation_tools)
# Image loading section
load_group = QGroupBox("Image Loading")
load_layout = QVBoxLayout()
# Load image button
button_layout = QHBoxLayout()
self.load_image_btn = QPushButton("Load Image")
self.load_image_btn.clicked.connect(self._load_image)
button_layout.addWidget(self.load_image_btn)
button_layout.addStretch()
load_layout.addLayout(button_layout)
# Image info label
self.image_info_label = QLabel("No image loaded")
load_layout.addWidget(self.image_info_label)
load_group.setLayout(load_layout)
self.right_splitter.addWidget(load_group)
# }
# Add list + both splitters to the main horizontal splitter
self.main_splitter.addWidget(annotated_group)
self.main_splitter.addWidget(self.left_splitter)
self.main_splitter.addWidget(self.right_splitter)
# Set initial sizes: list (left), canvas (middle), controls (right)
self.main_splitter.setSizes([320, 650, 280])
layout.addWidget(self.main_splitter)
self.setLayout(layout)
# Restore splitter positions from settings
self._restore_state()
# Populate list on startup.
self._refresh_annotated_images_list()
def _load_image(self):
"""Load and display an image file."""
# Get last opened directory from QSettings
settings = QSettings("microscopy_app", "object_detection")
last_dir = settings.value("annotation_tab/last_directory", None)
# Fallback to image repository path or home directory
if last_dir and Path(last_dir).exists():
start_dir = last_dir
else:
repo_path = self.config_manager.get_image_repository_path()
start_dir = repo_path if repo_path else str(Path.home())
# Open file dialog
file_path, _ = QFileDialog.getOpenFileName(
self,
"Select Image",
start_dir,
"Images (*" + " *".join(Image.SUPPORTED_EXTENSIONS) + ")",
)
if not file_path:
return
try:
# Load image using Image class
self.current_image = Image(file_path)
self.current_image_path = file_path
# Store the directory for next time
settings.setValue("annotation_tab/last_directory", str(Path(file_path).parent))
# Get or create image in database
repo_root = self.config_manager.get_image_repository_path()
relative_path: str
try:
if repo_root:
repo_root_path = Path(repo_root).expanduser().resolve()
file_resolved = Path(file_path).expanduser().resolve()
if file_resolved.is_relative_to(repo_root_path):
relative_path = file_resolved.relative_to(repo_root_path).as_posix()
else:
# Fallback: store filename only to avoid leaking absolute paths.
relative_path = file_resolved.name
else:
relative_path = str(Path(file_path).name)
except Exception:
relative_path = str(Path(file_path).name)
self.current_image_id = self.db_manager.get_or_create_image(
relative_path,
Path(file_path).name,
self.current_image.width,
self.current_image.height,
)
# Display image using the AnnotationCanvasWidget
self.annotation_canvas.load_image(self.current_image)
# Load and display any existing annotations for this image
self._load_annotations_for_current_image()
# Update annotated images list (newly annotated image added/selected).
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
# Update info label
self._update_image_info()
logger.info(f"Loaded image: {file_path} (DB ID: {self.current_image_id})")
except ImageLoadError as e:
logger.error(f"Failed to load image: {e}")
QMessageBox.critical(self, "Error Loading Image", f"Failed to load image:\n{str(e)}")
except Exception as e:
logger.error(f"Unexpected error loading image: {e}")
QMessageBox.critical(self, "Error", f"Unexpected error:\n{str(e)}")
def _update_image_info(self):
"""Update the image info label with current image details."""
if self.current_image is None:
self.image_info_label.setText("No image loaded")
return
zoom_percentage = self.annotation_canvas.get_zoom_percentage()
info_text = (
f"File: {Path(self.current_image_path).name}\n"
f"Size: {self.current_image.width}x{self.current_image.height} pixels\n"
f"Channels: {self.current_image.channels}\n"
f"Data type: {self.current_image.dtype}\n"
f"Format: {self.current_image.format.upper()}\n"
f"File size: {self.current_image.size_mb:.2f} MB\n"
f"Zoom: {zoom_percentage}%"
)
self.image_info_label.setText(info_text)
def _on_zoom_changed(self, zoom_scale: float):
"""Handle zoom level changes from the annotation canvas."""
self._update_image_info()
def _on_annotation_drawn(self, points: list):
"""
Handle when an annotation stroke is drawn.
Saves the new annotation directly to the database and refreshes the
on-canvas display of annotations for the current image.
"""
# Ensure we have an image loaded and in the DB
if not self.current_image or not self.current_image_id:
logger.warning("Annotation drawn but no image loaded")
QMessageBox.warning(
self,
"No Image",
"Please load an image before drawing annotations.",
)
return
current_class = self.annotation_tools.get_current_class()
if not current_class:
logger.warning("Annotation drawn but no object class selected")
QMessageBox.warning(
self,
"No Class Selected",
"Please select an object class before drawing annotations.",
)
return
if not points:
logger.warning("Annotation drawn with no points, ignoring")
return
# points are [(x_norm, y_norm), ...]
xs = [p[0] for p in points]
ys = [p[1] for p in points]
x_min, x_max = min(xs), max(xs)
y_min, y_max = min(ys), max(ys)
# Store segmentation mask in [y_norm, x_norm] format to match DB
db_polyline = [[float(y), float(x)] for (x, y) in points]
try:
annotation_id = self.db_manager.add_annotation(
image_id=self.current_image_id,
class_id=current_class["id"],
bbox=(x_min, y_min, x_max, y_max),
annotator="manual",
segmentation_mask=db_polyline,
verified=False,
)
logger.info(
f"Saved annotation (ID: {annotation_id}) for class "
f"'{current_class['class_name']}' "
f"Bounding box: ({x_min:.3f}, {y_min:.3f}) to ({x_max:.3f}, {y_max:.3f})\n"
f"with {len(points)} polyline points"
)
# Reload annotations from DB and redraw (respecting current class filter)
self._load_annotations_for_current_image()
# Update list counts.
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
except Exception as e:
logger.error(f"Failed to save annotation: {e}")
QMessageBox.critical(self, "Error", f"Failed to save annotation:\n{str(e)}")
def _on_annotation_selected(self, annotation_ids):
"""
Handle selection of existing annotations on the canvas.
Args:
annotation_ids: List of selected annotation IDs, or None/empty if cleared.
"""
if not annotation_ids:
self.selected_annotation_ids = []
self.annotation_tools.set_has_selected_annotation(False)
logger.debug("Annotation selection cleared on canvas")
return
# Normalize to a unique, sorted list of integer IDs
ids = sorted({int(aid) for aid in annotation_ids if isinstance(aid, int)})
self.selected_annotation_ids = ids
self.annotation_tools.set_has_selected_annotation(bool(ids))
logger.debug(f"Annotations selected on canvas: IDs={ids}")
def _on_simplify_on_finish_changed(self, enabled: bool):
"""Update canvas simplify-on-finish flag from tools widget."""
self.annotation_canvas.simplify_on_finish = enabled
logger.debug(f"Annotation simplification on finish set to {enabled}")
def _on_simplify_epsilon_changed(self, epsilon: float):
"""Update canvas RDP epsilon from tools widget."""
self.annotation_canvas.simplify_epsilon = float(epsilon)
logger.debug(f"Annotation simplification epsilon set to {epsilon}")
def _on_class_color_changed(self):
"""
Handle changes to the selected object's class color.
When the user updates a class color in the tools widget, reload the
annotations for the current image so that all polylines are redrawn
using the updated per-class colors.
"""
if not self.current_image_id:
return
logger.debug(f"Class color changed; reloading annotations for image ID {self.current_image_id}")
self._load_annotations_for_current_image()
def _on_class_selected(self, class_data):
"""
Handle when an object class is selected or cleared.
When a specific class is selected, only annotations of that class are drawn.
When the selection is cleared ("-- Select Class --"), all annotations are shown.
"""
if class_data:
logger.debug(f"Object class selected: {class_data['class_name']}")
else:
logger.debug('No class selected ("-- Select Class --"), showing all annotations')
# Changing the class filter invalidates any previous selection
self.selected_annotation_ids = []
self.annotation_tools.set_has_selected_annotation(False)
# Whenever the selection changes, update which annotations are visible
self._redraw_annotations_for_current_filter()
def _on_clear_annotations(self):
"""Handle clearing all annotations."""
self.annotation_canvas.clear_annotations()
# Clear in-memory state and selection, but keep DB entries unchanged
self.current_annotations = []
self.selected_annotation_ids = []
self.annotation_tools.set_has_selected_annotation(False)
logger.info("Cleared all annotations")
def _on_delete_selected_annotation(self):
"""Handle deleting the currently selected annotation(s) (if any)."""
if not self.selected_annotation_ids:
QMessageBox.information(
self,
"No Selection",
"No annotation is currently selected.",
)
return
count = len(self.selected_annotation_ids)
if count == 1:
question = "Are you sure you want to delete the selected annotation?"
title = "Delete Annotation"
else:
question = f"Are you sure you want to delete the {count} selected annotations?"
title = "Delete Annotations"
reply = QMessageBox.question(
self,
title,
question,
QMessageBox.Yes | QMessageBox.No,
QMessageBox.No,
)
if reply != QMessageBox.Yes:
return
failed_ids = []
try:
for ann_id in self.selected_annotation_ids:
try:
deleted = self.db_manager.delete_annotation(ann_id)
if not deleted:
failed_ids.append(ann_id)
except Exception as e:
logger.error(f"Failed to delete annotation ID {ann_id}: {e}")
failed_ids.append(ann_id)
if failed_ids:
QMessageBox.warning(
self,
"Partial Failure",
"Some annotations could not be deleted:\n" + ", ".join(str(a) for a in failed_ids),
)
else:
logger.info(
f"Deleted {count} annotation(s): " + ", ".join(str(a) for a in self.selected_annotation_ids)
)
# Clear selection and reload annotations for the current image from DB
self.selected_annotation_ids = []
self.annotation_tools.set_has_selected_annotation(False)
self._load_annotations_for_current_image()
# Update list counts.
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
except Exception as e:
logger.error(f"Failed to delete annotations: {e}")
QMessageBox.critical(
self,
"Error",
f"Failed to delete annotations:\n{str(e)}",
)
def _load_annotations_for_current_image(self):
"""
Load all annotations for the current image from the database and
redraw them on the canvas, honoring the currently selected class
filter (if any).
"""
if not self.current_image_id:
self.current_annotations = []
self.annotation_canvas.clear_annotations()
self.selected_annotation_ids = []
self.annotation_tools.set_has_selected_annotation(False)
return
try:
self.current_annotations = self.db_manager.get_annotations_for_image(self.current_image_id)
# New annotations loaded; reset any selection
self.selected_annotation_ids = []
self.annotation_tools.set_has_selected_annotation(False)
self._redraw_annotations_for_current_filter()
except Exception as e:
logger.error(f"Failed to load annotations for image {self.current_image_id}: {e}")
QMessageBox.critical(
self,
"Error",
f"Failed to load annotations for this image:\n{str(e)}",
)
def _redraw_annotations_for_current_filter(self):
"""
Redraw annotations for the current image, optionally filtered by the
currently selected object class.
"""
# Clear current on-canvas annotations but keep the image
self.annotation_canvas.clear_annotations()
if not self.current_annotations:
return
current_class = self.annotation_tools.get_current_class()
selected_class_id = current_class["id"] if current_class else None
drawn_count = 0
for ann in self.current_annotations:
# Filter by class if one is selected
if selected_class_id is not None and ann.get("class_id") != selected_class_id:
continue
if ann.get("segmentation_mask"):
polyline = ann["segmentation_mask"]
color = ann.get("class_color", "#FF0000")
self.annotation_canvas.draw_saved_polyline(
polyline,
color,
width=3,
annotation_id=ann["id"],
)
self.annotation_canvas.draw_saved_bbox(
[ann["x_min"], ann["y_min"], ann["x_max"], ann["y_max"]],
color,
width=3,
)
drawn_count += 1
logger.info(
f"Displayed {drawn_count} annotation(s) for current image with "
f"{'no class filter' if selected_class_id is None else f'class_id={selected_class_id}'}"
)
def _restore_state(self):
"""Restore splitter positions from settings."""
settings = QSettings("microscopy_app", "object_detection")
# Restore main splitter state
main_state = settings.value("annotation_tab/main_splitter_state")
if main_state:
self.main_splitter.restoreState(main_state)
logger.debug("Restored main splitter state")
# Restore left splitter state
left_state = settings.value("annotation_tab/left_splitter_state")
if left_state:
self.left_splitter.restoreState(left_state)
logger.debug("Restored left splitter state")
# Restore right splitter state
right_state = settings.value("annotation_tab/right_splitter_state")
if right_state:
self.right_splitter.restoreState(right_state)
logger.debug("Restored right splitter state")
def save_state(self):
"""Save splitter positions to settings."""
settings = QSettings("microscopy_app", "object_detection")
# Save main splitter state
settings.setValue("annotation_tab/main_splitter_state", self.main_splitter.saveState())
# Save left splitter state
settings.setValue("annotation_tab/left_splitter_state", self.left_splitter.saveState())
# Save right splitter state
settings.setValue("annotation_tab/right_splitter_state", self.right_splitter.saveState())
logger.debug("Saved annotation tab splitter states")
def refresh(self):
"""Refresh the tab."""
pass
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
# ==================== Annotated images list ====================
def _refresh_annotated_images_list(self, select_image_id: int | None = None) -> None:
"""Reload annotated-images list from the database."""
if not hasattr(self, "annotated_images_table"):
return
# Preserve selection if possible
desired_id = select_image_id if select_image_id is not None else self.current_image_id
name_filter = ""
if hasattr(self, "annotated_filter_edit"):
name_filter = self.annotated_filter_edit.text().strip()
try:
rows = self.db_manager.get_annotated_images_summary(name_filter=name_filter)
except Exception as exc:
logger.error(f"Failed to load annotated images summary: {exc}")
rows = []
sorting_enabled = self.annotated_images_table.isSortingEnabled()
self.annotated_images_table.setSortingEnabled(False)
self.annotated_images_table.blockSignals(True)
try:
self.annotated_images_table.setRowCount(len(rows))
for r, entry in enumerate(rows):
image_name = str(entry.get("filename") or "")
count = int(entry.get("annotation_count") or 0)
rel_path = str(entry.get("relative_path") or "")
name_item = QTableWidgetItem(image_name)
# Tooltip shows full path of the image (best-effort: repository_root + relative_path)
full_path = rel_path
repo_root = self.config_manager.get_image_repository_path()
if repo_root and rel_path and not Path(rel_path).is_absolute():
try:
full_path = str((Path(repo_root) / rel_path).resolve())
except Exception:
full_path = str(Path(repo_root) / rel_path)
name_item.setToolTip(full_path)
name_item.setData(Qt.UserRole, int(entry.get("id")))
name_item.setData(Qt.UserRole + 1, rel_path)
count_item = QTableWidgetItem()
# Use EditRole to ensure numeric sorting.
count_item.setData(Qt.EditRole, count)
count_item.setData(Qt.UserRole, int(entry.get("id")))
count_item.setData(Qt.UserRole + 1, rel_path)
self.annotated_images_table.setItem(r, 0, name_item)
self.annotated_images_table.setItem(r, 1, count_item)
# Re-select desired row
if desired_id is not None:
for r in range(self.annotated_images_table.rowCount()):
item = self.annotated_images_table.item(r, 0)
if item and item.data(Qt.UserRole) == desired_id:
self.annotated_images_table.selectRow(r)
break
finally:
self.annotated_images_table.blockSignals(False)
self.annotated_images_table.setSortingEnabled(sorting_enabled)
def _on_annotated_image_selected(self) -> None:
"""When user clicks an item in the list, load that image in the annotation canvas."""
selected = self.annotated_images_table.selectedItems()
if not selected:
return
# Row selection -> take the first column item
row = self.annotated_images_table.currentRow()
item = self.annotated_images_table.item(row, 0)
if not item:
return
image_id = item.data(Qt.UserRole)
rel_path = item.data(Qt.UserRole + 1) or ""
if not image_id:
return
image_path = self._resolve_image_path_for_relative_path(rel_path)
if not image_path:
QMessageBox.warning(
self,
"Image Not Found",
"Unable to locate image on disk for:\n"
f"{rel_path}\n\n"
"Tip: set Settings → Image repository path to the folder containing your images.",
)
return
try:
self.current_image = Image(image_path)
self.current_image_path = image_path
self.current_image_id = int(image_id)
self.annotation_canvas.load_image(self.current_image)
self._load_annotations_for_current_image()
self._update_image_info()
except ImageLoadError as exc:
logger.error(f"Failed to load image '{image_path}': {exc}")
QMessageBox.critical(self, "Error Loading Image", f"Failed to load image:\n{exc}")
except Exception as exc:
logger.error(f"Unexpected error loading image '{image_path}': {exc}")
QMessageBox.critical(self, "Error", f"Unexpected error:\n{exc}")
def _resolve_image_path_for_relative_path(self, relative_path: str) -> str | None:
"""Best-effort conversion from a DB relative_path to an on-disk file path."""
rel = (relative_path or "").strip()
if not rel:
return None
candidates: list[Path] = []
# 1) Repository root + relative
repo_root = (self.config_manager.get_image_repository_path() or "").strip()
if repo_root:
candidates.append(Path(repo_root) / rel)
# 2) If the DB path is absolute, try it directly.
candidates.append(Path(rel))
# 3) Try the directory of the currently loaded image (helps when DB stores only filenames)
if self.current_image_path:
try:
candidates.append(Path(self.current_image_path).expanduser().resolve().parent / Path(rel).name)
except Exception:
pass
# 4) Try the last directory used by the annotation file picker
try:
settings = QSettings("microscopy_app", "object_detection")
last_dir = settings.value("annotation_tab/last_directory", None)
if last_dir:
candidates.append(Path(str(last_dir)) / Path(rel).name)
except Exception:
pass
for p in candidates:
try:
expanded = p.expanduser()
if expanded.exists() and expanded.is_file():
return str(expanded.resolve())
except Exception:
continue
# 5) Fallback: search by filename within repository root.
filename = Path(rel).name
if repo_root and filename:
root = Path(repo_root).expanduser()
try:
if root.exists():
for match in root.rglob(filename):
if match.is_file():
return str(match.resolve())
except Exception as exc:
logger.debug(f"Search for {filename} under {root} failed: {exc}")
return None

View File

@@ -20,12 +20,14 @@ from PySide6.QtWidgets import (
)
from PySide6.QtCore import Qt, QThread, Signal
from pathlib import Path
from typing import Optional
from src.database.db_manager import DatabaseManager
from src.utils.config_manager import ConfigManager
from src.utils.logger import get_logger
from src.utils.file_utils import get_image_files
from src.model.inference import InferenceEngine
from src.utils.image import Image
logger = get_logger(__name__)
@@ -147,30 +149,66 @@ class DetectionTab(QWidget):
self.model_combo.currentIndexChanged.connect(self._on_model_changed)
def _load_models(self):
"""Load available models from database."""
"""Load available models from database and local storage."""
try:
models = self.db_manager.get_models()
self.model_combo.clear()
models = self.db_manager.get_models()
has_models = False
if not models:
self.model_combo.addItem("No models available", None)
self._set_buttons_enabled(False)
return
known_paths = set()
# Add base model option
# Add base model option first (always available)
base_model = self.config_manager.get(
"models.default_base_model", "yolov8s.pt"
)
self.model_combo.addItem(
f"Base Model ({base_model})", {"id": 0, "path": base_model}
"models.default_base_model", "yolov8s-seg.pt"
)
if base_model:
base_data = {
"id": 0,
"path": base_model,
"model_name": Path(base_model).stem or "Base Model",
"model_version": "pretrained",
"base_model": base_model,
"source": "base",
}
self.model_combo.addItem(f"Base Model ({base_model})", base_data)
known_paths.add(self._normalize_model_path(base_model))
has_models = True
# Add trained models
# Add trained models from database
for model in models:
display_name = f"{model['model_name']} v{model['model_version']}"
self.model_combo.addItem(display_name, model)
model_data = {**model, "path": model.get("model_path")}
normalized = self._normalize_model_path(model_data.get("path"))
if normalized:
known_paths.add(normalized)
self.model_combo.addItem(display_name, model_data)
has_models = True
self._set_buttons_enabled(True)
# Discover local model files not yet in the database
local_models = self._discover_local_models()
for model_path in local_models:
normalized = self._normalize_model_path(model_path)
if normalized in known_paths:
continue
display_name = f"Local Model ({Path(model_path).stem})"
model_data = {
"id": None,
"path": str(model_path),
"model_name": Path(model_path).stem,
"model_version": "local",
"base_model": Path(model_path).stem,
"source": "local",
}
self.model_combo.addItem(display_name, model_data)
known_paths.add(normalized)
has_models = True
if not has_models:
self.model_combo.addItem("No models available", None)
self._set_buttons_enabled(False)
else:
self._set_buttons_enabled(True)
except Exception as e:
logger.error(f"Error loading models: {e}")
@@ -199,7 +237,7 @@ class DetectionTab(QWidget):
self,
"Select Image",
start_dir,
"Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)",
"Images (*" + " *".join(Image.SUPPORTED_EXTENSIONS) + ")",
)
if not file_path:
@@ -249,25 +287,39 @@ class DetectionTab(QWidget):
QMessageBox.warning(self, "No Model", "Please select a model first.")
return
model_path = model_data["path"]
model_id = model_data["id"]
model_path = model_data.get("path")
if not model_path:
QMessageBox.warning(
self, "Invalid Model", "Selected model is missing a file path."
)
return
# Ensure we have a valid model ID (create entry for base model if needed)
if model_id == 0:
# Create database entry for base model
base_model = self.config_manager.get(
"models.default_base_model", "yolov8s.pt"
)
model_id = self.db_manager.add_model(
model_name="Base Model",
model_version="pretrained",
model_path=base_model,
base_model=base_model,
if not Path(model_path).exists():
QMessageBox.critical(
self,
"Model Not Found",
f"The selected model file could not be found:\n{model_path}",
)
return
model_id = model_data.get("id")
# Ensure we have a database entry for the selected model
if model_id in (None, 0):
model_id = self._ensure_model_record(model_data)
if not model_id:
QMessageBox.critical(
self,
"Model Registration Failed",
"Unable to register the selected model in the database.",
)
return
normalized_model_path = self._normalize_model_path(model_path) or model_path
# Create inference engine
self.inference_engine = InferenceEngine(
model_path, self.db_manager, model_id
normalized_model_path, self.db_manager, model_id
)
# Get confidence threshold
@@ -338,6 +390,76 @@ class DetectionTab(QWidget):
self.batch_btn.setEnabled(enabled)
self.model_combo.setEnabled(enabled)
def _discover_local_models(self) -> list:
"""Scan the models directory for standalone .pt files."""
models_dir = self.config_manager.get_models_directory()
if not models_dir:
return []
models_path = Path(models_dir)
if not models_path.exists():
return []
try:
return sorted(
[p for p in models_path.rglob("*.pt") if p.is_file()],
key=lambda p: str(p).lower(),
)
except Exception as e:
logger.warning(f"Error discovering local models: {e}")
return []
def _normalize_model_path(self, path_value) -> str:
"""Return a normalized absolute path string for comparison."""
if not path_value:
return ""
try:
return str(Path(path_value).resolve())
except Exception:
return str(path_value)
def _ensure_model_record(self, model_data: dict) -> Optional[int]:
"""Ensure a database record exists for the selected model."""
model_path = model_data.get("path")
if not model_path:
return None
normalized_target = self._normalize_model_path(model_path)
try:
existing_models = self.db_manager.get_models()
for model in existing_models:
existing_path = model.get("model_path")
if not existing_path:
continue
normalized_existing = self._normalize_model_path(existing_path)
if (
normalized_existing == normalized_target
or existing_path == model_path
):
return model["id"]
model_name = (
model_data.get("model_name") or Path(model_path).stem or "Custom Model"
)
model_version = (
model_data.get("model_version") or model_data.get("source") or "local"
)
base_model = model_data.get(
"base_model",
self.config_manager.get("models.default_base_model", "yolov8s-seg.pt"),
)
return self.db_manager.add_model(
model_name=model_name,
model_version=model_version,
model_path=normalized_target,
base_model=base_model,
)
except Exception as e:
logger.error(f"Failed to ensure model record for {model_path}: {e}")
return None
def refresh(self):
"""Refresh the tab."""
self._load_models()

View File

@@ -1,46 +1,699 @@
"""
Results tab for the microscopy object detection application.
Results tab for browsing stored detections and visualizing overlays.
"""
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from PySide6.QtWidgets import (
QWidget,
QVBoxLayout,
QHBoxLayout,
QLabel,
QGroupBox,
QPushButton,
QSplitter,
QTableWidget,
QTableWidgetItem,
QHeaderView,
QAbstractItemView,
QMessageBox,
QCheckBox,
)
from PySide6.QtCore import Qt
from src.database.db_manager import DatabaseManager
from src.utils.config_manager import ConfigManager
from src.utils.logger import get_logger
from src.utils.image import Image, ImageLoadError
from src.gui.widgets import AnnotationCanvasWidget
logger = get_logger(__name__)
class ResultsTab(QWidget):
"""Results tab placeholder."""
"""Results tab showing detection history and preview overlays."""
def __init__(
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
):
def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
super().__init__(parent)
self.db_manager = db_manager
self.config_manager = config_manager
self.detection_summary: List[Dict] = []
self.current_selection: Optional[Dict] = None
self.current_image: Optional[Image] = None
self.current_detections: List[Dict] = []
self._image_path_cache: Dict[str, str] = {}
self._setup_ui()
self.refresh()
def _setup_ui(self):
"""Setup user interface."""
layout = QVBoxLayout()
group = QGroupBox("Results")
group_layout = QVBoxLayout()
label = QLabel(
"Results viewer will be implemented here.\n\n"
"Features:\n"
"- Detection history browser\n"
"- Advanced filtering\n"
"- Statistics dashboard\n"
"- Export functionality"
)
group_layout.addWidget(label)
group.setLayout(group_layout)
# Splitter for list + preview
splitter = QSplitter(Qt.Horizontal)
layout.addWidget(group)
layout.addStretch()
# Left pane: detection list
left_container = QWidget()
left_layout = QVBoxLayout()
left_layout.setContentsMargins(0, 0, 0, 0)
controls_layout = QHBoxLayout()
self.refresh_btn = QPushButton("Refresh")
self.refresh_btn.clicked.connect(self.refresh)
controls_layout.addWidget(self.refresh_btn)
self.delete_all_btn = QPushButton("Delete All Detections")
self.delete_all_btn.setToolTip(
"Permanently delete ALL detections from the database.\n" "This cannot be undone."
)
self.delete_all_btn.clicked.connect(self._delete_all_detections)
controls_layout.addWidget(self.delete_all_btn)
self.export_labels_btn = QPushButton("Export Labels")
self.export_labels_btn.setToolTip(
"Export YOLO .txt labels for the selected image/model run.\n"
"Output path is inferred from the image path (images/ -> labels/)."
)
self.export_labels_btn.clicked.connect(self._export_labels_for_current_selection)
controls_layout.addWidget(self.export_labels_btn)
controls_layout.addStretch()
left_layout.addLayout(controls_layout)
self.results_table = QTableWidget(0, 5)
self.results_table.setHorizontalHeaderLabels(["Image", "Model", "Detections", "Classes", "Last Updated"])
self.results_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Stretch)
self.results_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.Stretch)
self.results_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeToContents)
self.results_table.horizontalHeader().setSectionResizeMode(3, QHeaderView.Stretch)
self.results_table.horizontalHeader().setSectionResizeMode(4, QHeaderView.ResizeToContents)
self.results_table.setSelectionBehavior(QAbstractItemView.SelectRows)
self.results_table.setSelectionMode(QAbstractItemView.SingleSelection)
self.results_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
self.results_table.itemSelectionChanged.connect(self._on_result_selected)
left_layout.addWidget(self.results_table)
left_container.setLayout(left_layout)
# Right pane: preview canvas and controls
right_container = QWidget()
right_layout = QVBoxLayout()
right_layout.setContentsMargins(0, 0, 0, 0)
preview_group = QGroupBox("Detection Preview")
preview_layout = QVBoxLayout()
self.preview_canvas = AnnotationCanvasWidget()
# Auto-zoom so newly loaded images fill the available preview viewport.
self.preview_canvas.set_auto_fit_to_view(True)
self.preview_canvas.set_polyline_enabled(False)
self.preview_canvas.set_show_bboxes(True)
preview_layout.addWidget(self.preview_canvas)
toggles_layout = QHBoxLayout()
self.show_masks_checkbox = QCheckBox("Show Masks")
self.show_masks_checkbox.setChecked(False)
self.show_masks_checkbox.stateChanged.connect(self._apply_detection_overlays)
self.show_bboxes_checkbox = QCheckBox("Show Bounding Boxes")
self.show_bboxes_checkbox.setChecked(True)
self.show_bboxes_checkbox.stateChanged.connect(self._toggle_bboxes)
self.show_confidence_checkbox = QCheckBox("Show Confidence")
self.show_confidence_checkbox.setChecked(False)
self.show_confidence_checkbox.stateChanged.connect(self._apply_detection_overlays)
toggles_layout.addWidget(self.show_masks_checkbox)
toggles_layout.addWidget(self.show_bboxes_checkbox)
toggles_layout.addWidget(self.show_confidence_checkbox)
toggles_layout.addStretch()
preview_layout.addLayout(toggles_layout)
self.summary_label = QLabel("Select a detection result to preview.")
self.summary_label.setWordWrap(True)
preview_layout.addWidget(self.summary_label)
preview_group.setLayout(preview_layout)
right_layout.addWidget(preview_group)
right_container.setLayout(right_layout)
splitter.addWidget(left_container)
splitter.addWidget(right_container)
splitter.setStretchFactor(0, 1)
splitter.setStretchFactor(1, 2)
layout.addWidget(splitter)
self.setLayout(layout)
def _delete_all_detections(self):
"""Delete all detections from the database after user confirmation."""
confirm = QMessageBox.warning(
self,
"Delete All Detections",
"This will permanently delete ALL detections from the database.\n\n"
"This action cannot be undone.\n\n"
"Do you want to continue?",
QMessageBox.Yes | QMessageBox.No,
QMessageBox.No,
)
if confirm != QMessageBox.Yes:
return
try:
deleted = self.db_manager.delete_all_detections()
except Exception as exc:
logger.error(f"Failed to delete all detections: {exc}")
QMessageBox.critical(
self,
"Error",
f"Failed to delete detections:\n{exc}",
)
return
QMessageBox.information(
self,
"Delete All Detections",
f"Deleted {deleted} detection(s) from the database.",
)
# Reset UI state.
self.refresh()
def refresh(self):
"""Refresh the tab."""
pass
"""Refresh the detection list and preview."""
self._load_detection_summary()
self._populate_results_table()
self.current_selection = None
self.current_image = None
self.current_detections = []
self.preview_canvas.clear()
self.summary_label.setText("Select a detection result to preview.")
if hasattr(self, "export_labels_btn"):
self.export_labels_btn.setEnabled(False)
def _load_detection_summary(self):
"""Load latest detection summaries grouped by image + model."""
try:
detections = self.db_manager.get_detections(limit=500)
summary_map: Dict[tuple, Dict] = {}
for det in detections:
key = (det["image_id"], det["model_id"])
metadata = det.get("metadata") or {}
entry = summary_map.setdefault(
key,
{
"image_id": det["image_id"],
"model_id": det["model_id"],
"image_path": det.get("image_path"),
"image_filename": det.get("image_filename") or det.get("image_path"),
"model_name": det.get("model_name", ""),
"model_version": det.get("model_version", ""),
"last_detected": det.get("detected_at"),
"count": 0,
"classes": set(),
"source_path": metadata.get("source_path"),
"repository_root": metadata.get("repository_root"),
},
)
entry["count"] += 1
if det.get("detected_at") and (
not entry.get("last_detected") or str(det.get("detected_at")) > str(entry.get("last_detected"))
):
entry["last_detected"] = det.get("detected_at")
if det.get("class_name"):
entry["classes"].add(det["class_name"])
if metadata.get("source_path") and not entry.get("source_path"):
entry["source_path"] = metadata.get("source_path")
if metadata.get("repository_root") and not entry.get("repository_root"):
entry["repository_root"] = metadata.get("repository_root")
self.detection_summary = sorted(
summary_map.values(),
key=lambda x: str(x.get("last_detected") or ""),
reverse=True,
)
except Exception as e:
logger.error(f"Failed to load detection summary: {e}")
QMessageBox.critical(
self,
"Error",
f"Failed to load detection results:\n{str(e)}",
)
self.detection_summary = []
def _populate_results_table(self):
"""Populate the table widget with detection summaries."""
self.results_table.setRowCount(len(self.detection_summary))
for row, entry in enumerate(self.detection_summary):
model_label = f"{entry['model_name']} {entry['model_version']}".strip()
class_list = ", ".join(sorted(entry["classes"])) if entry["classes"] else "-"
items = [
QTableWidgetItem(entry.get("image_filename", "")),
QTableWidgetItem(model_label),
QTableWidgetItem(str(entry.get("count", 0))),
QTableWidgetItem(class_list),
QTableWidgetItem(str(entry.get("last_detected") or "")),
]
for col, item in enumerate(items):
item.setData(Qt.UserRole, row)
self.results_table.setItem(row, col, item)
self.results_table.clearSelection()
def _on_result_selected(self):
"""Handle selection changes in the detection table."""
selected_items = self.results_table.selectedItems()
if not selected_items:
return
row = selected_items[0].data(Qt.UserRole)
if row is None or row >= len(self.detection_summary):
return
entry = self.detection_summary[row]
if (
self.current_selection
and self.current_selection.get("image_id") == entry["image_id"]
and self.current_selection.get("model_id") == entry["model_id"]
):
return
self.current_selection = entry
image_path = self._resolve_image_path(entry)
if not image_path:
QMessageBox.warning(
self,
"Image Not Found",
"Unable to locate the image file for this detection.",
)
return
try:
self.current_image = Image(image_path)
self.preview_canvas.load_image(self.current_image)
except ImageLoadError as e:
logger.error(f"Failed to load image '{image_path}': {e}")
QMessageBox.critical(
self,
"Image Error",
f"Failed to load image for preview:\n{str(e)}",
)
return
self._load_detections_for_selection(entry)
self._apply_detection_overlays()
self._update_summary_label(entry)
if hasattr(self, "export_labels_btn"):
self.export_labels_btn.setEnabled(True)
def _export_labels_for_current_selection(self):
"""Export YOLO label file(s) for the currently selected image/model."""
if not self.current_selection:
QMessageBox.information(self, "Export Labels", "Select a detection result first.")
return
entry = self.current_selection
image_path_str = self._resolve_image_path(entry)
if not image_path_str:
QMessageBox.warning(
self,
"Export Labels",
"Unable to locate the image file for this detection; cannot infer labels path.",
)
return
# Ensure we have the detections for the selection.
if not self.current_detections:
self._load_detections_for_selection(entry)
if not self.current_detections:
QMessageBox.information(
self,
"Export Labels",
"No detections found for this image/model selection.",
)
return
image_path = Path(image_path_str)
try:
label_path = self._infer_yolo_label_path(image_path)
except Exception as exc:
logger.error(f"Failed to infer label path for {image_path}: {exc}")
QMessageBox.critical(
self,
"Export Labels",
f"Failed to infer export path for labels:\n{exc}",
)
return
class_map = self._build_detection_class_index_map(self.current_detections)
if not class_map:
QMessageBox.warning(
self,
"Export Labels",
"Unable to build class->index mapping (missing class names).",
)
return
lines_written = 0
skipped = 0
label_path.parent.mkdir(parents=True, exist_ok=True)
try:
with open(label_path, "w", encoding="utf-8") as handle:
print("writing to", label_path)
for det in self.current_detections:
yolo_line = self._format_detection_as_yolo_line(det, class_map)
if not yolo_line:
skipped += 1
continue
handle.write(yolo_line + "\n")
lines_written += 1
except OSError as exc:
logger.error(f"Failed to write labels file {label_path}: {exc}")
QMessageBox.critical(
self,
"Export Labels",
f"Failed to write label file:\n{label_path}\n\n{exc}",
)
return
return
# Optional: write a classes.txt next to the labels root to make the mapping discoverable.
# This is not required by Ultralytics (data.yaml usually holds class names), but helps reuse.
try:
classes_txt = label_path.parent.parent / "classes.txt"
classes_txt.parent.mkdir(parents=True, exist_ok=True)
inv = {idx: name for name, idx in class_map.items()}
with open(classes_txt, "w", encoding="utf-8") as handle:
for idx in range(len(inv)):
handle.write(f"{inv[idx]}\n")
except Exception:
# Non-fatal
pass
QMessageBox.information(
self,
"Export Labels",
f"Exported {lines_written} label line(s) to:\n{label_path}\n\nSkipped {skipped} invalid detection(s).",
)
def _infer_yolo_label_path(self, image_path: Path) -> Path:
"""Infer a YOLO label path from an image path.
If the image lives under an `images/` directory (anywhere in the path), we mirror the
subpath under a sibling `labels/` directory at the same level.
Example:
/dataset/train/images/sub/img.jpg -> /dataset/train/labels/sub/img.txt
"""
resolved = image_path.expanduser().resolve()
# Find the nearest ancestor directory named 'images'
images_dir: Optional[Path] = None
for parent in [resolved.parent, *resolved.parents]:
if parent.name.lower() == "images":
images_dir = parent
break
if images_dir is not None:
rel = resolved.relative_to(images_dir)
labels_dir = images_dir.parent / "labels"
return (labels_dir / rel).with_suffix(".txt")
# Fallback: create a local sibling labels folder next to the image.
return (resolved.parent / "labels" / resolved.name).with_suffix(".txt")
def _build_detection_class_index_map(self, detections: List[Dict]) -> Dict[str, int]:
"""Build a stable class_name -> YOLO class index mapping.
Preference order:
1) Database object_classes table (alphabetical class_name order)
2) Fallback to class_name values present in the detections (alphabetical)
"""
names: List[str] = []
try:
db_classes = self.db_manager.get_object_classes() or []
names = [str(row.get("class_name")) for row in db_classes if row.get("class_name")]
except Exception:
names = []
if not names:
observed = sorted({str(det.get("class_name")) for det in detections if det.get("class_name")})
names = list(observed)
return {name: idx for idx, name in enumerate(names)}
def _format_detection_as_yolo_line(self, det: Dict, class_map: Dict[str, int]) -> Optional[str]:
"""Convert a detection row to a YOLO label line.
- If segmentation_mask is present, exports segmentation polygon format:
class x1 y1 x2 y2 ...
(normalized coordinates)
- Otherwise exports bbox format:
class x_center y_center width height
(normalized coordinates)
"""
class_name = det.get("class_name")
if not class_name or class_name not in class_map:
return None
class_idx = class_map[class_name]
mask = det.get("segmentation_mask")
polygon = self._convert_segmentation_mask_to_polygon(mask)
if polygon:
coords = " ".join(f"{value:.6f}" for value in polygon)
return f"{class_idx} {coords}".strip()
bbox = self._convert_bbox_to_yolo_xywh(det)
if bbox is None:
return None
x_center, y_center, width, height = bbox
return f"{class_idx} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
def _convert_bbox_to_yolo_xywh(self, det: Dict) -> Optional[Tuple[float, float, float, float]]:
"""Convert stored xyxy (normalized) bbox to YOLO xywh (normalized)."""
x_min = det.get("x_min")
y_min = det.get("y_min")
x_max = det.get("x_max")
y_max = det.get("y_max")
if any(v is None for v in (x_min, y_min, x_max, y_max)):
return None
try:
x_min_f = self._clamp01(float(x_min))
y_min_f = self._clamp01(float(y_min))
x_max_f = self._clamp01(float(x_max))
y_max_f = self._clamp01(float(y_max))
except (TypeError, ValueError):
return None
width = max(0.0, x_max_f - x_min_f)
height = max(0.0, y_max_f - y_min_f)
if width <= 0.0 or height <= 0.0:
return None
x_center = x_min_f + width / 2.0
y_center = y_min_f + height / 2.0
return x_center, y_center, width, height
def _convert_segmentation_mask_to_polygon(self, mask_data) -> List[float]:
"""Convert stored segmentation_mask [[x,y], ...] to YOLO polygon coords [x1,y1,...]."""
if not isinstance(mask_data, list):
return []
coords: List[float] = []
for point in mask_data:
if not isinstance(point, (list, tuple)) or len(point) < 2:
continue
try:
x = self._clamp01(float(point[0]))
y = self._clamp01(float(point[1]))
except (TypeError, ValueError):
continue
coords.extend([x, y])
# Need at least 3 points => 6 values.
return coords if len(coords) >= 6 else []
@staticmethod
def _clamp01(value: float) -> float:
if value < 0.0:
return 0.0
if value > 1.0:
return 1.0
return value
def _load_detections_for_selection(self, entry: Dict):
"""Load detection records for the selected image/model pair."""
self.current_detections = []
if not entry:
return
try:
filters = {"image_id": entry["image_id"], "model_id": entry["model_id"]}
self.current_detections = self.db_manager.get_detections(filters)
except Exception as e:
logger.error(f"Failed to load detections for preview: {e}")
QMessageBox.critical(
self,
"Error",
f"Failed to load detections for this image:\n{str(e)}",
)
self.current_detections = []
def _apply_detection_overlays(self):
"""Draw detections onto the preview canvas based on current toggles."""
self.preview_canvas.clear_annotations()
self.preview_canvas.set_show_bboxes(self.show_bboxes_checkbox.isChecked())
if not self.current_detections or not self.current_image:
return
for det in self.current_detections:
color = self._get_class_color(det.get("class_name"))
if self.show_masks_checkbox.isChecked() and det.get("segmentation_mask"):
mask_points = self._convert_mask(det["segmentation_mask"])
if mask_points:
self.preview_canvas.draw_saved_polyline(mask_points, color)
bbox = [
det.get("x_min"),
det.get("y_min"),
det.get("x_max"),
det.get("y_max"),
]
if all(v is not None for v in bbox):
label = None
if self.show_confidence_checkbox.isChecked():
confidence = det.get("confidence")
if confidence is not None:
label = f"{confidence:.2f}"
self.preview_canvas.draw_saved_bbox(bbox, color, label=label)
def _convert_mask(self, mask_points: List[List[float]]) -> List[List[float]]:
"""Convert stored [x, y] masks to [y, x] format for the canvas."""
converted = []
for point in mask_points:
if len(point) >= 2:
x, y = point[0], point[1]
converted.append([y, x])
return converted
def _toggle_bboxes(self):
"""Update bounding box visibility on the canvas."""
self.preview_canvas.set_show_bboxes(self.show_bboxes_checkbox.isChecked())
# Re-render to respect show/hide when toggled
self._apply_detection_overlays()
def _update_summary_label(self, entry: Dict):
"""Display textual summary for the selected detection run."""
classes = ", ".join(sorted(entry.get("classes", []))) or "-"
summary_text = (
f"Image: {entry.get('image_filename', 'unknown')}\n"
f"Model: {entry.get('model_name', '')} {entry.get('model_version', '')}\n"
f"Detections: {entry.get('count', 0)}\n"
f"Classes: {classes}\n"
f"Last Updated: {entry.get('last_detected', 'n/a')}"
)
self.summary_label.setText(summary_text)
def _resolve_image_path(self, entry: Dict) -> Optional[str]:
"""Resolve an image path using metadata, cache, and repository hints."""
relative_path = entry.get("image_path") if entry else None
cache_key = relative_path or entry.get("source_path")
if cache_key and cache_key in self._image_path_cache:
cached = Path(self._image_path_cache[cache_key])
if cached.exists():
return self._image_path_cache[cache_key]
del self._image_path_cache[cache_key]
candidates = []
source_path = entry.get("source_path") if entry else None
if source_path:
candidates.append(Path(source_path))
repo_roots = []
if entry.get("repository_root"):
repo_roots.append(entry["repository_root"])
config_repo = self.config_manager.get_image_repository_path()
if config_repo:
repo_roots.append(config_repo)
for root in repo_roots:
if relative_path:
candidates.append(Path(root) / relative_path)
if relative_path:
candidates.append(Path(relative_path))
for candidate in candidates:
try:
if candidate and candidate.exists():
resolved = str(candidate.resolve())
if cache_key:
self._image_path_cache[cache_key] = resolved
return resolved
except Exception:
continue
# Fallback: search by filename in known roots
filename = Path(relative_path).name if relative_path else None
if filename:
search_roots = [Path(root) for root in repo_roots if root]
if not search_roots:
search_roots = [Path("data")]
match = self._search_in_roots(filename, search_roots)
if match:
resolved = str(match.resolve())
if cache_key:
self._image_path_cache[cache_key] = resolved
return resolved
return None
def _search_in_roots(self, filename: str, roots: List[Path]) -> Optional[Path]:
"""Search for a file name within a list of root directories."""
for root in roots:
try:
if not root.exists():
continue
for candidate in root.rglob(filename):
return candidate
except Exception as e:
logger.debug(f"Error searching for {filename} in {root}: {e}")
return None
def _get_class_color(self, class_name: Optional[str]) -> str:
"""Return consistent color hex for a class name."""
if not class_name:
return "#FF6B6B"
color_map = self.config_manager.get_bbox_colors()
if class_name in color_map:
return color_map[class_name]
# Deterministic fallback color based on hash
palette = [
"#FF6B6B",
"#4ECDC4",
"#FFD166",
"#1D3557",
"#F4A261",
"#E76F51",
]
return palette[hash(class_name) % len(palette)]

File diff suppressed because it is too large Load Diff

View File

@@ -2,45 +2,554 @@
Validation tab for the microscopy object detection application.
"""
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple
from PySide6.QtCore import Qt, QSize
from PySide6.QtGui import QPainter, QPixmap
from PySide6.QtWidgets import (
QWidget,
QVBoxLayout,
QLabel,
QGroupBox,
QHBoxLayout,
QPushButton,
QComboBox,
QFormLayout,
QScrollArea,
QGridLayout,
QFrame,
QTableWidget,
QTableWidgetItem,
QHeaderView,
QSplitter,
QListWidget,
QListWidgetItem,
QAbstractItemView,
QGraphicsView,
QGraphicsScene,
QGraphicsPixmapItem,
)
from src.database.db_manager import DatabaseManager
from src.utils.config_manager import ConfigManager
from src.utils.logger import get_logger
logger = get_logger(__name__)
@dataclass(frozen=True)
class _PlotItem:
label: str
path: Path
class _ZoomableImageView(QGraphicsView):
"""Zoomable image viewer.
- Mouse wheel: zoom in/out
- Left mouse drag: pan (ScrollHandDrag)
"""
def __init__(self, parent: Optional[QWidget] = None):
super().__init__(parent)
self._scene = QGraphicsScene(self)
self.setScene(self._scene)
self._pixmap_item = QGraphicsPixmapItem()
self._scene.addItem(self._pixmap_item)
# QGraphicsView render hints are QPainter.RenderHints.
self.setRenderHints(self.renderHints() | QPainter.RenderHint.SmoothPixmapTransform)
self.setDragMode(QGraphicsView.DragMode.ScrollHandDrag)
self.setTransformationAnchor(QGraphicsView.ViewportAnchor.AnchorUnderMouse)
self.setResizeAnchor(QGraphicsView.ViewportAnchor.AnchorUnderMouse)
self._has_pixmap = False
def clear(self) -> None:
self._pixmap_item.setPixmap(QPixmap())
self._scene.setSceneRect(0, 0, 1, 1)
self.resetTransform()
self._has_pixmap = False
def set_pixmap(self, pixmap: QPixmap, *, fit: bool = True) -> None:
self._pixmap_item.setPixmap(pixmap)
self._scene.setSceneRect(pixmap.rect())
self._has_pixmap = not pixmap.isNull()
self.resetTransform()
if fit and self._has_pixmap:
self.fitInView(self._pixmap_item, Qt.AspectRatioMode.KeepAspectRatio)
def wheelEvent(self, event) -> None: # type: ignore[override]
if not self._has_pixmap:
return
zoom_in_factor = 1.25
zoom_out_factor = 1.0 / zoom_in_factor
factor = zoom_in_factor if event.angleDelta().y() > 0 else zoom_out_factor
self.scale(factor, factor)
class ValidationTab(QWidget):
"""Validation tab placeholder."""
"""Validation tab that shows stored validation metrics + plots for a selected model."""
def __init__(
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
):
def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
super().__init__(parent)
self.db_manager = db_manager
self.config_manager = config_manager
self._models: List[Dict[str, Any]] = []
self._selected_model_id: Optional[int] = None
self._plot_widgets: List[QWidget] = []
self._plot_items: List[_PlotItem] = []
self._setup_ui()
self.refresh()
def _setup_ui(self):
"""Setup user interface."""
layout = QVBoxLayout()
layout = QVBoxLayout(self)
group = QGroupBox("Validation")
group_layout = QVBoxLayout()
label = QLabel(
"Validation functionality will be implemented here.\n\n"
"Features:\n"
"- Model validation\n"
"- Metrics visualization\n"
"- Confusion matrix\n"
"- Precision-Recall curves"
)
group_layout.addWidget(label)
group.setLayout(group_layout)
# ===== Header controls =====
header = QGroupBox("Validation")
header_layout = QVBoxLayout()
header_row = QHBoxLayout()
layout.addWidget(group)
layout.addStretch()
self.setLayout(layout)
header_row.addWidget(QLabel("Select model:"))
self.model_combo = QComboBox()
self.model_combo.setMinimumWidth(420)
self.model_combo.currentIndexChanged.connect(self._on_model_selected)
header_row.addWidget(self.model_combo, 1)
self.refresh_btn = QPushButton("Refresh")
self.refresh_btn.clicked.connect(self.refresh)
header_row.addWidget(self.refresh_btn)
header_row.addStretch()
header_layout.addLayout(header_row)
self.header_status = QLabel("No models loaded.")
self.header_status.setWordWrap(True)
header_layout.addWidget(self.header_status)
header.setLayout(header_layout)
layout.addWidget(header)
# ===== Metrics =====
metrics_group = QGroupBox("Validation Metrics")
metrics_layout = QVBoxLayout()
self.metrics_form = QFormLayout()
self.metric_labels: Dict[str, QLabel] = {}
for key in ("mAP50", "mAP50-95", "precision", "recall", "fitness"):
value_label = QLabel("")
value_label.setTextInteractionFlags(Qt.TextSelectableByMouse)
self.metric_labels[key] = value_label
self.metrics_form.addRow(f"{key}:", value_label)
metrics_layout.addLayout(self.metrics_form)
self.per_class_table = QTableWidget(0, 3)
self.per_class_table.setHorizontalHeaderLabels(["Class", "AP", "AP50"])
self.per_class_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Stretch)
self.per_class_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeToContents)
self.per_class_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeToContents)
self.per_class_table.setEditTriggers(QTableWidget.NoEditTriggers)
self.per_class_table.setMinimumHeight(160)
metrics_layout.addWidget(QLabel("Per-class metrics (if available):"))
metrics_layout.addWidget(self.per_class_table)
metrics_group.setLayout(metrics_layout)
layout.addWidget(metrics_group)
# ===== Plots =====
plots_group = QGroupBox("Validation Plots")
plots_layout = QVBoxLayout()
self.plots_status = QLabel("Select a model to see validation plots.")
self.plots_status.setWordWrap(True)
plots_layout.addWidget(self.plots_status)
self.plots_splitter = QSplitter(Qt.Orientation.Horizontal)
# Left: selected image viewer
left_widget = QWidget()
left_layout = QVBoxLayout(left_widget)
left_layout.setContentsMargins(0, 0, 0, 0)
self.selected_plot_title = QLabel("No image selected.")
self.selected_plot_title.setWordWrap(True)
self.selected_plot_title.setTextInteractionFlags(Qt.TextSelectableByMouse)
left_layout.addWidget(self.selected_plot_title)
self.plot_view = _ZoomableImageView()
self.plot_view.setMinimumHeight(360)
left_layout.addWidget(self.plot_view, 1)
self.selected_plot_path = QLabel("")
self.selected_plot_path.setWordWrap(True)
self.selected_plot_path.setStyleSheet("color: #888;")
self.selected_plot_path.setTextInteractionFlags(Qt.TextSelectableByMouse)
left_layout.addWidget(self.selected_plot_path)
# Right: scrollable list
right_widget = QWidget()
right_layout = QVBoxLayout(right_widget)
right_layout.setContentsMargins(0, 0, 0, 0)
right_layout.addWidget(QLabel("Images:"))
self.plots_list = QListWidget()
self.plots_list.setSelectionMode(QAbstractItemView.SelectionMode.SingleSelection)
self.plots_list.setIconSize(QSize(160, 160))
self.plots_list.itemSelectionChanged.connect(self._on_plot_item_selected)
right_layout.addWidget(self.plots_list, 1)
self.plots_splitter.addWidget(left_widget)
self.plots_splitter.addWidget(right_widget)
self.plots_splitter.setStretchFactor(0, 3)
self.plots_splitter.setStretchFactor(1, 1)
plots_layout.addWidget(self.plots_splitter, 1)
plots_group.setLayout(plots_layout)
layout.addWidget(plots_group, 1)
layout.addStretch(0)
self._clear_metrics()
self._clear_plots()
# ==================== Public API ====================
def refresh(self):
"""Refresh the tab."""
pass
self._load_models()
self._populate_model_combo()
self._restore_or_select_default_model()
# ==================== Internal: models ====================
def _load_models(self) -> None:
try:
self._models = self.db_manager.get_models() or []
except Exception as exc:
logger.error("Failed to load models: %s", exc)
self._models = []
def _populate_model_combo(self) -> None:
self.model_combo.blockSignals(True)
self.model_combo.clear()
self.model_combo.addItem("Select a model…", None)
for model in self._models:
model_id = model.get("id")
name = (model.get("model_name") or "").strip()
version = (model.get("model_version") or "").strip()
created_at = model.get("created_at")
label = f"{name} {version}".strip()
if created_at:
label = f"{label} ({created_at})"
self.model_combo.addItem(label, model_id)
self.model_combo.blockSignals(False)
if self._models:
self.header_status.setText(f"Loaded {len(self._models)} model(s).")
else:
self.header_status.setText("No models found. Train a model first.")
def _restore_or_select_default_model(self) -> None:
if not self._models:
self._selected_model_id = None
self._clear_metrics()
self._clear_plots()
return
# Keep selection if still present.
if self._selected_model_id is not None:
for idx in range(1, self.model_combo.count()):
if self.model_combo.itemData(idx) == self._selected_model_id:
self.model_combo.setCurrentIndex(idx)
return
# Otherwise select the newest model (top of get_models ORDER BY created_at DESC).
first_model_id = self.model_combo.itemData(1) if self.model_combo.count() > 1 else None
if first_model_id is not None:
self.model_combo.setCurrentIndex(1)
def _on_model_selected(self, index: int) -> None:
model_id = self.model_combo.itemData(index)
if not model_id:
self._selected_model_id = None
self._clear_metrics()
self._clear_plots()
self.plots_status.setText("Select a model to see validation plots.")
return
self._selected_model_id = int(model_id)
model = self._get_model_by_id(self._selected_model_id)
if not model:
self._clear_metrics()
self._clear_plots()
self.plots_status.setText("Selected model not found.")
return
self._render_metrics(model)
self._render_plots(model)
def _get_model_by_id(self, model_id: int) -> Optional[Dict[str, Any]]:
for model in self._models:
if model.get("id") == model_id:
return model
try:
return self.db_manager.get_model_by_id(model_id)
except Exception:
return None
# ==================== Internal: metrics ====================
def _clear_metrics(self) -> None:
for label in self.metric_labels.values():
label.setText("")
self.per_class_table.setRowCount(0)
def _render_metrics(self, model: Dict[str, Any]) -> None:
self._clear_metrics()
metrics: Dict[str, Any] = model.get("metrics") or {}
# Training tab stores metrics under results['metrics'] in training results payload.
if isinstance(metrics, dict) and "metrics" in metrics and isinstance(metrics.get("metrics"), dict):
metrics = metrics.get("metrics") or {}
def set_metric(key: str, value: Any) -> None:
if key not in self.metric_labels:
return
if value is None:
self.metric_labels[key].setText("")
return
try:
self.metric_labels[key].setText(f"{float(value):.4f}")
except Exception:
self.metric_labels[key].setText(str(value))
set_metric("mAP50", metrics.get("mAP50"))
set_metric("mAP50-95", metrics.get("mAP50-95") or metrics.get("mAP50_95") or metrics.get("mAP50-95"))
set_metric("precision", metrics.get("precision"))
set_metric("recall", metrics.get("recall"))
set_metric("fitness", metrics.get("fitness"))
# Optional per-class metrics
class_metrics = metrics.get("class_metrics") if isinstance(metrics, dict) else None
if isinstance(class_metrics, dict) and class_metrics:
items = sorted(class_metrics.items(), key=lambda kv: str(kv[0]))
self.per_class_table.setRowCount(len(items))
for row, (cls_name, cls_stats) in enumerate(items):
ap = (cls_stats or {}).get("ap")
ap50 = (cls_stats or {}).get("ap50")
self.per_class_table.setItem(row, 0, QTableWidgetItem(str(cls_name)))
self.per_class_table.setItem(row, 1, QTableWidgetItem(self._format_float(ap)))
self.per_class_table.setItem(row, 2, QTableWidgetItem(self._format_float(ap50)))
else:
self.per_class_table.setRowCount(0)
@staticmethod
def _format_float(value: Any) -> str:
if value is None:
return ""
try:
return f"{float(value):.4f}"
except Exception:
return str(value)
# ==================== Internal: plots ====================
def _clear_plots(self) -> None:
# Remove legacy grid widgets (from the initial implementation).
for widget in self._plot_widgets:
widget.setParent(None)
widget.deleteLater()
self._plot_widgets = []
self._plot_items = []
if hasattr(self, "plots_list"):
self.plots_list.blockSignals(True)
self.plots_list.clear()
self.plots_list.blockSignals(False)
if hasattr(self, "plot_view"):
self.plot_view.clear()
if hasattr(self, "selected_plot_title"):
self.selected_plot_title.setText("No image selected.")
if hasattr(self, "selected_plot_path"):
self.selected_plot_path.setText("")
def _render_plots(self, model: Dict[str, Any]) -> None:
self._clear_plots()
plot_dirs = self._infer_run_directories(model)
plot_items = self._discover_plot_items(plot_dirs)
if not plot_items:
dirs_text = "\n".join(str(p) for p in plot_dirs if p)
self.plots_status.setText(
"No validation plot images found for this model.\n\n"
"Searched directories:\n" + (dirs_text or "(none)")
)
return
self._plot_items = list(plot_items)
self.plots_status.setText(f"Found {len(plot_items)} plot image(s). Select one to view/zoom.")
self.plots_list.blockSignals(True)
self.plots_list.clear()
for idx, item in enumerate(self._plot_items):
qitem = QListWidgetItem(item.label)
qitem.setData(Qt.ItemDataRole.UserRole, idx)
pix = QPixmap(str(item.path))
if not pix.isNull():
thumb = pix.scaled(
self.plots_list.iconSize(),
Qt.AspectRatioMode.KeepAspectRatio,
Qt.TransformationMode.SmoothTransformation,
)
qitem.setIcon(thumb)
self.plots_list.addItem(qitem)
self.plots_list.blockSignals(False)
if self.plots_list.count() > 0:
self.plots_list.setCurrentRow(0)
def _on_plot_item_selected(self) -> None:
if not self._plot_items:
return
selected = self.plots_list.selectedItems()
if not selected:
return
idx = selected[0].data(Qt.ItemDataRole.UserRole)
try:
idx_int = int(idx)
except Exception:
return
if idx_int < 0 or idx_int >= len(self._plot_items):
return
plot = self._plot_items[idx_int]
self.selected_plot_title.setText(plot.label)
self.selected_plot_path.setText(str(plot.path))
pix = QPixmap(str(plot.path))
if pix.isNull():
self.plot_view.clear()
return
self.plot_view.set_pixmap(pix, fit=True)
def _infer_run_directories(self, model: Dict[str, Any]) -> List[Path]:
dirs: List[Path] = []
# 1) Infer from model_path: .../<run>/weights/best.pt -> <run>
model_path = model.get("model_path")
if model_path:
try:
p = Path(str(model_path)).expanduser()
if p.name.lower().endswith(".pt"):
# If it lives under weights/, use parent.parent.
if p.parent.name == "weights" and p.parent.parent.exists():
dirs.append(p.parent.parent)
elif p.parent.exists():
dirs.append(p.parent)
except Exception:
pass
# 2) Look at training_params.stage_results[].results.save_dir
training_params = model.get("training_params") or {}
stage_results = None
if isinstance(training_params, dict):
stage_results = training_params.get("stage_results")
if isinstance(stage_results, list):
for stage in stage_results:
results = (stage or {}).get("results")
save_dir = (results or {}).get("save_dir") if isinstance(results, dict) else None
if save_dir:
try:
save_path = Path(str(save_dir)).expanduser()
if save_path.exists():
dirs.append(save_path)
except Exception:
continue
# Deduplicate while preserving order.
unique: List[Path] = []
seen: set[str] = set()
for d in dirs:
try:
resolved = str(d.resolve())
except Exception:
resolved = str(d)
if resolved not in seen and d.exists() and d.is_dir():
seen.add(resolved)
unique.append(d)
return unique
def _discover_plot_items(self, directories: Sequence[Path]) -> List[_PlotItem]:
# Prefer canonical Ultralytics filenames first, then fall back to any png/jpg.
preferred_names = [
"results.png",
"results.jpg",
"confusion_matrix.png",
"confusion_matrix_normalized.png",
"labels.jpg",
"labels.png",
"BoxPR_curve.png",
"BoxP_curve.png",
"BoxR_curve.png",
"BoxF1_curve.png",
"MaskPR_curve.png",
"MaskP_curve.png",
"MaskR_curve.png",
"MaskF1_curve.png",
"val_batch0_pred.jpg",
"val_batch0_labels.jpg",
]
found: List[_PlotItem] = []
seen: set[str] = set()
for d in directories:
# 1) Preferred
for name in preferred_names:
p = d / name
if p.exists() and p.is_file():
key = str(p)
if key in seen:
continue
seen.add(key)
found.append(_PlotItem(label=f"{name} (from {d.name})", path=p))
# 2) Curated globs
for pattern in ("train_batch*.jpg", "val_batch*.jpg", "*curve*.png"):
for p in sorted(d.glob(pattern)):
if not p.is_file():
continue
key = str(p)
if key in seen:
continue
seen.add(key)
found.append(_PlotItem(label=f"{p.name} (from {d.name})", path=p))
# 3) Fallback: any top-level png/jpg (excluding weights dir contents)
for ext in ("*.png", "*.jpg", "*.jpeg", "*.webp"):
for p in sorted(d.glob(ext)):
if not p.is_file():
continue
key = str(p)
if key in seen:
continue
seen.add(key)
found.append(_PlotItem(label=f"{p.name} (from {d.name})", path=p))
# Keep list bounded to avoid UI overload for huge runs.
return found[:60]

View File

@@ -0,0 +1,7 @@
"""GUI widgets for the microscopy object detection application."""
from src.gui.widgets.image_display_widget import ImageDisplayWidget
from src.gui.widgets.annotation_canvas_widget import AnnotationCanvasWidget
from src.gui.widgets.annotation_tools_widget import AnnotationToolsWidget
__all__ = ["ImageDisplayWidget", "AnnotationCanvasWidget", "AnnotationToolsWidget"]

View File

@@ -0,0 +1,930 @@
"""
Annotation canvas widget for drawing annotations on images.
Currently supports polyline drawing tool with color selection for manual annotation.
"""
import numpy as np
import math
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QScrollArea
from PySide6.QtGui import (
QPixmap,
QImage,
QPainter,
QPen,
QColor,
QKeyEvent,
QMouseEvent,
QPaintEvent,
QPolygonF,
)
from PySide6.QtCore import Qt, QEvent, Signal, QPoint, QPointF, QRect, QTimer
from typing import Any, Dict, List, Optional, Tuple
from src.utils.image import Image, ImageLoadError
from src.utils.logger import get_logger
logger = get_logger(__name__)
def perpendicular_distance(
point: Tuple[float, float],
start: Tuple[float, float],
end: Tuple[float, float],
) -> float:
"""Perpendicular distance from `point` to the line defined by `start`->`end`."""
(x, y), (x1, y1), (x2, y2) = point, start, end
dx = x2 - x1
dy = y2 - y1
if dx == 0.0 and dy == 0.0:
return math.hypot(x - x1, y - y1)
num = abs(dy * x - dx * y + x2 * y1 - y2 * x1)
den = math.hypot(dx, dy)
return num / den
def rdp(points: List[Tuple[float, float]], epsilon: float) -> List[Tuple[float, float]]:
"""
Recursive Ramer-Douglas-Peucker (RDP) polyline simplification.
Args:
points: List of (x, y) points.
epsilon: Maximum allowed perpendicular distance in pixels.
Returns:
Simplified list of (x, y) points including first and last.
"""
if len(points) <= 2:
return list(points)
start = points[0]
end = points[-1]
max_dist = -1.0
index = -1
for i in range(1, len(points) - 1):
d = perpendicular_distance(points[i], start, end)
if d > max_dist:
max_dist = d
index = i
if max_dist > epsilon:
# Recursive split
left = rdp(points[: index + 1], epsilon)
right = rdp(points[index:], epsilon)
# Concatenate but avoid duplicate at split point
return left[:-1] + right
# Keep only start and end
return [start, end]
def simplify_polyline(points: List[Tuple[float, float]], epsilon: float) -> List[Tuple[float, float]]:
"""
Simplify a polyline with RDP while preserving closure semantics.
If the polyline is closed (first == last), the duplicate last point is removed
before simplification and then re-added after simplification.
"""
if not points:
return []
pts = [(float(x), float(y)) for x, y in points]
closed = False
if len(pts) >= 2 and pts[0] == pts[-1]:
closed = True
pts = pts[:-1] # remove duplicate last for simplification
if len(pts) <= 2:
simplified = list(pts)
else:
simplified = rdp(pts, epsilon)
if closed and simplified:
if simplified[0] != simplified[-1]:
simplified.append(simplified[0])
return simplified
class AnnotationCanvasWidget(QWidget):
"""
Widget for displaying images and drawing annotations with zoom and drawing tools.
Features:
- Display images with zoom functionality
- Polyline tool for drawing annotations
- Configurable pen color and width
- Mouse-based drawing interface
- Zoom in/out with mouse wheel and keyboard
Signals:
zoom_changed: Emitted when zoom level changes (float zoom_scale)
annotation_drawn: Emitted when a new stroke is completed (list of points)
"""
zoom_changed = Signal(float)
annotation_drawn = Signal(list) # List of (x, y) points in normalized coordinates
# Emitted when the user selects an existing polyline on the canvas.
# Carries the associated annotation_id (int) or None if selection is cleared
annotation_selected = Signal(object)
def __init__(self, parent=None):
"""Initialize the annotation canvas widget."""
super().__init__(parent)
self.current_image = None
self.original_pixmap = None
self.annotation_pixmap = None # Overlay for annotations
self.zoom_scale = 1.0
self.zoom_min = 0.1
self.zoom_max = 10.0
self.zoom_step = 0.1
self.zoom_wheel_step = 0.15
# Auto-fit behavior (opt-in): when enabled, newly loaded images (and resizes)
# will scale to fill the available viewport while preserving aspect ratio.
self._auto_fit_to_view: bool = False
# Drawing / interaction state
self.is_drawing = False
self.polyline_enabled = False
self.polyline_pen_color = QColor(255, 0, 0, 128) # Default red with 50% alpha
self.polyline_pen_width = 3
self.show_bboxes: bool = True # Control visibility of bounding boxes
# Current stroke and stored polylines (in image coordinates, pixel units)
self.current_stroke: List[Tuple[float, float]] = []
self.polylines: List[List[Tuple[float, float]]] = []
self.stroke_meta: List[Dict[str, Any]] = [] # per-polyline style (color, width)
# Optional DB annotation_id for each stored polyline (None for temporary / unsaved)
self.polyline_annotation_ids: List[Optional[int]] = []
# Indices in self.polylines of the currently selected polylines (multi-select)
self.selected_polyline_indices: List[int] = []
# Stored bounding boxes in normalized coordinates (x_min, y_min, x_max, y_max)
self.bboxes: List[List[float]] = []
self.bbox_meta: List[Dict[str, Any]] = [] # per-bbox style (color, width)
# Legacy collection of strokes in normalized coordinates (kept for API compatibility)
self.all_strokes: List[dict] = []
# RDP simplification parameters (in pixels)
self.simplify_on_finish: bool = True
self.simplify_epsilon: float = 2.0
self.sample_threshold: float = 2.0 # minimum movement to sample a new point
self._setup_ui()
def set_auto_fit_to_view(self, enabled: bool):
"""Enable/disable automatic zoom-to-fit behavior."""
self._auto_fit_to_view = bool(enabled)
if self._auto_fit_to_view and self.original_pixmap is not None:
QTimer.singleShot(0, self.fit_to_view)
def fit_to_view(self, padding_px: int = 6):
"""Zoom the image so it fits the scroll area's viewport (aspect preserved)."""
if self.original_pixmap is None:
return
viewport = self.scroll_area.viewport().size()
available_w = max(1, int(viewport.width()) - int(padding_px))
available_h = max(1, int(viewport.height()) - int(padding_px))
img_w = max(1, int(self.original_pixmap.width()))
img_h = max(1, int(self.original_pixmap.height()))
scale_w = available_w / img_w
scale_h = available_h / img_h
new_scale = min(scale_w, scale_h)
new_scale = max(self.zoom_min, min(self.zoom_max, float(new_scale)))
if abs(new_scale - self.zoom_scale) < 1e-4:
return
self.zoom_scale = new_scale
self._apply_zoom()
def _setup_ui(self):
"""Setup user interface."""
layout = QVBoxLayout()
layout.setContentsMargins(0, 0, 0, 0)
# Scroll area for canvas
self.scroll_area = QScrollArea()
self.scroll_area.setWidgetResizable(True)
self.scroll_area.setMinimumHeight(400)
self.canvas_label = QLabel("No image loaded")
self.canvas_label.setAlignment(Qt.AlignCenter)
self.canvas_label.setStyleSheet("QLabel { background-color: #2b2b2b; color: #888; }")
self.canvas_label.setScaledContents(False)
self.canvas_label.setMouseTracking(True)
self.scroll_area.setWidget(self.canvas_label)
self.scroll_area.viewport().installEventFilter(self)
layout.addWidget(self.scroll_area)
self.setLayout(layout)
self.setFocusPolicy(Qt.StrongFocus)
def load_image(self, image: Image):
"""
Load and display an image.
Args:
image: Image object to display
"""
self.current_image = image
self.zoom_scale = 1.0
self.clear_annotations()
self._display_image()
# Defer fit-to-view until the widget has a valid viewport size.
if self._auto_fit_to_view:
QTimer.singleShot(0, self.fit_to_view)
logger.debug(f"Loaded image into annotation canvas: {image.width}x{image.height}")
def resizeEvent(self, event):
"""Optionally keep the image fitted when the widget is resized."""
super().resizeEvent(event)
if self._auto_fit_to_view and self.original_pixmap is not None:
QTimer.singleShot(0, self.fit_to_view)
def clear(self):
"""Clear the displayed image and all annotations."""
self.current_image = None
self.original_pixmap = None
self.annotation_pixmap = None
self.zoom_scale = 1.0
self.clear_annotations()
self.canvas_label.setText("No image loaded")
self.canvas_label.setPixmap(QPixmap())
def clear_annotations(self):
"""Clear all drawn annotations."""
self.all_strokes = []
self.current_stroke = []
self.polylines = []
self.stroke_meta = []
self.polyline_annotation_ids = []
self.selected_polyline_indices = []
self.bboxes = []
self.bbox_meta = []
self.is_drawing = False
if self.annotation_pixmap:
self.annotation_pixmap.fill(Qt.transparent)
self._update_display()
def _display_image(self):
"""Display the current image in the canvas."""
if self.current_image is None:
return
try:
# Get image data in a format compatible with Qt
if self.current_image.channels in (3, 4):
image_data = self.current_image.get_rgb()
else:
image_data = self.current_image.get_qt_rgb()
height, width = image_data.shape[:2]
bytes_per_line = image_data.strides[0]
qimage = QImage(
image_data.data,
width,
height,
bytes_per_line,
QImage.Format_RGBX32FPx4, # self.current_image.qtimage_format,
).copy() # Copy so Qt owns the buffer even after numpy array goes out of scope
self.original_pixmap = QPixmap.fromImage(qimage)
# Create transparent overlay for annotations
self.annotation_pixmap = QPixmap(self.original_pixmap.size())
self.annotation_pixmap.fill(Qt.transparent)
self._apply_zoom()
except Exception as e:
logger.error(f"Error displaying image: {e}")
raise ImageLoadError(f"Failed to display image: {str(e)}")
def _apply_zoom(self):
"""Apply current zoom level to the displayed image."""
if self.original_pixmap is None:
return
scaled_width = int(self.original_pixmap.width() * self.zoom_scale)
scaled_height = int(self.original_pixmap.height() * self.zoom_scale)
# Scale both image and annotations
scaled_image = self.original_pixmap.scaled(
scaled_width,
scaled_height,
Qt.KeepAspectRatio,
(Qt.SmoothTransformation if self.zoom_scale >= 1.0 else Qt.FastTransformation),
)
scaled_annotations = self.annotation_pixmap.scaled(
scaled_width,
scaled_height,
Qt.KeepAspectRatio,
(Qt.SmoothTransformation if self.zoom_scale >= 1.0 else Qt.FastTransformation),
)
# Composite image and annotations
combined = QPixmap(scaled_image.size())
painter = QPainter(combined)
painter.drawPixmap(0, 0, scaled_image)
painter.drawPixmap(0, 0, scaled_annotations)
painter.end()
self.canvas_label.setPixmap(combined)
self.canvas_label.setScaledContents(False)
self.canvas_label.adjustSize()
self.zoom_changed.emit(self.zoom_scale)
def _update_display(self):
"""Update display after drawing."""
self._apply_zoom()
def set_polyline_enabled(self, enabled: bool):
"""Enable or disable polyline tool."""
self.polyline_enabled = enabled
if enabled:
self.canvas_label.setCursor(Qt.CrossCursor)
else:
self.canvas_label.setCursor(Qt.ArrowCursor)
def set_polyline_pen_color(self, color: QColor):
"""Set polyline pen color."""
self.polyline_pen_color = color
def set_polyline_pen_width(self, width: int):
"""Set polyline pen width."""
self.polyline_pen_width = max(1, width)
def get_zoom_percentage(self) -> int:
"""Get current zoom level as percentage."""
return int(self.zoom_scale * 100)
def zoom_in(self):
"""Zoom in on the image."""
if self.original_pixmap is None:
return
new_scale = self.zoom_scale + self.zoom_step
if new_scale <= self.zoom_max:
self.zoom_scale = new_scale
self._apply_zoom()
def zoom_out(self):
"""Zoom out from the image."""
if self.original_pixmap is None:
return
new_scale = self.zoom_scale - self.zoom_step
if new_scale >= self.zoom_min:
self.zoom_scale = new_scale
self._apply_zoom()
def reset_zoom(self):
"""Reset zoom to 100%."""
if self.original_pixmap is None:
return
self.zoom_scale = 1.0
self._apply_zoom()
def _canvas_to_image_coords(self, pos: QPoint) -> Optional[Tuple[int, int]]:
"""Convert canvas coordinates to image coordinates, accounting for zoom and centering."""
if self.original_pixmap is None or self.canvas_label.pixmap() is None:
return None
# Get the displayed pixmap size (after zoom)
displayed_pixmap = self.canvas_label.pixmap()
displayed_width = displayed_pixmap.width()
displayed_height = displayed_pixmap.height()
# Calculate offset due to label centering (label might be larger than pixmap)
label_width = self.canvas_label.width()
label_height = self.canvas_label.height()
offset_x = max(0, (label_width - displayed_width) // 2)
offset_y = max(0, (label_height - displayed_height) // 2)
# Adjust position for offset and convert to image coordinates
x = (pos.x() - offset_x) / self.zoom_scale
y = (pos.y() - offset_y) / self.zoom_scale
# Check bounds
if 0 <= x < self.original_pixmap.width() and 0 <= y < self.original_pixmap.height():
return (int(x), int(y))
return None
def _find_polyline_at(self, img_x: float, img_y: float, threshold_px: float = 5.0) -> Optional[int]:
"""
Find index of polyline whose geometry is within threshold_px of (img_x, img_y).
Returns the index in self.polylines, or None if none is close enough.
"""
best_index: Optional[int] = None
best_dist: float = float("inf")
for idx, polyline in enumerate(self.polylines):
if len(polyline) < 2:
continue
# Quick bounding-box check to skip obviously distant polylines
xs = [p[0] for p in polyline]
ys = [p[1] for p in polyline]
if img_x < min(xs) - threshold_px or img_x > max(xs) + threshold_px:
continue
if img_y < min(ys) - threshold_px or img_y > max(ys) + threshold_px:
continue
# Precise distance to all segments
for (x1, y1), (x2, y2) in zip(polyline[:-1], polyline[1:]):
d = perpendicular_distance((img_x, img_y), (float(x1), float(y1)), (float(x2), float(y2)))
if d < best_dist:
best_dist = d
best_index = idx
if best_index is not None and best_dist <= threshold_px:
return best_index
return None
def _image_to_normalized_coords(self, x: int, y: int) -> Tuple[float, float]:
"""Convert image coordinates to normalized coordinates (0-1)."""
if self.original_pixmap is None:
return (0.0, 0.0)
norm_x = x / self.original_pixmap.width()
norm_y = y / self.original_pixmap.height()
return (norm_x, norm_y)
def _add_polyline(
self,
img_points: List[Tuple[float, float]],
color: QColor,
width: int,
annotation_id: Optional[int] = None,
):
"""Store a polyline in image coordinates and redraw annotations."""
if not img_points or len(img_points) < 2:
return
# Ensure all points are tuples of floats
normalized_points = [(float(x), float(y)) for x, y in img_points]
self.polylines.append(normalized_points)
self.stroke_meta.append({"color": QColor(color), "width": int(width)})
self.polyline_annotation_ids.append(annotation_id)
self._redraw_annotations()
def _redraw_annotations(self):
"""Redraw all stored polylines and (optionally) bounding boxes onto the annotation pixmap."""
if self.annotation_pixmap is None:
return
# Clear existing overlay
self.annotation_pixmap.fill(Qt.transparent)
painter = QPainter(self.annotation_pixmap)
# Draw polylines
for idx, (polyline, meta) in enumerate(zip(self.polylines, self.stroke_meta)):
pen_color: QColor = meta.get("color", self.polyline_pen_color)
width: int = meta.get("width", self.polyline_pen_width)
if idx in self.selected_polyline_indices:
# Highlight selected polylines in a distinct color / width
highlight_color = QColor(255, 255, 0, 200) # yellow, semi-opaque
pen = QPen(
highlight_color,
width + 1,
Qt.SolidLine,
Qt.RoundCap,
Qt.RoundJoin,
)
else:
pen = QPen(
pen_color,
width,
Qt.SolidLine,
Qt.RoundCap,
Qt.RoundJoin,
)
painter.setPen(pen)
# Use QPolygonF for efficient polygon rendering (single call vs N-1 calls)
# drawPolygon() automatically closes the shape, ensuring proper visual closure
polygon = QPolygonF([QPointF(x, y) for x, y in polyline])
painter.drawPolygon(polygon)
# Draw bounding boxes (dashed) if enabled
if self.show_bboxes and self.original_pixmap is not None and self.bboxes:
img_width = float(self.original_pixmap.width())
img_height = float(self.original_pixmap.height())
for bbox, meta in zip(self.bboxes, self.bbox_meta):
if len(bbox) != 4:
continue
x_min_norm, y_min_norm, x_max_norm, y_max_norm = bbox
x_min = int(x_min_norm * img_width)
y_min = int(y_min_norm * img_height)
x_max = int(x_max_norm * img_width)
y_max = int(y_max_norm * img_height)
rect_width = x_max - x_min
rect_height = y_max - y_min
pen_color: QColor = meta.get("color", QColor(255, 0, 0, 128))
width: int = meta.get("width", self.polyline_pen_width)
pen = QPen(
pen_color,
width,
Qt.DashLine,
Qt.SquareCap,
Qt.MiterJoin,
)
painter.setPen(pen)
painter.drawRect(x_min, y_min, rect_width, rect_height)
label_text = meta.get("label")
if label_text:
painter.save()
font = painter.font()
font.setPointSizeF(max(10.0, width + 4))
painter.setFont(font)
metrics = painter.fontMetrics()
text_width = metrics.horizontalAdvance(label_text)
text_height = metrics.height()
padding = 4
bg_width = text_width + padding * 2
bg_height = text_height + padding * 2
canvas_width = self.original_pixmap.width()
canvas_height = self.original_pixmap.height()
bg_x = max(0, min(x_min, canvas_width - bg_width))
bg_y = y_min - bg_height
if bg_y < 0:
bg_y = min(y_min, canvas_height - bg_height)
bg_y = max(0, bg_y)
background_rect = QRect(bg_x, bg_y, bg_width, bg_height)
background_color = QColor(pen_color)
background_color.setAlpha(220)
painter.fillRect(background_rect, background_color)
text_color = QColor(0, 0, 0)
if background_color.lightness() < 128:
text_color = QColor(255, 255, 255)
painter.setPen(text_color)
painter.drawText(
background_rect.adjusted(padding, padding, -padding, -padding),
Qt.AlignLeft | Qt.AlignVCenter,
label_text,
)
painter.restore()
painter.end()
self._update_display()
def mousePressEvent(self, event: QMouseEvent):
"""Handle mouse press events for drawing and selecting polylines."""
if self.annotation_pixmap is None:
super().mousePressEvent(event)
return
# Map click to image coordinates
label_pos = self.canvas_label.mapFromGlobal(event.globalPos())
img_coords = self._canvas_to_image_coords(label_pos)
# Left button + drawing tool enabled -> start a new stroke
if event.button() == Qt.LeftButton and self.polyline_enabled:
if img_coords:
self.is_drawing = True
self.current_stroke = [(float(img_coords[0]), float(img_coords[1]))]
return
# Left button + drawing tool disabled -> attempt selection of existing polyline
if event.button() == Qt.LeftButton and not self.polyline_enabled:
if img_coords:
idx = self._find_polyline_at(float(img_coords[0]), float(img_coords[1]))
if idx is not None:
if event.modifiers() & Qt.ShiftModifier:
# Multi-select mode: add to current selection (if not already selected)
if idx not in self.selected_polyline_indices:
self.selected_polyline_indices.append(idx)
else:
# Single-select mode: replace current selection
self.selected_polyline_indices = [idx]
# Build list of selected annotation IDs (ignore None entries)
selected_ids: List[int] = []
for sel_idx in self.selected_polyline_indices:
if 0 <= sel_idx < len(self.polyline_annotation_ids):
ann_id = self.polyline_annotation_ids[sel_idx]
if isinstance(ann_id, int):
selected_ids.append(ann_id)
if selected_ids:
self.annotation_selected.emit(selected_ids)
else:
# No valid DB-backed annotations in selection
self.annotation_selected.emit(None)
else:
# Clicked on empty space -> clear selection
self.selected_polyline_indices = []
self.annotation_selected.emit(None)
self._redraw_annotations()
return
# Fallback for other buttons / cases
super().mousePressEvent(event)
def mouseMoveEvent(self, event: QMouseEvent):
"""Handle mouse move events for drawing."""
if not self.is_drawing or not self.polyline_enabled or self.annotation_pixmap is None:
super().mouseMoveEvent(event)
return
# Get accurate position using global coordinates
label_pos = self.canvas_label.mapFromGlobal(event.globalPos())
img_coords = self._canvas_to_image_coords(label_pos)
if img_coords and len(self.current_stroke) > 0:
last_point = self.current_stroke[-1]
dx = img_coords[0] - last_point[0]
dy = img_coords[1] - last_point[1]
# Only sample a new point if we moved enough pixels
if math.hypot(dx, dy) < self.sample_threshold:
return
# Draw line from last point to current point for interactive feedback
painter = QPainter(self.annotation_pixmap)
pen = QPen(
self.polyline_pen_color,
self.polyline_pen_width,
Qt.SolidLine,
Qt.RoundCap,
Qt.RoundJoin,
)
painter.setPen(pen)
painter.drawLine(
int(last_point[0]),
int(last_point[1]),
int(img_coords[0]),
int(img_coords[1]),
)
painter.end()
self.current_stroke.append((float(img_coords[0]), float(img_coords[1])))
self._update_display()
def mouseReleaseEvent(self, event: QMouseEvent):
"""Handle mouse release events to complete a stroke."""
if not self.is_drawing or event.button() != Qt.LeftButton:
super().mouseReleaseEvent(event)
return
self.is_drawing = False
if len(self.current_stroke) > 1 and self.original_pixmap is not None:
# Ensure the stroke is closed by connecting end -> start
raw_points = list(self.current_stroke)
if raw_points[0] != raw_points[-1]:
raw_points.append(raw_points[0])
# Optional RDP simplification (in image pixel space)
if self.simplify_on_finish:
simplified = simplify_polyline(raw_points, self.simplify_epsilon)
else:
simplified = raw_points
if len(simplified) >= 2:
# Store polyline and redraw all annotations
self._add_polyline(simplified, self.polyline_pen_color, self.polyline_pen_width)
# Convert to normalized coordinates for metadata + signal
normalized_stroke = [self._image_to_normalized_coords(int(x), int(y)) for (x, y) in simplified]
self.all_strokes.append(
{
"points": normalized_stroke,
"color": self.polyline_pen_color.name(),
"alpha": self.polyline_pen_color.alpha(),
"width": self.polyline_pen_width,
}
)
# Emit signal with normalized coordinates
self.annotation_drawn.emit(normalized_stroke)
logger.debug(
f"Completed stroke with {len(simplified)} points " f"(normalized len={len(normalized_stroke)})"
)
self.current_stroke = []
def get_all_strokes(self) -> List[dict]:
"""Get all drawn strokes with metadata."""
return self.all_strokes
def get_annotation_parameters(self) -> Optional[List[Dict[str, Any]]]:
"""
Get all annotation parameters including bounding box and polyline.
Returns:
List of dictionaries, each containing:
- 'bbox': [x_min, y_min, x_max, y_max] in normalized image coordinates
- 'polyline': List of [y_norm, x_norm] points describing the polygon
"""
if self.original_pixmap is None or not self.polylines:
return None
img_width = float(self.original_pixmap.width())
img_height = float(self.original_pixmap.height())
params: List[Dict[str, Any]] = []
for idx, polyline in enumerate(self.polylines):
if len(polyline) < 2:
continue
xs = [p[0] for p in polyline]
ys = [p[1] for p in polyline]
x_min_norm = min(xs) / img_width
x_max_norm = max(xs) / img_width
y_min_norm = min(ys) / img_height
y_max_norm = max(ys) / img_height
# Store polyline as [y_norm, x_norm] to match DB convention and
# the expectations of draw_saved_polyline().
normalized_polyline = [[y / img_height, x / img_width] for (x, y) in polyline]
logger.debug(
f"Polyline {idx}: {len(polyline)} points, "
f"bbox=({x_min_norm:.3f}, {y_min_norm:.3f})-({x_max_norm:.3f}, {y_max_norm:.3f})"
)
params.append(
{
"bbox": [x_min_norm, y_min_norm, x_max_norm, y_max_norm],
"polyline": normalized_polyline,
}
)
return params or None
def draw_saved_polyline(
self,
polyline: List[List[float]],
color: str,
width: int = 1,
annotation_id: Optional[int] = None,
):
"""
Draw a polyline from database coordinates onto the annotation canvas.
Args:
polyline: List of [x, y] coordinate pairs in normalized coordinates (0-1)
color: Color hex string (e.g., '#FF0000')
width: Line width in pixels
"""
if not self.annotation_pixmap or not self.original_pixmap:
logger.warning("Cannot draw polyline: no image loaded")
return
if len(polyline) < 2:
logger.warning("Polyline has less than 2 points, cannot draw")
return
# Convert normalized coordinates to image coordinates
# Polyline is stored as [[y_norm, x_norm], ...] (row_norm, col_norm format)
img_width = self.original_pixmap.width()
img_height = self.original_pixmap.height()
logger.debug(f"Loading polyline with {len(polyline)} points")
logger.debug(f" Image size: {img_width}x{img_height}")
logger.debug(f" First 3 normalized points from DB: {polyline[:3]}")
img_coords: List[Tuple[float, float]] = []
for y_norm, x_norm in polyline:
x = float(x_norm * img_width)
y = float(y_norm * img_height)
img_coords.append((x, y))
logger.debug(f" First 3 pixel coords: {img_coords[:3]}")
# Store and redraw using common pipeline
pen_color = QColor(color)
pen_color.setAlpha(255) # Add semi-transparency
self._add_polyline(img_coords, pen_color, width, annotation_id=annotation_id)
# Store in all_strokes for consistency (uses normalized coordinates)
self.all_strokes.append({"points": polyline, "color": color, "alpha": 255, "width": width})
logger.debug(f"Drew saved polyline with {len(polyline)} points in color {color}")
def draw_saved_bbox(
self,
bbox: List[float],
color: str,
width: int = 3,
label: Optional[str] = None,
):
"""
Draw a bounding box from database coordinates onto the annotation canvas.
Args:
bbox: Bounding box as [x_min_norm, y_min_norm, x_max_norm, y_max_norm]
in normalized coordinates (0-1)
color: Color hex string (e.g., '#FF0000')
width: Line width in pixels
label: Optional text label to render near the bounding box
"""
if not self.annotation_pixmap or not self.original_pixmap:
logger.warning("Cannot draw bounding box: no image loaded")
return
if len(bbox) != 4:
logger.warning(f"Invalid bounding box format: expected 4 values, got {len(bbox)}")
return
# Convert normalized coordinates to image coordinates (for logging/debug)
img_width = self.original_pixmap.width()
img_height = self.original_pixmap.height()
x_min_norm, y_min_norm, x_max_norm, y_max_norm = bbox
x_min = int(x_min_norm * img_width)
y_min = int(y_min_norm * img_height)
x_max = int(x_max_norm * img_width)
y_max = int(y_max_norm * img_height)
logger.debug(f"Drawing bounding box: {bbox}")
logger.debug(f" Image size: {img_width}x{img_height}")
logger.debug(f" Pixel coords: ({x_min}, {y_min}) to ({x_max}, {y_max})")
# Store bounding box (normalized) and its style; actual drawing happens
# in _redraw_annotations() together with all polylines.
pen_color = QColor(color)
pen_color.setAlpha(128) # Add semi-transparency
self.bboxes.append([float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)])
self.bbox_meta.append({"color": pen_color, "width": int(width), "label": label})
# Store in all_strokes for consistency
self.all_strokes.append({"bbox": bbox, "color": color, "alpha": 128, "width": width, "label": label})
# Redraw overlay (polylines + all bounding boxes)
self._redraw_annotations()
logger.debug(f"Drew saved bounding box in color {color}")
def set_show_bboxes(self, show: bool):
"""
Enable or disable drawing of bounding boxes.
Args:
show: If True, draw bounding boxes; if False, hide them.
"""
self.show_bboxes = bool(show)
logger.debug(f"Set show_bboxes to {self.show_bboxes}")
self._redraw_annotations()
def keyPressEvent(self, event: QKeyEvent):
"""Handle keyboard events for zooming."""
if event.key() in (Qt.Key_Plus, Qt.Key_Equal):
self.zoom_in()
event.accept()
elif event.key() == Qt.Key_Minus:
self.zoom_out()
event.accept()
elif event.key() == Qt.Key_0 and event.modifiers() == Qt.ControlModifier:
self.reset_zoom()
event.accept()
else:
super().keyPressEvent(event)
def eventFilter(self, obj, event: QEvent) -> bool:
"""Event filter to capture wheel events for zooming."""
if event.type() == QEvent.Wheel:
wheel_event = event
if self.original_pixmap is not None:
delta = wheel_event.angleDelta().y()
if delta > 0:
new_scale = self.zoom_scale + self.zoom_wheel_step
if new_scale <= self.zoom_max:
self.zoom_scale = new_scale
self._apply_zoom()
else:
new_scale = self.zoom_scale - self.zoom_wheel_step
if new_scale >= self.zoom_min:
self.zoom_scale = new_scale
self._apply_zoom()
return True
return super().eventFilter(obj, event)

View File

@@ -0,0 +1,478 @@
"""
Annotation tools widget for controlling annotation parameters.
Includes polyline tool, color picker, class selection, and annotation management.
"""
from PySide6.QtWidgets import (
QWidget,
QVBoxLayout,
QHBoxLayout,
QLabel,
QGroupBox,
QPushButton,
QComboBox,
QSpinBox,
QDoubleSpinBox,
QCheckBox,
QColorDialog,
QInputDialog,
QMessageBox,
)
from PySide6.QtGui import QColor, QIcon, QPixmap, QPainter
from PySide6.QtCore import Qt, Signal
from typing import Optional, Dict
from src.database.db_manager import DatabaseManager
from src.utils.logger import get_logger
logger = get_logger(__name__)
class AnnotationToolsWidget(QWidget):
"""
Widget for annotation tool controls.
Features:
- Enable/disable polyline tool
- Color selection for polyline pen
- Object class selection
- Add new object classes
- Pen width control
- Clear annotations
Signals:
polyline_enabled_changed: Emitted when polyline tool is enabled/disabled (bool)
polyline_pen_color_changed: Emitted when polyline pen color changes (QColor)
polyline_pen_width_changed: Emitted when polyline pen width changes (int)
class_selected: Emitted when object class is selected (dict)
clear_annotations_requested: Emitted when clear button is pressed
"""
polyline_enabled_changed = Signal(bool)
polyline_pen_color_changed = Signal(QColor)
polyline_pen_width_changed = Signal(int)
simplify_on_finish_changed = Signal(bool)
simplify_epsilon_changed = Signal(float)
# Toggle visibility of bounding boxes on the canvas
show_bboxes_changed = Signal(bool)
class_selected = Signal(dict)
class_color_changed = Signal()
clear_annotations_requested = Signal()
# Request deletion of the currently selected annotation on the canvas
delete_selected_annotation_requested = Signal()
def __init__(self, db_manager: DatabaseManager, parent=None):
"""
Initialize annotation tools widget.
Args:
db_manager: Database manager instance
parent: Parent widget
"""
super().__init__(parent)
self.db_manager = db_manager
self.polyline_enabled = False
self.current_color = QColor(255, 0, 0, 128) # Red with 50% alpha
self.current_class = None
self._setup_ui()
self._load_object_classes()
def _setup_ui(self):
"""Setup user interface."""
layout = QVBoxLayout()
# Polyline Tool Group
polyline_group = QGroupBox("Polyline Tool")
polyline_layout = QVBoxLayout()
# Enable/Disable polyline tool
button_layout = QHBoxLayout()
self.polyline_toggle_btn = QPushButton("Start Drawing Polyline")
self.polyline_toggle_btn.setCheckable(True)
self.polyline_toggle_btn.clicked.connect(self._on_polyline_toggle)
button_layout.addWidget(self.polyline_toggle_btn)
polyline_layout.addLayout(button_layout)
# Polyline pen width control
width_layout = QHBoxLayout()
width_layout.addWidget(QLabel("Pen Width:"))
self.polyline_pen_width_spin = QSpinBox()
self.polyline_pen_width_spin.setMinimum(1)
self.polyline_pen_width_spin.setMaximum(20)
self.polyline_pen_width_spin.setValue(3)
self.polyline_pen_width_spin.valueChanged.connect(
self._on_polyline_pen_width_changed
)
width_layout.addWidget(self.polyline_pen_width_spin)
width_layout.addStretch()
polyline_layout.addLayout(width_layout)
# Simplification controls (RDP)
simplify_layout = QHBoxLayout()
self.simplify_checkbox = QCheckBox("Simplify on finish")
self.simplify_checkbox.setChecked(True)
self.simplify_checkbox.stateChanged.connect(self._on_simplify_toggle)
simplify_layout.addWidget(self.simplify_checkbox)
simplify_layout.addWidget(QLabel("epsilon (px):"))
self.eps_spin = QDoubleSpinBox()
self.eps_spin.setRange(0.0, 1000.0)
self.eps_spin.setSingleStep(0.5)
self.eps_spin.setValue(2.0)
self.eps_spin.valueChanged.connect(self._on_eps_change)
simplify_layout.addWidget(self.eps_spin)
simplify_layout.addStretch()
polyline_layout.addLayout(simplify_layout)
polyline_group.setLayout(polyline_layout)
layout.addWidget(polyline_group)
# Object Class Group
class_group = QGroupBox("Object Class")
class_layout = QVBoxLayout()
# Class selection dropdown
self.class_combo = QComboBox()
self.class_combo.currentIndexChanged.connect(self._on_class_selected)
class_layout.addWidget(self.class_combo)
# Add / manage classes
class_button_layout = QHBoxLayout()
self.add_class_btn = QPushButton("Add New Class")
self.add_class_btn.clicked.connect(self._on_add_class)
class_button_layout.addWidget(self.add_class_btn)
self.refresh_classes_btn = QPushButton("Refresh")
self.refresh_classes_btn.clicked.connect(self._load_object_classes)
class_button_layout.addWidget(self.refresh_classes_btn)
class_layout.addLayout(class_button_layout)
# Class color (associated with selected object class)
color_layout = QHBoxLayout()
color_layout.addWidget(QLabel("Class Color:"))
self.color_btn = QPushButton()
self.color_btn.setFixedSize(40, 30)
self.color_btn.clicked.connect(self._on_color_picker)
self._update_color_button()
color_layout.addWidget(self.color_btn)
color_layout.addStretch()
class_layout.addLayout(color_layout)
# Selected class info
self.class_info_label = QLabel("No class selected")
self.class_info_label.setWordWrap(True)
self.class_info_label.setStyleSheet(
"QLabel { color: #888; font-style: italic; }"
)
class_layout.addWidget(self.class_info_label)
class_group.setLayout(class_layout)
layout.addWidget(class_group)
# Actions Group
actions_group = QGroupBox("Actions")
actions_layout = QVBoxLayout()
# Show / hide bounding boxes
self.show_bboxes_checkbox = QCheckBox("Show bounding boxes")
self.show_bboxes_checkbox.setChecked(True)
self.show_bboxes_checkbox.stateChanged.connect(self._on_show_bboxes_toggle)
actions_layout.addWidget(self.show_bboxes_checkbox)
self.clear_btn = QPushButton("Clear All Annotations")
self.clear_btn.clicked.connect(self._on_clear_annotations)
actions_layout.addWidget(self.clear_btn)
# Delete currently selected annotation (enabled when a selection exists)
self.delete_selected_btn = QPushButton("Delete Selected Annotation")
self.delete_selected_btn.clicked.connect(self._on_delete_selected_annotation)
self.delete_selected_btn.setEnabled(False)
actions_layout.addWidget(self.delete_selected_btn)
actions_group.setLayout(actions_layout)
layout.addWidget(actions_group)
layout.addStretch()
self.setLayout(layout)
def _update_color_button(self):
"""Update the color button appearance with current color."""
pixmap = QPixmap(40, 30)
pixmap.fill(self.current_color)
# Add border
painter = QPainter(pixmap)
painter.setPen(Qt.black)
painter.drawRect(0, 0, pixmap.width() - 1, pixmap.height() - 1)
painter.end()
self.color_btn.setIcon(QIcon(pixmap))
self.color_btn.setStyleSheet(f"background-color: {self.current_color.name()};")
def _load_object_classes(self):
"""Load object classes from database and populate combo box."""
try:
classes = self.db_manager.get_object_classes()
# Clear and repopulate combo box
self.class_combo.clear()
self.class_combo.addItem("-- Select Class / Show All --", None)
for cls in classes:
self.class_combo.addItem(cls["class_name"], cls)
logger.debug(f"Loaded {len(classes)} object classes")
except Exception as e:
logger.error(f"Error loading object classes: {e}")
QMessageBox.warning(
self, "Error", f"Failed to load object classes:\n{str(e)}"
)
def _on_polyline_toggle(self, checked: bool):
"""Handle polyline tool enable/disable."""
self.polyline_enabled = checked
if checked:
self.polyline_toggle_btn.setText("Stop Drawing Polyline")
self.polyline_toggle_btn.setStyleSheet(
"QPushButton { background-color: #4CAF50; }"
)
else:
self.polyline_toggle_btn.setText("Start Drawing Polyline")
self.polyline_toggle_btn.setStyleSheet("")
self.polyline_enabled_changed.emit(self.polyline_enabled)
logger.debug(f"Polyline tool {'enabled' if checked else 'disabled'}")
def _on_polyline_pen_width_changed(self, width: int):
"""Handle polyline pen width changes."""
self.polyline_pen_width_changed.emit(width)
logger.debug(f"Polyline pen width changed to {width}")
def _on_simplify_toggle(self, state: int):
"""Handle simplify-on-finish checkbox toggle."""
enabled = bool(state)
self.simplify_on_finish_changed.emit(enabled)
logger.debug(f"Simplify on finish set to {enabled}")
def _on_eps_change(self, val: float):
"""Handle epsilon (RDP tolerance) value changes."""
epsilon = float(val)
self.simplify_epsilon_changed.emit(epsilon)
logger.debug(f"Simplification epsilon changed to {epsilon}")
def _on_show_bboxes_toggle(self, state: int):
"""Handle 'Show bounding boxes' checkbox toggle."""
show = bool(state)
self.show_bboxes_changed.emit(show)
logger.debug(f"Show bounding boxes set to {show}")
def _on_color_picker(self):
"""Open color picker dialog and update the selected object's class color."""
if not self.current_class:
QMessageBox.warning(
self,
"No Class Selected",
"Please select an object class before changing its color.",
)
return
# Use current class color (without alpha) as the base
base_color = QColor(self.current_class.get("color", self.current_color.name()))
color = QColorDialog.getColor(
base_color,
self,
"Select Class Color",
QColorDialog.ShowAlphaChannel, # Allow alpha in UI, but store RGB in DB
)
if not color.isValid():
return
# Normalize to opaque RGB for storage
new_color = QColor(color)
new_color.setAlpha(255)
hex_color = new_color.name()
try:
# Update in database
self.db_manager.update_object_class(
class_id=self.current_class["id"], color=hex_color
)
except Exception as e:
logger.error(f"Failed to update class color in database: {e}")
QMessageBox.critical(
self,
"Error",
f"Failed to update class color in database:\n{str(e)}",
)
return
# Update local class data and combo box item data
self.current_class["color"] = hex_color
current_index = self.class_combo.currentIndex()
if current_index >= 0:
self.class_combo.setItemData(current_index, dict(self.current_class))
# Update info label text
info_text = f"Class: {self.current_class['class_name']}\nColor: {hex_color}"
if self.current_class.get("description"):
info_text += f"\nDescription: {self.current_class['description']}"
self.class_info_label.setText(info_text)
# Use semi-transparent version for polyline pen / button preview
class_color = QColor(hex_color)
class_color.setAlpha(128)
self.current_color = class_color
self._update_color_button()
self.polyline_pen_color_changed.emit(class_color)
logger.debug(
f"Updated class '{self.current_class['class_name']}' color to "
f"{hex_color} (polyline pen alpha={class_color.alpha()})"
)
# Notify listeners (e.g., AnnotationTab) so they can reload/redraw
self.class_color_changed.emit()
def _on_class_selected(self, index: int):
"""Handle object class selection (including '-- Select Class --')."""
class_data = self.class_combo.currentData()
if class_data:
self.current_class = class_data
# Update info label
info_text = (
f"Class: {class_data['class_name']}\n" f"Color: {class_data['color']}"
)
if class_data.get("description"):
info_text += f"\nDescription: {class_data['description']}"
self.class_info_label.setText(info_text)
# Update polyline pen color to match class color with semi-transparency
class_color = QColor(class_data["color"])
if class_color.isValid():
# Add 50% alpha for semi-transparency
class_color.setAlpha(128)
self.current_color = class_color
self._update_color_button()
self.polyline_pen_color_changed.emit(class_color)
self.class_selected.emit(class_data)
logger.debug(f"Selected class: {class_data['class_name']}")
else:
# "-- Select Class --" chosen: clear current class and show all annotations
self.current_class = None
self.class_info_label.setText("No class selected")
self.class_selected.emit(None)
logger.debug("Class selection cleared: showing annotations for all classes")
def _on_add_class(self):
"""Handle adding a new object class."""
# Get class name
class_name, ok = QInputDialog.getText(
self, "Add Object Class", "Enter class name:"
)
if not ok or not class_name.strip():
return
class_name = class_name.strip()
# Check if class already exists
existing = self.db_manager.get_object_class_by_name(class_name)
if existing:
QMessageBox.warning(
self, "Class Exists", f"A class named '{class_name}' already exists."
)
return
# Get color
color = QColorDialog.getColor(self.current_color, self, "Select Class Color")
if not color.isValid():
return
# Get optional description
description, ok = QInputDialog.getText(
self, "Class Description", "Enter class description (optional):"
)
if not ok:
description = None
# Add to database
try:
class_id = self.db_manager.add_object_class(
class_name, color.name(), description.strip() if description else None
)
logger.info(f"Added new object class: {class_name} (ID: {class_id})")
# Reload classes and select the new one
self._load_object_classes()
# Find and select the newly added class
for i in range(self.class_combo.count()):
class_data = self.class_combo.itemData(i)
if class_data and class_data.get("id") == class_id:
self.class_combo.setCurrentIndex(i)
break
QMessageBox.information(
self, "Success", f"Class '{class_name}' added successfully!"
)
except Exception as e:
logger.error(f"Error adding object class: {e}")
QMessageBox.critical(
self, "Error", f"Failed to add object class:\n{str(e)}"
)
def _on_clear_annotations(self):
"""Handle clear annotations button."""
reply = QMessageBox.question(
self,
"Clear Annotations",
"Are you sure you want to clear all annotations?",
QMessageBox.Yes | QMessageBox.No,
QMessageBox.No,
)
if reply == QMessageBox.Yes:
self.clear_annotations_requested.emit()
logger.debug("Clear annotations requested")
def _on_delete_selected_annotation(self):
"""Handle delete selected annotation button."""
self.delete_selected_annotation_requested.emit()
logger.debug("Delete selected annotation requested")
def set_has_selected_annotation(self, has_selection: bool):
"""
Enable/disable actions that require a selected annotation.
Args:
has_selection: True if an annotation is currently selected on the canvas.
"""
self.delete_selected_btn.setEnabled(bool(has_selection))
def get_current_class(self) -> Optional[Dict]:
"""Get currently selected object class."""
return self.current_class
def get_polyline_pen_color(self) -> QColor:
"""Get current polyline pen color."""
return self.current_color
def get_polyline_pen_width(self) -> int:
"""Get current polyline pen width."""
return self.polyline_pen_width_spin.value()
def is_polyline_enabled(self) -> bool:
"""Check if polyline tool is enabled."""
return self.polyline_enabled

View File

@@ -0,0 +1,282 @@
"""
Image display widget with zoom functionality for the microscopy object detection application.
Reusable widget for displaying images with zoom controls.
"""
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QScrollArea
from PySide6.QtGui import QPixmap, QImage, QKeyEvent
from PySide6.QtCore import Qt, QEvent, Signal
from pathlib import Path
import numpy as np
from src.utils.image import Image, ImageLoadError
from src.utils.logger import get_logger
logger = get_logger(__name__)
class ImageDisplayWidget(QWidget):
"""
Reusable widget for displaying images with zoom functionality.
Features:
- Display images from Image objects
- Zoom in/out with mouse wheel
- Zoom in/out with +/- keyboard keys
- Reset zoom with Ctrl+0
- Scroll area for large images
Signals:
zoom_changed: Emitted when zoom level changes (float zoom_scale)
"""
zoom_changed = Signal(float) # Emitted when zoom level changes
def __init__(self, parent=None):
"""
Initialize the image display widget.
Args:
parent: Parent widget
"""
super().__init__(parent)
self.current_image = None
self.original_pixmap = None # Store original pixmap for zoom
self.zoom_scale = 1.0 # Current zoom scale
self.zoom_min = 0.1 # Minimum zoom (10%)
self.zoom_max = 10.0 # Maximum zoom (1000%)
self.zoom_step = 0.1 # Zoom step for +/- keys
self.zoom_wheel_step = 0.15 # Zoom step for mouse wheel
self._setup_ui()
def _setup_ui(self):
"""Setup user interface."""
layout = QVBoxLayout()
layout.setContentsMargins(0, 0, 0, 0)
# Scroll area for image
self.scroll_area = QScrollArea()
self.scroll_area.setWidgetResizable(True)
self.scroll_area.setMinimumHeight(400)
self.image_label = QLabel("No image loaded")
self.image_label.setAlignment(Qt.AlignCenter)
self.image_label.setStyleSheet(
"QLabel { background-color: #2b2b2b; color: #888; }"
)
self.image_label.setScaledContents(False)
# Enable mouse tracking for wheel events
self.image_label.setMouseTracking(True)
self.scroll_area.setWidget(self.image_label)
# Install event filter to capture wheel events on scroll area
self.scroll_area.viewport().installEventFilter(self)
layout.addWidget(self.scroll_area)
self.setLayout(layout)
# Set focus policy to receive keyboard events
self.setFocusPolicy(Qt.StrongFocus)
def load_image(self, image: Image):
"""
Load and display an image.
Args:
image: Image object to display
Raises:
ImageLoadError: If image cannot be displayed
"""
self.current_image = image
# Reset zoom when loading new image
self.zoom_scale = 1.0
# Convert to QPixmap and display
self._display_image()
logger.debug(f"Loaded image into display widget: {image.width}x{image.height}")
def clear(self):
"""Clear the displayed image."""
self.current_image = None
self.original_pixmap = None
self.zoom_scale = 1.0
self.image_label.setText("No image loaded")
self.image_label.setPixmap(QPixmap())
logger.debug("Cleared image display")
def _display_image(self):
"""Display the current image in the image label."""
if self.current_image is None:
return
try:
# Get RGB image data
if self.current_image.channels == 3:
image_data = self.current_image.get_rgb()
height, width, channels = image_data.shape
else:
image_data = self.current_image.get_grayscale()
height, width = image_data.shape
channels = 1
# Ensure data is contiguous for proper QImage display
image_data = np.ascontiguousarray(image_data)
# Use actual stride from numpy array for correct display
bytes_per_line = image_data.strides[0]
qimage = QImage(
image_data.data,
width,
height,
bytes_per_line,
self.current_image.qtimage_format,
).copy() # Copy to ensure Qt owns its memory after this scope
# Convert to pixmap
pixmap = QPixmap.fromImage(qimage)
# Store original pixmap for zooming
self.original_pixmap = pixmap
# Apply zoom and display
self._apply_zoom()
except Exception as e:
logger.error(f"Error displaying image: {e}")
raise ImageLoadError(f"Failed to display image: {str(e)}")
def _apply_zoom(self):
"""Apply current zoom level to the displayed image."""
if self.original_pixmap is None:
return
# Calculate scaled size
scaled_width = int(self.original_pixmap.width() * self.zoom_scale)
scaled_height = int(self.original_pixmap.height() * self.zoom_scale)
# Scale pixmap
scaled_pixmap = self.original_pixmap.scaled(
scaled_width,
scaled_height,
Qt.KeepAspectRatio,
(
Qt.SmoothTransformation
if self.zoom_scale >= 1.0
else Qt.FastTransformation
),
)
# Display in label
self.image_label.setPixmap(scaled_pixmap)
self.image_label.setScaledContents(False)
self.image_label.adjustSize()
# Emit zoom changed signal
self.zoom_changed.emit(self.zoom_scale)
def zoom_in(self):
"""Zoom in on the image."""
if self.original_pixmap is None:
return
new_scale = self.zoom_scale + self.zoom_step
if new_scale <= self.zoom_max:
self.zoom_scale = new_scale
self._apply_zoom()
logger.debug(f"Zoomed in to {int(self.zoom_scale * 100)}%")
def zoom_out(self):
"""Zoom out from the image."""
if self.original_pixmap is None:
return
new_scale = self.zoom_scale - self.zoom_step
if new_scale >= self.zoom_min:
self.zoom_scale = new_scale
self._apply_zoom()
logger.debug(f"Zoomed out to {int(self.zoom_scale * 100)}%")
def reset_zoom(self):
"""Reset zoom to 100%."""
if self.original_pixmap is None:
return
self.zoom_scale = 1.0
self._apply_zoom()
logger.debug("Reset zoom to 100%")
def set_zoom(self, scale: float):
"""
Set zoom to a specific scale.
Args:
scale: Zoom scale (1.0 = 100%)
"""
if self.original_pixmap is None:
return
# Clamp to min/max
scale = max(self.zoom_min, min(self.zoom_max, scale))
self.zoom_scale = scale
self._apply_zoom()
logger.debug(f"Set zoom to {int(self.zoom_scale * 100)}%")
def get_zoom_percentage(self) -> int:
"""
Get current zoom level as percentage.
Returns:
Zoom level as integer percentage (e.g., 100 for 100%)
"""
return int(self.zoom_scale * 100)
def keyPressEvent(self, event: QKeyEvent):
"""Handle keyboard events for zooming."""
if event.key() in (Qt.Key_Plus, Qt.Key_Equal):
# + or = key (= is the unshifted + on many keyboards)
self.zoom_in()
event.accept()
elif event.key() == Qt.Key_Minus:
# - key
self.zoom_out()
event.accept()
elif event.key() == Qt.Key_0 and event.modifiers() == Qt.ControlModifier:
# Ctrl+0 to reset zoom
self.reset_zoom()
event.accept()
else:
super().keyPressEvent(event)
def eventFilter(self, obj, event: QEvent) -> bool:
"""Event filter to capture wheel events for zooming."""
if event.type() == QEvent.Wheel:
wheel_event = event
if self.original_pixmap is not None:
# Get wheel angle delta
delta = wheel_event.angleDelta().y()
# Zoom in/out based on wheel direction
if delta > 0:
# Scroll up = zoom in
new_scale = self.zoom_scale + self.zoom_wheel_step
if new_scale <= self.zoom_max:
self.zoom_scale = new_scale
self._apply_zoom()
else:
# Scroll down = zoom out
new_scale = self.zoom_scale - self.zoom_wheel_step
if new_scale >= self.zoom_min:
self.zoom_scale = new_scale
self._apply_zoom()
return True # Event handled
return super().eventFilter(obj, event)

49
src/gui_launcher.py Normal file
View File

@@ -0,0 +1,49 @@
"""GUI launcher module for microscopy object detection application."""
import sys
from pathlib import Path
from PySide6.QtWidgets import QApplication
from PySide6.QtCore import Qt
from src import __version__
from src.gui.main_window import MainWindow
from src.utils.logger import setup_logging
from src.utils.config_manager import ConfigManager
def main():
"""Launch the GUI application."""
# Setup logging
config_manager = ConfigManager()
log_config = config_manager.get_section("logging")
setup_logging(
log_file=log_config.get("file", "logs/app.log"),
level=log_config.get("level", "INFO"),
log_format=log_config.get("format"),
)
# Enable High DPI scaling
QApplication.setHighDpiScaleFactorRoundingPolicy(
Qt.HighDpiScaleFactorRoundingPolicy.PassThrough
)
# Create Qt application
app = QApplication(sys.argv)
app.setApplicationName("Microscopy Object Detection")
app.setOrganizationName("MicroscopyLab")
app.setApplicationVersion(__version__)
# Set application style
app.setStyle("Fusion")
# Create and show main window
window = MainWindow()
window.show()
# Run application
sys.exit(app.exec())
if __name__ == "__main__":
main()

View File

@@ -5,12 +5,12 @@ Handles detection inference and result storage.
from typing import List, Dict, Optional, Callable
from pathlib import Path
from PIL import Image
import cv2
import numpy as np
from src.model.yolo_wrapper import YOLOWrapper
from src.database.db_manager import DatabaseManager
from src.utils.image import Image
from src.utils.logger import get_logger
from src.utils.file_utils import get_relative_path
@@ -42,6 +42,7 @@ class InferenceEngine:
relative_path: str,
conf: float = 0.25,
save_to_db: bool = True,
repository_root: Optional[str] = None,
) -> Dict:
"""
Detect objects in a single image.
@@ -51,48 +52,79 @@ class InferenceEngine:
relative_path: Relative path from repository root
conf: Confidence threshold
save_to_db: Whether to save results to database
repository_root: Base directory used to compute relative_path (if known)
Returns:
Dictionary with detection results
"""
try:
# Normalize storage path (fall back to absolute path when repo root is unknown)
stored_relative_path = relative_path
if not repository_root:
stored_relative_path = str(Path(image_path).resolve())
# Get image dimensions
img = Image.open(image_path)
width, height = img.size
img.close()
img = Image(image_path)
width = img.width
height = img.height
# Perform detection
detections = self.yolo.predict(image_path, conf=conf)
# Add/get image in database
image_id = self.db_manager.get_or_create_image(
relative_path=relative_path,
relative_path=stored_relative_path,
filename=Path(image_path).name,
width=width,
height=height,
)
# Save detections to database
if save_to_db and detections:
detection_records = []
for det in detections:
# Use normalized bbox from detection
bbox_normalized = det[
"bbox_normalized"
] # [x_min, y_min, x_max, y_max]
inserted_count = 0
deleted_count = 0
record = {
"image_id": image_id,
"model_id": self.model_id,
"class_name": det["class_name"],
"bbox": tuple(bbox_normalized),
"confidence": det["confidence"],
"metadata": {"class_id": det["class_id"]},
}
detection_records.append(record)
# Save detections to database, replacing any previous results for this image/model
if save_to_db:
deleted_count = self.db_manager.delete_detections_for_image(
image_id, self.model_id
)
if detections:
detection_records = []
for det in detections:
# Use normalized bbox from detection
bbox_normalized = det[
"bbox_normalized"
] # [x_min, y_min, x_max, y_max]
self.db_manager.add_detections_batch(detection_records)
logger.info(f"Saved {len(detection_records)} detections to database")
metadata = {
"class_id": det["class_id"],
"source_path": str(Path(image_path).resolve()),
}
if repository_root:
metadata["repository_root"] = str(
Path(repository_root).resolve()
)
record = {
"image_id": image_id,
"model_id": self.model_id,
"class_name": det["class_name"],
"bbox": tuple(bbox_normalized),
"confidence": det["confidence"],
"segmentation_mask": det.get("segmentation_mask"),
"metadata": metadata,
}
detection_records.append(record)
inserted_count = self.db_manager.add_detections_batch(
detection_records
)
logger.info(
f"Saved {inserted_count} detections to database (replaced {deleted_count})"
)
else:
logger.info(
f"Detection run removed {deleted_count} stale entries but produced no new detections"
)
return {
"success": True,
@@ -141,7 +173,12 @@ class InferenceEngine:
rel_path = get_relative_path(image_path, repository_root)
# Perform detection
result = self.detect_single(image_path, rel_path, conf)
result = self.detect_single(
image_path,
rel_path,
conf=conf,
repository_root=repository_root,
)
results.append(result)
# Update progress
@@ -160,6 +197,7 @@ class InferenceEngine:
conf: float = 0.25,
bbox_thickness: int = 2,
bbox_colors: Optional[Dict[str, str]] = None,
draw_masks: bool = True,
) -> tuple:
"""
Detect objects and return annotated image.
@@ -169,6 +207,7 @@ class InferenceEngine:
conf: Confidence threshold
bbox_thickness: Thickness of bounding boxes
bbox_colors: Dictionary mapping class names to hex colors
draw_masks: Whether to draw segmentation masks (if available)
Returns:
Tuple of (detections, annotated_image_array)
@@ -189,12 +228,8 @@ class InferenceEngine:
bbox_colors = {}
default_color = self._hex_to_bgr(bbox_colors.get("default", "#00FF00"))
# Draw bounding boxes
# Draw detections
for det in detections:
# Get absolute coordinates
bbox_abs = det["bbox_absolute"]
x1, y1, x2, y2 = [int(v) for v in bbox_abs]
# Get color for this class
class_name = det["class_name"]
color_hex = bbox_colors.get(
@@ -202,7 +237,33 @@ class InferenceEngine:
)
color = self._hex_to_bgr(color_hex)
# Draw box
# Draw segmentation mask if available and requested
if draw_masks and det.get("segmentation_mask"):
mask_normalized = det["segmentation_mask"]
if mask_normalized and len(mask_normalized) > 0:
# Convert normalized coordinates to absolute pixels
mask_points = np.array(
[
[int(pt[0] * width), int(pt[1] * height)]
for pt in mask_normalized
],
dtype=np.int32,
)
# Create a semi-transparent overlay
overlay = img.copy()
cv2.fillPoly(overlay, [mask_points], color)
# Blend with original image (30% opacity)
cv2.addWeighted(overlay, 0.3, img, 0.7, 0, img)
# Draw mask contour
cv2.polylines(img, [mask_points], True, color, bbox_thickness)
# Get absolute coordinates for bounding box
bbox_abs = det["bbox_absolute"]
x1, y1, x2, y2 = [int(v) for v in bbox_abs]
# Draw bounding box
cv2.rectangle(img, (x1, y1), (x2, y2), color, bbox_thickness)
# Prepare label

View File

@@ -1,13 +1,21 @@
"""
YOLO model wrapper for the microscopy object detection application.
Provides a clean interface to YOLOv8 for training, validation, and inference.
"""YOLO model wrapper for the microscopy object detection application.
Notes on 16-bit TIFF support:
- Ultralytics training defaults assume 8-bit images and normalize by dividing by 255.
- This project can patch Ultralytics at runtime to decode TIFFs via `tifffile` and
normalize `uint16` correctly.
See [`apply_ultralytics_16bit_tiff_patches()`](src/utils/ultralytics_16bit_patch.py:1).
"""
from ultralytics import YOLO
from pathlib import Path
from typing import Optional, List, Dict, Callable, Any
import torch
import tempfile
import os
from src.utils.image import Image
from src.utils.logger import get_logger
from src.utils.ultralytics_16bit_patch import apply_ultralytics_16bit_tiff_patches
logger = get_logger(__name__)
@@ -16,7 +24,7 @@ logger = get_logger(__name__)
class YOLOWrapper:
"""Wrapper for YOLOv8 model operations."""
def __init__(self, model_path: str = "yolov8s.pt"):
def __init__(self, model_path: str = "yolov8s-seg.pt"):
"""
Initialize YOLO model.
@@ -28,6 +36,9 @@ class YOLOWrapper:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"YOLOWrapper initialized with device: {self.device}")
# Apply Ultralytics runtime patches early (before first import/instantiation of YOLO datasets/trainers).
apply_ultralytics_16bit_tiff_patches()
def load_model(self) -> bool:
"""
Load YOLO model from path.
@@ -37,6 +48,9 @@ class YOLOWrapper:
"""
try:
logger.info(f"Loading YOLO model from {self.model_path}")
# Import YOLO lazily to ensure runtime patches are applied first.
from ultralytics import YOLO
self.model = YOLO(self.model_path)
self.model.to(self.device)
logger.info("Model loaded successfully")
@@ -55,6 +69,7 @@ class YOLOWrapper:
save_dir: str = "data/models",
name: str = "custom_model",
resume: bool = False,
callbacks: Optional[Dict[str, Callable]] = None,
**kwargs,
) -> Dict[str, Any]:
"""
@@ -69,19 +84,29 @@ class YOLOWrapper:
save_dir: Directory to save trained model
name: Name for the training run
resume: Resume training from last checkpoint
callbacks: Optional Ultralytics callback dictionary
**kwargs: Additional training arguments
Returns:
Dictionary with training results
"""
if self.model is None:
self.load_model()
if not self.load_model():
raise RuntimeError(f"Failed to load model from {self.model_path}")
try:
logger.info(f"Starting training: {name}")
logger.info(
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
)
logger.info(f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}")
# Defaults for 16-bit safety: disable augmentations that force uint8 and HSV ops that assume 0..255.
# Users can override by passing explicit kwargs.
kwargs.setdefault("mosaic", 0.0)
kwargs.setdefault("mixup", 0.0)
kwargs.setdefault("cutmix", 0.0)
kwargs.setdefault("copy_paste", 0.0)
kwargs.setdefault("hsv_h", 0.0)
kwargs.setdefault("hsv_s", 0.0)
kwargs.setdefault("hsv_v", 0.0)
# Train the model
results = self.model.train(
@@ -117,13 +142,12 @@ class YOLOWrapper:
Dictionary with validation metrics
"""
if self.model is None:
self.load_model()
if not self.load_model():
raise RuntimeError(f"Failed to load model from {self.model_path}")
try:
logger.info(f"Starting validation on {split} split")
results = self.model.val(
data=data_yaml, split=split, device=self.device, **kwargs
)
results = self.model.val(data=data_yaml, split=split, device=self.device, **kwargs)
logger.info("Validation completed successfully")
return self._format_validation_results(results)
@@ -158,10 +182,13 @@ class YOLOWrapper:
List of detection dictionaries
"""
if self.model is None:
self.load_model()
if not self.load_model():
raise RuntimeError(f"Failed to load model from {self.model_path}")
prepared_source, cleanup_path = self._prepare_source(source)
imgsz = 1088
try:
logger.info(f"Running inference on {source}")
logger.info(f"Running inference on {source} -> prepared_source {prepared_source}")
results = self.model.predict(
source=source,
conf=conf,
@@ -170,6 +197,7 @@ class YOLOWrapper:
save_txt=save_txt,
save_conf=save_conf,
device=self.device,
imgsz=imgsz,
**kwargs,
)
@@ -180,10 +208,14 @@ class YOLOWrapper:
except Exception as e:
logger.error(f"Error during inference: {e}")
raise
finally:
if 0: # cleanup_path:
try:
os.remove(cleanup_path)
except OSError as cleanup_error:
logger.warning(f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}")
def export(
self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
) -> str:
def export(self, format: str = "onnx", output_path: Optional[str] = None, **kwargs) -> str:
"""
Export model to different format.
@@ -196,7 +228,8 @@ class YOLOWrapper:
Path to exported model
"""
if self.model is None:
self.load_model()
if not self.load_model():
raise RuntimeError(f"Failed to load model from {self.model_path}")
try:
logger.info(f"Exporting model to {format} format")
@@ -208,13 +241,35 @@ class YOLOWrapper:
logger.error(f"Error exporting model: {e}")
raise
def _prepare_source(self, source):
"""Convert single-channel images to RGB temporarily for inference."""
cleanup_path = None
if isinstance(source, (str, Path)):
source_path = Path(source)
if source_path.is_file():
try:
img_obj = Image(source_path)
suffix = source_path.suffix or ".png"
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
tmp_path = tmp.name
tmp.close()
img_obj.save(tmp_path)
cleanup_path = tmp_path
logger.info(f"Converted image {source_path} to RGB for inference at {tmp_path}")
return tmp_path, cleanup_path
except Exception as convert_error:
logger.warning(
f"Failed to preprocess {source_path} as RGB, continuing with original file: {convert_error}"
)
return source, cleanup_path
def _format_training_results(self, results) -> Dict[str, Any]:
"""Format training results into dictionary."""
try:
# Get the results dict
results_dict = (
results.results_dict if hasattr(results, "results_dict") else {}
)
results_dict = results.results_dict if hasattr(results, "results_dict") else {}
formatted = {
"success": True,
@@ -247,9 +302,7 @@ class YOLOWrapper:
"mAP50-95": float(box_metrics.map),
"precision": float(box_metrics.mp),
"recall": float(box_metrics.mr),
"fitness": (
float(results.fitness) if hasattr(results, "fitness") else 0.0
),
"fitness": (float(results.fitness) if hasattr(results, "fitness") else 0.0),
}
# Add per-class metrics if available
@@ -259,11 +312,7 @@ class YOLOWrapper:
if idx < len(box_metrics.ap):
class_metrics[name] = {
"ap": float(box_metrics.ap[idx]),
"ap50": (
float(box_metrics.ap50[idx])
if hasattr(box_metrics, "ap50")
else 0.0
),
"ap50": (float(box_metrics.ap50[idx]) if hasattr(box_metrics, "ap50") else 0.0),
}
formatted["class_metrics"] = class_metrics
@@ -282,6 +331,10 @@ class YOLOWrapper:
boxes = result.boxes
image_path = str(result.path)
orig_shape = result.orig_shape # (height, width)
height, width = orig_shape
# Check if this is a segmentation model with masks
has_masks = hasattr(result, "masks") and result.masks is not None
for i in range(len(boxes)):
# Get normalized coordinates
@@ -292,13 +345,32 @@ class YOLOWrapper:
"class_id": int(boxes.cls[i]),
"class_name": result.names[int(boxes.cls[i])],
"confidence": float(boxes.conf[i]),
"bbox_normalized": [
float(v) for v in xyxyn
], # [x_min, y_min, x_max, y_max]
"bbox_absolute": [
float(v) for v in boxes.xyxy[i].cpu().numpy()
], # Absolute pixels
"bbox_normalized": [float(v) for v in xyxyn], # [x_min, y_min, x_max, y_max]
"bbox_absolute": [float(v) for v in boxes.xyxy[i].cpu().numpy()], # Absolute pixels
}
# Extract segmentation mask if available
if has_masks:
try:
# Get the mask for this detection
mask_data = result.masks.xy[i] # Polygon coordinates in absolute pixels
# Convert to normalized coordinates
if len(mask_data) > 0:
mask_normalized = []
for point in mask_data:
x_norm = float(point[0]) / width
y_norm = float(point[1]) / height
mask_normalized.append([x_norm, y_norm])
detection["segmentation_mask"] = mask_normalized
else:
detection["segmentation_mask"] = None
except Exception as mask_error:
logger.warning(f"Error extracting mask for detection {i}: {mask_error}")
detection["segmentation_mask"] = None
else:
detection["segmentation_mask"] = None
detections.append(detection)
return detections
@@ -308,9 +380,7 @@ class YOLOWrapper:
return []
@staticmethod
def convert_bbox_format(
bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy"
) -> List[float]:
def convert_bbox_format(bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy") -> List[float]:
"""
Convert bounding box between formats.

View File

@@ -0,0 +1,7 @@
"""
Utility modules for the microscopy object detection application.
"""
from src.utils.image import Image, ImageLoadError
__all__ = ["Image", "ImageLoadError"]

View File

@@ -7,6 +7,7 @@ import yaml
from pathlib import Path
from typing import Any, Dict, Optional
from src.utils.logger import get_logger
from src.utils.image import Image
logger = get_logger(__name__)
@@ -46,18 +47,15 @@ class ConfigManager:
"database": {"path": "data/detections.db"},
"image_repository": {
"base_path": "",
"allowed_extensions": [
".jpg",
".jpeg",
".png",
".tif",
".tiff",
".bmp",
],
"allowed_extensions": Image.SUPPORTED_EXTENSIONS,
},
"models": {
"default_base_model": "yolov8s.pt",
"default_base_model": "yolov8s-seg.pt",
"models_directory": "data/models",
"base_model_choices": [
"yolov8s-seg.pt",
"yolo11s-seg.pt",
],
},
"training": {
"default_epochs": 100,
@@ -65,6 +63,20 @@ class ConfigManager:
"default_imgsz": 640,
"default_patience": 50,
"default_lr0": 0.01,
"two_stage": {
"enabled": False,
"stage1": {
"epochs": 20,
"lr0": 0.0005,
"patience": 10,
"freeze": 10,
},
"stage2": {
"epochs": 150,
"lr0": 0.0003,
"patience": 30,
},
},
},
"detection": {
"default_confidence": 0.25,
@@ -213,6 +225,4 @@ class ConfigManager:
def get_allowed_extensions(self) -> list:
"""Get list of allowed image file extensions."""
return self.get(
"image_repository.allowed_extensions", [".jpg", ".jpeg", ".png"]
)
return self.get("image_repository.allowed_extensions", Image.SUPPORTED_EXTENSIONS)

View File

@@ -0,0 +1,103 @@
import numpy as np
from pathlib import Path
from skimage.draw import polygon
from tifffile import TiffFile
from src.database.db_manager import DatabaseManager
def read_image(image_path: Path) -> np.ndarray:
metadata = {}
with TiffFile(image_path) as tif:
image = tif.asarray()
metadata = tif.imagej_metadata
return image, metadata
def main():
polygon_vertices = np.array([[10, 10], [50, 10], [50, 50], [10, 50]])
image = np.zeros((100, 100), dtype=np.uint8)
rr, cc = polygon(polygon_vertices[:, 0], polygon_vertices[:, 1])
image[rr, cc] = 255
if __name__ == "__main__":
db = DatabaseManager()
model_name = "c17"
model_id = db.get_models(filters={"model_name": model_name})[0]["id"]
print(f"Model name {model_name}, id {model_id}")
detections = db.get_detections(filters={"model_id": model_id})
file_stems = set()
for detection in detections:
file_stems.add(detection["image_filename"].split("_")[0])
print("Files:", file_stems)
for stem in file_stems:
print(stem)
detections = db.get_detections(filters={"model_id": model_id, "i.filename": f"LIKE %{stem}%"})
annotations = []
for detection in detections:
source_path = Path(detection["metadata"]["source_path"])
image, metadata = read_image(source_path)
offset = np.array(list(map(int, metadata["tile_section"].split(","))))[::-1]
scale = np.array(list(map(int, metadata["patch_size"].split(","))))[::-1]
# tile_size = np.array(list(map(int, metadata["tile_size"].split(","))))
segmentation = np.array(detection["segmentation_mask"]) # * tile_size
# print(source_path, image, metadata, segmentation.shape)
# print(offset)
# print(scale)
# print(segmentation)
# segmentation = (segmentation + offset * tile_size) / (tile_size * scale)
segmentation = (segmentation + offset) / scale
yolo_annotation = f"{detection['metadata']['class_id']} " + " ".join(
[f"{x:.6f} {y:.6f}" for x, y in segmentation]
)
annotations.append(yolo_annotation)
# print(segmentation)
# print(yolo_annotation)
# aa
print(
" ",
detection["model_name"],
detection["image_id"],
detection["image_filename"],
source_path,
metadata["label_path"],
)
# section_i_section_j = detection["image_filename"].split("_")[1].split(".")[0]
# print(" ", section_i_section_j)
label_path = metadata["label_path"]
print(" ", label_path)
with open(label_path, "w") as f:
f.write("\n".join(annotations))
exit()
for detection in detections:
print(detection["model_name"], detection["image_id"], detection["image_filename"])
print(detections[0])
# polygon_vertices = np.array([[10, 10], [50, 10], [50, 50], [10, 50]])
# image = np.zeros((100, 100), dtype=np.uint8)
# rr, cc = polygon(polygon_vertices[:, 0], polygon_vertices[:, 1])
# image[rr, cc] = 255
# import matplotlib.pyplot as plt
# plt.imshow(image, cmap='gray')
# plt.show()

View File

@@ -28,7 +28,9 @@ def get_image_files(
List of absolute paths to image files
"""
if allowed_extensions is None:
allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
from src.utils.image import Image
allowed_extensions = Image.SUPPORTED_EXTENSIONS
# Normalize extensions to lowercase
allowed_extensions = [ext.lower() for ext in allowed_extensions]
@@ -204,7 +206,9 @@ def is_image_file(
True if file is an image
"""
if allowed_extensions is None:
allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
from src.utils.image import Image
allowed_extensions = Image.SUPPORTED_EXTENSIONS
extension = Path(file_path).suffix.lower()
return extension in [ext.lower() for ext in allowed_extensions]

370
src/utils/image.py Normal file
View File

@@ -0,0 +1,370 @@
"""
Image loading and management utilities for the microscopy object detection application.
"""
import cv2
import numpy as np
from pathlib import Path
from typing import Optional, Tuple, Union
from src.utils.logger import get_logger
from src.utils.file_utils import validate_file_path, is_image_file
from PySide6.QtGui import QImage
from tifffile import imread, imwrite
logger = get_logger(__name__)
def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.5) -> np.ndarray:
"""
Convert a grayscale image to a pseudo-RGB image using a gamma correction.
Args:
arr: Input grayscale image as numpy array
Returns:
Pseudo-RGB image as numpy array
"""
if arr.ndim != 2:
raise ValueError("Input array must be a grayscale image with shape (H, W)")
a1 = arr.copy().astype(np.float32)
a1 -= np.percentile(a1, 2)
a1[a1 < 0] = 0
p999 = np.percentile(a1, 99.9)
a1[a1 > p999] = p999
a1 /= a1.max()
if 1:
a2 = a1.copy()
a2 = a2**gamma
a2 /= a2.max()
a3 = a1.copy()
p9999 = np.percentile(a3, 99.99)
a3[a3 > p9999] = p9999
a3 /= a3.max()
# return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
# return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
out = np.stack([a1, a2, a3], axis=0)
# print(any(np.isnan(out).flatten()))
return out
class ImageLoadError(Exception):
"""Exception raised when an image cannot be loaded."""
pass
class Image:
"""
A class for loading and managing images from file paths.
Supports multiple image formats: .jpg, .jpeg, .png, .tif, .tiff, .bmp
Provides access to image data in multiple formats (OpenCV/numpy, PIL).
Attributes:
path: Path to the image file
data: Image data as numpy array (OpenCV format, BGR)
pil_image: Image data as PIL Image (RGB)
width: Image width in pixels
height: Image height in pixels
channels: Number of color channels
format: Image file format
size_bytes: File size in bytes
"""
SUPPORTED_EXTENSIONS = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
def __init__(self, image_path: Union[str, Path]):
"""
Initialize an Image object by loading from a file path.
Args:
image_path: Path to the image file (string or Path object)
Raises:
ImageLoadError: If the image cannot be loaded or is invalid
"""
self.path = Path(image_path)
self._data: Optional[np.ndarray] = None
self._width: int = 0
self._height: int = 0
self._channels: int = 0
self._format: str = ""
self._size_bytes: int = 0
self._dtype: Optional[np.dtype] = None
# Load the image
self._load()
def _load(self) -> None:
"""
Load the image from disk.
Raises:
ImageLoadError: If the image cannot be loaded
"""
# Validate path
if not validate_file_path(str(self.path), must_exist=True):
raise ImageLoadError(f"Invalid or non-existent file path: {self.path}")
# Check file extension
if not is_image_file(str(self.path), self.SUPPORTED_EXTENSIONS):
ext = self.path.suffix.lower()
raise ImageLoadError(
f"Unsupported image format: {ext}. " f"Supported formats: {', '.join(self.SUPPORTED_EXTENSIONS)}"
)
try:
if self.path.suffix.lower() in [".tif", ".tiff"]:
self._data = imread(str(self.path))
else:
# raise NotImplementedError("RGB is not implemented")
# Load with OpenCV (returns BGR format)
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
if self._data is None:
raise ImageLoadError(f"Failed to load image with OpenCV: {self.path}")
# Extract metadata
# print(self._data.shape)
if len(self._data.shape) == 2:
self._height, self._width = self._data.shape[:2]
self._channels = 1
else:
self._height, self._width = self._data.shape[1:]
self._channels = self._data.shape[0]
# self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1
self._format = self.path.suffix.lower().lstrip(".")
self._size_bytes = self.path.stat().st_size
self._dtype = self._data.dtype
if 0:
logger.info(
f"Successfully loaded image: {self.path.name} "
f"({self._width}x{self._height}, {self._channels} channels, "
f"{self._format.upper()})"
)
except Exception as e:
logger.error(f"Error loading image {self.path}: {e}")
raise ImageLoadError(f"Failed to load image: {e}") from e
@property
def data(self) -> np.ndarray:
"""
Get image data as numpy array (OpenCV format, BGR or grayscale).
Returns:
Image data as numpy array
"""
if self._data is None:
raise ImageLoadError("Image data not available")
return self._data
@property
def width(self) -> int:
"""Get image width in pixels."""
return self._width
@property
def height(self) -> int:
"""Get image height in pixels."""
return self._height
@property
def shape(self) -> Tuple[int, int, int]:
"""
Get image shape as (height, width, channels).
Returns:
Tuple of (height, width, channels)
"""
print("shape", self._height, self._width, self._channels)
return (self._height, self._width, self._channels)
@property
def channels(self) -> int:
"""Get number of color channels."""
return self._channels
@property
def format(self) -> str:
"""Get image file format (e.g., 'jpg', 'png')."""
return self._format
@property
def size_bytes(self) -> int:
"""Get file size in bytes."""
return self._size_bytes
@property
def size_mb(self) -> float:
"""Get file size in megabytes."""
return self._size_bytes / (1024 * 1024)
@property
def dtype(self) -> np.dtype:
"""Get the data type of the image array."""
if self._dtype is None:
raise ImageLoadError("Image dtype not available")
return self._dtype
@property
def qtimage_format(self) -> QImage.Format:
"""
Get the appropriate QImage format for the image.
Returns:
QImage.Format enum value
"""
if self._channels == 3:
return QImage.Format_RGB888
elif self._channels == 4:
return QImage.Format_RGBA8888
elif self._channels == 1:
if self._dtype == np.uint16:
return QImage.Format_Grayscale16
elif self._dtype == np.uint8:
return QImage.Format_Grayscale8
elif self._dtype == np.float32:
return QImage.Format_BGR30
else:
raise ImageLoadError(f"Unsupported number of channels: {self._channels}")
def get_rgb(self) -> np.ndarray:
"""
Get image data as RGB numpy array.
Returns:
Image data in RGB format as numpy array
"""
if self.channels == 1:
img = get_pseudo_rgb(self.data)
self._dtype = img.dtype
return img, True
elif self._channels == 3:
return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB), False
elif self._channels == 4:
return cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA), False
else:
raise NotImplementedError
# else:
# return self._data
def get_qt_rgb(self) -> np.ascontiguousarray:
# we keep data as (C, H, W)
_img, pseudo = self.get_rgb()
if pseudo:
img = np.zeros((self.height, self.width, 4), dtype=np.float32)
img[..., 0] = _img[0] # R gradient
img[..., 1] = _img[1] # G gradient
img[..., 2] = _img[2] # B constant
img[..., 3] = 1.0 # A = 1.0 (opaque)
return np.ascontiguousarray(img)
else:
return np.ascontiguousarray(_img)
def get_grayscale(self) -> np.ndarray:
"""
Get image as grayscale numpy array.
Returns:
Grayscale image as numpy array
"""
if self._channels == 1:
return self._data
else:
return cv2.cvtColor(self._data, cv2.COLOR_BGR2GRAY)
def copy(self) -> np.ndarray:
"""
Get a copy of the image data.
Returns:
Copy of image data as numpy array
"""
return self._data.copy()
def resize(self, width: int, height: int) -> np.ndarray:
"""
Resize the image to specified dimensions.
Args:
width: Target width in pixels
height: Target height in pixels
Returns:
Resized image as numpy array (does not modify original)
"""
return cv2.resize(self._data, (width, height))
def is_grayscale(self) -> bool:
"""
Check if image is grayscale.
Returns:
True if image is grayscale (1 channel)
"""
return self._channels == 1
def is_color(self) -> bool:
"""
Check if image is color.
Returns:
True if image has 3 or more channels
"""
return self._channels >= 3
def save(self, path: Union[str, Path], pseudo_rgb: bool = True) -> None:
if self.channels == 1:
if pseudo_rgb:
img = get_pseudo_rgb(self.data)
print("Image.save", img.shape)
else:
img = np.repeat(self.data, 3, axis=2)
else:
raise NotImplementedError("Only grayscale images are supported for now.")
imwrite(path, data=img)
def __repr__(self) -> str:
"""String representation of the Image object."""
return (
f"Image(path='{self.path.name}', "
# Display as HxWxC to match the conventional NumPy shape semantics.
f"shape=({self._height}x{self._width}x{self._channels}), "
f"format={self._format}, "
f"size={self.size_mb:.2f}MB)"
)
def __str__(self) -> str:
"""String representation of the Image object."""
return self.__repr__()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--path", type=str, required=True)
args = parser.parse_args()
img = Image(args.path)
img.save(args.path + "test.tif")
print(img)

View File

@@ -0,0 +1,168 @@
import numpy as np
from roifile import ImagejRoi
from tifffile import TiffFile, TiffWriter
from pathlib import Path
class UT:
"""
Docstring for UT
Operetta files along with rois drawn in ImageJ
"""
def __init__(self, roifile_fn: Path, no_labels: bool):
self.roifile_fn = roifile_fn
print("is file", self.roifile_fn.is_file())
self.rois = None
if no_labels:
self.rois = ImagejRoi.fromfile(self.roifile_fn)
print(self.roifile_fn.stem)
print(self.roifile_fn.parent.parts[-1])
if "Roi-" in self.roifile_fn.stem:
self.stem = self.roifile_fn.stem.split("Roi-")[1]
else:
self.stem = self.roifile_fn.parent.parts[-1]
else:
self.roifile_fn = roifile_fn / roifile_fn.parts[-1]
self.stem = self.roifile_fn.stem
print(self.roifile_fn)
print(self.stem)
self.image, self.image_props = self._load_images()
def _load_images(self):
"""Loading sequence of tif files
array sequence is CZYX
"""
print("Loading images:", self.roifile_fn.parent, self.stem)
fns = list(self.roifile_fn.parent.glob(f"{self.stem.lower()}*.tif*"))
stems = [fn.stem.split(self.stem)[-1] for fn in fns]
n_ch = len(set([stem.split("-ch")[-1].split("t")[0] for stem in stems]))
n_p = len(set([stem.split("-")[0] for stem in stems]))
n_t = len(set([stem.split("t")[1] for stem in stems]))
with TiffFile(fns[0]) as tif:
img = tif.asarray()
w, h = img.shape
dtype = img.dtype
self.image_props = {
"channels": n_ch,
"planes": n_p,
"tiles": n_t,
"width": w,
"height": h,
"dtype": dtype,
}
print("Image props", self.image_props)
image_stack = np.zeros((n_ch, n_p, w, h), dtype=dtype)
for fn in fns:
with TiffFile(fn) as tif:
img = tif.asarray()
stem = fn.stem.split(self.stem)[-1]
ch = int(stem.split("-ch")[-1].split("t")[0])
p = int(stem.split("-")[0].split("p")[1])
t = int(stem.split("t")[1])
print(fn.stem, "ch", ch, "p", p, "t", t)
image_stack[ch - 1, p - 1] = img
print(image_stack.shape)
return image_stack, self.image_props
@property
def width(self):
return self.image_props["width"]
@property
def height(self):
return self.image_props["height"]
@property
def nchannels(self):
return self.image_props["channels"]
@property
def nplanes(self):
return self.image_props["planes"]
def export_rois(
self,
path: Path,
subfolder: str = "labels",
class_index: int = 0,
):
"""Export rois to a file"""
with open(path / subfolder / f"{self.stem}.txt", "w") as f:
for i, roi in enumerate(self.rois):
rc = roi.subpixel_coordinates
if rc is None:
print(f"No coordinates: {self.roifile_fn}, element {i}, out of {len(self.rois)}")
continue
xmn, ymn = rc.min(axis=0)
xmx, ymx = rc.max(axis=0)
xc = (xmn + xmx) / 2
yc = (ymn + ymx) / 2
bw = xmx - xmn
bh = ymx - ymn
coords = f"{xc/self.width} {yc/self.height} {bw/self.width} {bh/self.height} "
for x, y in rc:
coords += f"{x/self.width} {y/self.height} "
f.write(f"{class_index} {coords}\n")
return
def export_image(
self,
path: Path,
subfolder: str = "images",
plane_mode: str = "max projection",
channel: int = 0,
):
"""Export image to a file"""
if plane_mode == "max projection":
self.image = np.max(self.image[channel], axis=0)
print(self.image.shape)
print(path / subfolder / f"{self.stem}.tif")
with TiffWriter(path / subfolder / f"{self.stem}.tif") as tif:
tif.write(self.image)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input", nargs="*", type=Path)
parser.add_argument("-o", "--output", type=Path)
parser.add_argument(
"--no-labels",
action="store_false",
help="Source does not have labels, export only images",
)
args = parser.parse_args()
# print(args)
# aa
for path in args.input:
print("Path:", path)
if not args.no_labels:
print("No labels")
ut = UT(path, args.no_labels)
ut.export_image(args.output, plane_mode="max projection", channel=0)
else:
for rfn in Path(path).glob("*.zip"):
# if Path(path).suffix == ".zip":
print("Roi FN:", rfn)
ut = UT(rfn, args.no_labels)
ut.export_rois(args.output, class_index=0)
ut.export_image(args.output, plane_mode="max projection", channel=0)
print()

368
src/utils/image_splitter.py Normal file
View File

@@ -0,0 +1,368 @@
import numpy as np
from pathlib import Path
from tifffile import imread, imwrite
from shapely.geometry import LineString
from copy import deepcopy
from scipy.ndimage import zoom
# debug
from src.utils.image import Image
from show_yolo_seg import draw_annotations
import pylab as plt
import cv2
class Label:
def __init__(self, yolo_annotation: str):
class_id, bbox, polygon = self.parse_yolo_annotation(yolo_annotation)
self.class_id = class_id
self.bbox = bbox
self.polygon = polygon
def parse_yolo_annotation(self, yolo_annotation: str):
class_id, *coords = yolo_annotation.split()
class_id = int(class_id)
bbox = np.array(coords[:4], dtype=np.float32)
polygon = np.array(coords[4:], dtype=np.float32).reshape(-1, 2) if len(coords) > 4 else None
if not any(np.isclose(polygon[0], polygon[-1])):
polygon = np.vstack([polygon, polygon[0]])
return class_id, bbox, polygon
def offset_label(
self,
img_w,
img_h,
distance: float = 1.0,
cap_style: int = 2,
join_style: int = 2,
):
if self.polygon is None:
self.bbox = np.array(
[
self.bbox[0] - distance if self.bbox[0] - distance > 0 else 0,
self.bbox[1] - distance if self.bbox[1] - distance > 0 else 0,
self.bbox[2] + distance if self.bbox[2] + distance < 1 else 1,
self.bbox[3] + distance if self.bbox[3] + distance < 1 else 1,
],
dtype=np.float32,
)
return self.bbox
def coords_are_normalized(coords):
# If every coordinate is between 0 and 1 (inclusive-ish), assume normalized
print(coords)
# if not coords:
# return False
return all(max(coords.flatten)) <= 1.001
def poly_to_pts(coords, img_w, img_h):
# coords: [x1 y1 x2 y2 ...] either normalized or absolute
# if coords_are_normalized(coords):
coords = [coords[i] * (img_w if i % 2 == 0 else img_h) for i in range(len(coords))]
pts = np.array(coords, dtype=np.int32).reshape(-1, 2)
return pts
pts = poly_to_pts(self.polygon, img_w, img_h)
line = LineString(pts)
# Buffer distance in pixels
buffered = line.buffer(distance=distance, cap_style=cap_style, join_style=join_style)
self.polygon = np.array(buffered.exterior.coords, dtype=np.float32) / (img_w, img_h)
xmn, ymn = self.polygon.min(axis=0)
xmx, ymx = self.polygon.max(axis=0)
xc = (xmn + xmx) / 2
yc = (ymn + ymx) / 2
bw = xmx - xmn
bh = ymx - ymn
self.bbox = np.array([xc, yc, bw, bh], dtype=np.float32)
return self.bbox, self.polygon
def translate(self, x, y, scale_x, scale_y):
self.bbox[0] -= x
self.bbox[0] *= scale_x
self.bbox[1] -= y
self.bbox[1] *= scale_y
self.bbox[2] *= scale_x
self.bbox[3] *= scale_y
if self.polygon is not None:
self.polygon[:, 0] -= x
self.polygon[:, 0] *= scale_x
self.polygon[:, 1] -= y
self.polygon[:, 1] *= scale_y
def in_range(self, hrange, wrange):
xc, yc, h, w = self.bbox
x1 = xc - w / 2
y1 = yc - h / 2
x2 = xc + w / 2
y2 = yc + h / 2
truth_val = (
xc >= wrange[0]
and x1 <= wrange[1]
and x2 >= wrange[0]
and x2 <= wrange[1]
and y1 >= hrange[0]
and y1 <= hrange[1]
and y2 >= hrange[0]
and y2 <= hrange[1]
)
print(x1, x2, wrange, y1, y2, hrange, truth_val)
return truth_val
def to_string(self, bbox: list = None, polygon: list = None):
coords = ""
if bbox is None:
bbox = self.bbox
# coords += " ".join([f"{x:.6f}" for x in self.bbox])
if polygon is None:
polygon = self.polygon
if self.polygon is not None:
coords += " " + " ".join([f"{x:.6f} {y:.6f}" for x, y in self.polygon])
return f"{self.class_id} {coords}"
def __str__(self):
return f"Class: {self.class_id}, BBox: {self.bbox}, Polygon: {self.polygon}"
class YoloLabelReader:
def __init__(self, label_path: Path):
self.label_path = label_path
self.labels = self._read_labels()
def _read_labels(self):
with open(self.label_path, "r") as f:
labels = [Label(line) for line in f.readlines()]
return labels
def get_labels(self, hrange, wrange):
"""hrange and wrange are tuples of (start, end) normalized to [0, 1]"""
labels = []
# print(hrange, wrange)
for lbl in self.labels:
# print(lbl)
if lbl.in_range(hrange, wrange):
labels.append(lbl)
return labels if len(labels) > 0 else None
def __get_item__(self, index):
return self.labels[index]
def __len__(self):
return len(self.labels)
def __iter__(self):
return iter(self.labels)
class ImageSplitter:
def __init__(self, image_path: Path, label_path: Path):
self.image = imread(image_path)
self.image_path = image_path
self.label_path = label_path
if not label_path.exists():
print(f"Label file {label_path} not found")
self.labels = None
else:
self.labels = YoloLabelReader(label_path)
def split_into_tiles(self, patch_size: tuple = (2, 2)):
"""Split image into patches of size patch_size"""
hstep, wstep = (
self.image.shape[0] // patch_size[0],
self.image.shape[1] // patch_size[1],
)
h, w = self.image.shape[:2]
for i in range(patch_size[0]):
for j in range(patch_size[1]):
metadata = {
"image_path": str(self.image_path),
"label_path": str(self.label_path),
"tile_section": f"{i}, {j}",
"tile_size": f"{hstep}, {wstep}",
"patch_size": f"{patch_size[0]}, {patch_size[1]}",
}
tile_reference = f"i{i}j{j}"
hrange = (i * hstep / h, (i + 1) * hstep / h)
wrange = (j * wstep / w, (j + 1) * wstep / w)
tile = self.image[i * hstep : (i + 1) * hstep, j * wstep : (j + 1) * wstep]
labels = None
if self.labels is not None:
labels = deepcopy(self.labels.get_labels(hrange, wrange))
print(id(labels))
if labels is not None:
print(hrange[0], wrange[0])
for l in labels:
print(l.bbox)
[l.translate(wrange[0], hrange[0], 2, 2) for l in labels]
print("translated")
for l in labels:
print(l.bbox)
# print(labels)
yield tile_reference, tile, labels, metadata
def split_respective_to_label(self, padding: int = 67):
if self.labels is None:
raise ValueError("No labels found. Only images having labels can be split.")
for i, label in enumerate(self.labels):
tile_reference = f"_lbl-{i+1:02d}"
# print(label.bbox)
metadata = {"image_path": str(self.image_path), "label_path": str(self.label_path), "label_index": str(i)}
xc_norm, yc_norm, h_norm, w_norm = label.bbox # normalized coords
xc, yc, h, w = [
int(np.round(f))
for f in [
xc_norm * self.image.shape[1],
yc_norm * self.image.shape[0],
h_norm * self.image.shape[0],
w_norm * self.image.shape[1],
]
] # image coords
# print("img coords:", xc, yc, h, w)
pad_xneg = padding + 1 # int(w / 2) + padding
pad_xpos = padding # int(w / 2) + padding
pad_yneg = padding + 1 # int(h / 2) + padding
pad_ypos = padding # int(h / 2) + padding
if xc - pad_xneg < 0:
pad_xneg = xc
if pad_xpos + xc > self.image.shape[1]:
pad_xpos = self.image.shape[1] - xc
if yc - pad_yneg < 0:
pad_yneg = yc
if pad_ypos + yc > self.image.shape[0]:
pad_ypos = self.image.shape[0] - yc
# print("pads:", pad_xneg, pad_xpos, pad_yneg, pad_ypos)
tile = self.image[
yc - pad_yneg : yc + pad_ypos,
xc - pad_xneg : xc + pad_xpos,
]
ny, nx = tile.shape
x_offset = pad_xneg
y_offset = pad_yneg
# print("tile shape:", tile.shape)
yolo_annotation = f"{label.class_id} " # {x_offset/nx} {y_offset/ny} {h_norm} {w_norm} "
yolo_annotation += " ".join(
[
f"{(x*self.image.shape[1]-(xc - x_offset))/nx:.6f} {(y*self.image.shape[0]-(yc-y_offset))/ny:.6f}"
for x, y in label.polygon
]
)
print(yolo_annotation)
new_label = Label(yolo_annotation=yolo_annotation)
yield tile_reference, tile, [new_label], metadata
def main(args):
if args.output:
args.output.mkdir(exist_ok=True, parents=True)
(args.output / "images").mkdir(exist_ok=True)
(args.output / "images-zoomed").mkdir(exist_ok=True)
(args.output / "labels").mkdir(exist_ok=True)
for image_path in (args.input / "images").glob("*.tif"):
data = ImageSplitter(
image_path=image_path,
label_path=(args.input / "labels" / image_path.stem).with_suffix(".txt"),
)
if args.split_around_label:
data = data.split_respective_to_label(padding=args.padding)
else:
data = data.split_into_tiles(patch_size=args.patch_size)
for tile_reference, tile, labels, metadata in data:
print()
print(tile_reference, tile.shape, labels, metadata) # len(labels) if labels else None)
# { debug
debug = False
if debug:
plt.figure(figsize=(10, 10 * tile.shape[0] / tile.shape[1]))
if labels is None:
plt.imshow(tile, cmap="gray")
plt.axis("off")
plt.title(f"{image_path.name} ({tile_reference})")
plt.show()
continue
print(labels[0].bbox)
# Draw annotations
out = draw_annotations(
cv2.cvtColor((tile / tile.max() * 255).astype(np.uint8), cv2.COLOR_GRAY2BGR),
[l.to_string() for l in labels],
alpha=0.1,
)
# Convert BGR -> RGB for matplotlib display
out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
plt.imshow(out_rgb)
plt.axis("off")
plt.title(f"{image_path.name} ({tile_reference})")
plt.show()
# } debug
if args.output:
# imwrite(args.output / "images" / f"{image_path.stem}_{tile_reference}.tif", tile, metadata=metadata)
scale = 5
tile_zoomed = zoom(tile, zoom=scale)
metadata["scale"] = scale
imwrite(
args.output / "images" / f"{image_path.stem}_{tile_reference}.tif",
tile_zoomed,
metadata=metadata,
imagej=True,
)
if labels is not None:
with open(args.output / "labels" / f"{image_path.stem}_{tile_reference}.txt", "w") as f:
for label in labels:
# label.offset_label(tile.shape[1], tile.shape[0])
f.write(label.to_string() + "\n")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input", type=Path)
parser.add_argument("-o", "--output", type=Path)
parser.add_argument(
"-p",
"--patch-size",
nargs=2,
type=int,
default=[2, 2],
help="Number of patches along height and width, rows and columns, respectively",
)
parser.add_argument(
"-sal",
"--split-around-label",
action="store_true",
help="If enabled, the image will be split around the label and for each label, a separate image will be created.",
)
parser.add_argument(
"--padding",
type=int,
default=67,
help="Padding around the label when splitting around the label.",
)
args = parser.parse_args()
main(args)

1
src/utils/show_yolo_seg.py Symbolic link
View File

@@ -0,0 +1 @@
../../tests/show_yolo_seg.py

View File

@@ -0,0 +1,157 @@
"""Ultralytics runtime patches for 16-bit TIFF training.
Goals:
- Use `tifffile` to decode `.tif/.tiff` reliably (OpenCV can silently drop bit-depth depending on codec).
- Preserve 16-bit data through the dataloader as `uint16` tensors.
- Fix Ultralytics trainer normalization (default divides by 255) to scale `uint16` correctly.
- Avoid uint8-forcing augmentations by recommending/setting hyp values (handled by caller).
This module is intended to be imported/called **before** instantiating/using YOLO.
"""
from __future__ import annotations
from typing import Optional
from src.utils.logger import get_logger
logger = get_logger(__name__)
def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
"""Apply runtime monkey-patches to Ultralytics to better support 16-bit TIFFs.
This function is safe to call multiple times.
Args:
force: If True, re-apply patches even if already applied.
"""
# Import inside function to ensure patching occurs before YOLO model/dataset is created.
import os
import cv2
import numpy as np
# import tifffile
import torch
from src.utils.image import Image
from ultralytics.utils import patches as ul_patches
already_patched = getattr(ul_patches.imread, "__name__", "") == "tifffile_imread"
if already_patched and not force:
return
_original_imread = ul_patches.imread
def tifffile_imread(filename: str, flags: int = cv2.IMREAD_COLOR, pseudo_rgb: bool = True) -> Optional[np.ndarray]:
"""Replacement for [`ultralytics.utils.patches.imread()`](venv/lib/python3.12/site-packages/ultralytics/utils/patches.py:20).
- For `.tif/.tiff`, uses `tifffile.imread()` and preserves dtype (e.g. uint16).
- For other formats, falls back to Ultralytics' original implementation.
- Always returns HWC (3 dims). For grayscale, returns (H, W, 1) or (H, W, 3) depending on requested flags.
"""
# print("here")
# return _original_imread(filename, flags)
ext = os.path.splitext(filename)[1].lower()
if ext in (".tif", ".tiff"):
arr = Image(filename).get_qt_rgb()[:, :, :3]
# Normalize common shapes:
# - (H, W) -> (H, W, 1)
# - (C, H, W) -> (H, W, C) (heuristic)
if arr is None:
return None
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[0] < arr.shape[1]:
arr = np.transpose(arr, (1, 2, 0))
if arr.ndim == 2:
arr = arr[..., None]
# Ensure contiguous array for downstream OpenCV ops.
# logger.info(f"Loading with monkey-patched imread: {filename}")
arr = arr.astype(np.float32)
arr /= arr.max()
arr *= 2**8 - 1
arr = arr.astype(np.uint8)
# print(arr.shape, arr.dtype, any(np.isnan(arr).flatten()), np.where(np.isnan(arr)), arr.min(), arr.max())
return np.ascontiguousarray(arr)
# logger.info(f"Loading with original imread: {filename}")
return _original_imread(filename, flags)
# Patch the canonical reference.
ul_patches.imread = tifffile_imread
# Patch common module-level imports (some Ultralytics modules do `from ... import imread`).
# Importing these modules is safe and helps ensure the patched function is used.
try:
import ultralytics.data.base as _ul_base
_ul_base.imread = tifffile_imread
except Exception:
pass
try:
import ultralytics.data.loaders as _ul_loaders
_ul_loaders.imread = tifffile_imread
except Exception:
pass
# Patch trainer normalization: default divides by 255 regardless of input dtype.
from ultralytics.models.yolo.detect import train as detect_train
_orig_preprocess_batch = detect_train.DetectionTrainer.preprocess_batch
def preprocess_batch_16bit(self, batch: dict) -> dict: # type: ignore[override]
# Start from upstream behavior to keep device placement + multiscale identical,
# but replace the 255 division with dtype-aware scaling.
# logger.info(f"Preprocessing batch with monkey-patched preprocess_batch")
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")
img = batch.get("img")
if isinstance(img, torch.Tensor):
# Decide scaling denom based on dtype (avoid expensive reductions if possible).
if img.dtype == torch.uint8:
denom = 255.0
elif img.dtype == torch.uint16:
denom = 65535.0
elif img.dtype.is_floating_point:
# Assume already in 0-1 range if float.
denom = 1.0
else:
# Generic integer fallback.
try:
denom = float(torch.iinfo(img.dtype).max)
except Exception:
denom = 255.0
batch["img"] = img.float() / denom
# Multi-scale branch copied from upstream to avoid re-introducing `/255` scaling.
if getattr(self.args, "multi_scale", False):
import math
import random
import torch.nn as nn
imgs = batch["img"]
sz = (
random.randrange(int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride))
// self.stride
* self.stride
)
sf = sz / max(imgs.shape[2:])
if sf != 1:
ns = [math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]]
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
batch["img"] = imgs
return batch
detect_train.DetectionTrainer.preprocess_batch = preprocess_batch_16bit
# Tag function to make it easier to detect patch state.
setattr(detect_train.DetectionTrainer.preprocess_batch, "_ultralytics_16bit_patch", True)

231
tests/show_yolo_seg.py Normal file
View File

@@ -0,0 +1,231 @@
#!/usr/bin/env python3
"""
show_yolo_seg.py
Usage:
python show_yolo_seg.py /path/to/image.jpg /path/to/labels.txt
Supports:
- Segmentation polygons: "class x1 y1 x2 y2 ... xn yn"
- YOLO bbox lines as fallback: "class x_center y_center width height"
Coordinates can be normalized [0..1] or absolute pixels (auto-detected).
"""
import sys
import cv2
import numpy as np
import matplotlib.pyplot as plt
import argparse
from pathlib import Path
import random
from shapely.geometry import LineString
from src.utils.image import Image
def parse_label_line(line):
parts = line.strip().split()
if not parts:
return None
cls = int(float(parts[0]))
coords = [float(x) for x in parts[1:]]
return cls, coords
def coords_are_normalized(coords):
# If every coordinate is between 0 and 1 (inclusive-ish), assume normalized
if not coords:
return False
return max(coords) <= 1.001
def yolo_bbox_to_xyxy(coords, img_w, img_h):
# coords: [xc, yc, w, h] normalized or absolute
xc, yc, w, h = coords[:4]
if max(coords) <= 1.001:
xc *= img_w
yc *= img_h
w *= img_w
h *= img_h
x1 = int(round(xc - w / 2))
y1 = int(round(yc - h / 2))
x2 = int(round(xc + w / 2))
y2 = int(round(yc + h / 2))
return x1, y1, x2, y2
def poly_to_pts(coords, img_w, img_h):
# coords: [x1 y1 x2 y2 ...] either normalized or absolute
if coords_are_normalized(coords[4:]):
coords = [coords[i] * (img_w if i % 2 == 0 else img_h) for i in range(len(coords))]
pts = np.array(coords, dtype=np.int32).reshape(-1, 2)
return pts
def random_color_for_class(cls):
random.seed(cls) # deterministic per class
return (
0,
0,
255,
) # tuple(int(x) for x in np.array([random.randint(0, 255) for _ in range(3)]))
def draw_annotations(img, labels, alpha=0.4, draw_bbox_for_poly=True):
# img: BGR numpy array
overlay = img.copy()
h, w = img.shape[:2]
for line in labels:
if isinstance(line, str):
cls, coords = parse_label_line(line)
if isinstance(line, tuple):
cls, coords = line
if not coords:
continue
# polygon case (>=6 coordinates)
if len(coords) >= 6:
color = random_color_for_class(cls)
x1, y1, x2, y2 = yolo_bbox_to_xyxy(coords[:4], w, h)
print(x1, y1, x2, y2)
cv2.rectangle(img, (x1, y1), (x2, y2), color, 1)
pts = poly_to_pts(coords[4:], w, h)
# line = LineString(pts)
# # Buffer distance in pixels
# buffered = line.buffer(3, cap_style=2, join_style=2)
# coords = np.array(buffered.exterior.coords, dtype=np.int32)
# cv2.fillPoly(overlay, [coords], color=(255, 255, 255))
# fill on overlay
cv2.fillPoly(overlay, [pts], color)
# outline on base image
cv2.polylines(img, [pts], isClosed=True, color=color, thickness=1)
# put class text at first point
x, y = int(pts[0, 0]), int(pts[0, 1]) - 6
if 0:
cv2.putText(
img,
str(cls),
(x, max(6, y)),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(255, 255, 255),
2,
cv2.LINE_AA,
)
# YOLO bbox case (4 coords)
elif len(coords) == 4:
x1, y1, x2, y2 = yolo_bbox_to_xyxy(coords, w, h)
color = random_color_for_class(cls)
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
cv2.putText(
img,
str(cls),
(x1, max(6, y1 - 4)),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(255, 255, 255),
2,
cv2.LINE_AA,
)
else:
# Unknown / invalid format, skip
continue
# blend overlay for filled polygons
cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0, img)
return img
def load_labels_file(label_path):
labels = []
with open(label_path, "r") as f:
for raw in f:
line = raw.strip()
if not line:
continue
parsed = parse_label_line(line)
if parsed:
labels.append(parsed)
return labels
def main():
parser = argparse.ArgumentParser(description="Show YOLO segmentation / polygon annotations")
parser.add_argument("image", type=str, help="Path to image file")
parser.add_argument("--labels", type=str, help="Path to YOLO label file (polygons)")
parser.add_argument("--alpha", type=float, default=0.4, help="Polygon fill alpha (0..1)")
parser.add_argument("--no-bbox", action="store_true", help="Don't draw bounding boxes for polygons")
args = parser.parse_args()
print(args)
img_path = Path(args.image)
if args.labels:
lbl_path = Path(args.labels)
else:
lbl_path = img_path.with_suffix(".txt")
lbl_path = Path(str(lbl_path).replace("images", "labels"))
if not img_path.exists():
print("Image not found:", img_path)
sys.exit(1)
if not lbl_path.exists():
print("Label file not found:", lbl_path)
sys.exit(1)
# img = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
img = (Image(img_path).get_qt_rgb() * 255).astype(np.uint8)
if img is None:
print("Could not load image:", img_path)
sys.exit(1)
labels = load_labels_file(str(lbl_path))
if not labels:
print("No labels parsed from", lbl_path)
# continue and just show image
out = draw_annotations(img.copy(), labels, alpha=args.alpha, draw_bbox_for_poly=(not args.no_bbox))
out_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(10, 10 * out.shape[0] / out.shape[1]))
if 0:
plt.imshow(out_rgb.transpose(1, 0, 2))
else:
plt.imshow(out_rgb)
for label in labels:
lclass, coords = label
# print(lclass, coords)
bbox = coords[:4]
# print("bbox", bbox)
bbox = np.array(bbox) * np.array([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
yc, xc, h, w = bbox
# print("bbox", bbox)
# polyline = np.array(coords[4:]).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
polyline = np.array(coords).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
# print("pl", coords[4:])
# print("pl", polyline)
# Convert BGR -> RGB for matplotlib display
# out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
# out_rgb = Image()
plt.plot(polyline[:, 0], polyline[:, 1], "y", linewidth=2)
if 0:
plt.plot(
[yc - h / 2, yc - h / 2, yc + h / 2, yc + h / 2, yc - h / 2],
[xc - w / 2, xc + w / 2, xc + w / 2, xc - w / 2, xc - w / 2],
"r",
linewidth=2,
)
# plt.axis("off")
plt.title(f"{img_path.name} ({lbl_path.name})")
plt.show()
if __name__ == "__main__":
main()

145
tests/test_image.py Normal file
View File

@@ -0,0 +1,145 @@
"""
Tests for the Image class.
"""
import pytest
import numpy as np
from pathlib import Path
from src.utils.image import Image, ImageLoadError
class TestImage:
"""Test cases for the Image class."""
def test_load_nonexistent_file(self):
"""Test loading a non-existent file raises ImageLoadError."""
with pytest.raises(ImageLoadError):
Image("nonexistent_file.jpg")
def test_load_unsupported_format(self, tmp_path):
"""Test loading an unsupported format raises ImageLoadError."""
# Create a dummy file with unsupported extension
test_file = tmp_path / "test.txt"
test_file.write_text("not an image")
with pytest.raises(ImageLoadError):
Image(test_file)
def test_supported_extensions(self):
"""Test that supported extensions are correctly defined."""
expected_extensions = Image.SUPPORTED_EXTENSIONS
assert Image.SUPPORTED_EXTENSIONS == expected_extensions
def test_image_properties(self, tmp_path):
"""Test image properties after loading."""
# Create a simple test image using numpy and cv2
import cv2
test_img = np.zeros((100, 200, 3), dtype=np.uint8)
test_img[:, :] = [255, 0, 0] # Blue in BGR
test_file = tmp_path / "test.jpg"
cv2.imwrite(str(test_file), test_img)
# Load the image
img = Image(test_file)
# Check properties
assert img.width == 200
assert img.height == 100
assert img.channels == 3
assert img.format == "jpg"
assert img.shape == (100, 200, 3)
assert img.size_bytes > 0
assert img.is_color()
assert not img.is_grayscale()
def test_get_rgb(self, tmp_path):
"""Test RGB conversion."""
import cv2
# Create BGR image
test_img = np.zeros((50, 50, 3), dtype=np.uint8)
test_img[:, :] = [255, 0, 0] # Blue in BGR
test_file = tmp_path / "test_rgb.png"
cv2.imwrite(str(test_file), test_img)
img = Image(test_file)
rgb_data = img.get_rgb()
# RGB should have red channel at 255
assert rgb_data[0, 0, 0] == 0 # R
assert rgb_data[0, 0, 1] == 0 # G
assert rgb_data[0, 0, 2] == 255 # B (was BGR blue)
def test_get_grayscale(self, tmp_path):
"""Test grayscale conversion."""
import cv2
test_img = np.zeros((50, 50, 3), dtype=np.uint8)
test_img[:, :] = [128, 128, 128]
test_file = tmp_path / "test_gray.png"
cv2.imwrite(str(test_file), test_img)
img = Image(test_file)
gray_data = img.get_grayscale()
assert len(gray_data.shape) == 2 # Should be 2D
assert gray_data.shape == (50, 50)
def test_copy(self, tmp_path):
"""Test copying image data."""
import cv2
test_img = np.zeros((50, 50, 3), dtype=np.uint8)
test_file = tmp_path / "test_copy.png"
cv2.imwrite(str(test_file), test_img)
img = Image(test_file)
copied = img.copy()
# Modify copy
copied[0, 0] = [255, 255, 255]
# Original should be unchanged
assert not np.array_equal(img.data[0, 0], copied[0, 0])
def test_resize(self, tmp_path):
"""Test image resizing."""
import cv2
test_img = np.zeros((100, 100, 3), dtype=np.uint8)
test_file = tmp_path / "test_resize.png"
cv2.imwrite(str(test_file), test_img)
img = Image(test_file)
resized = img.resize(50, 50)
assert resized.shape == (50, 50, 3)
# Original should be unchanged
assert img.width == 100
assert img.height == 100
def test_str_repr(self, tmp_path):
"""Test string representation."""
import cv2
test_img = np.zeros((100, 200, 3), dtype=np.uint8)
test_file = tmp_path / "test_str.jpg"
cv2.imwrite(str(test_file), test_img)
img = Image(test_file)
str_repr = str(img)
assert "test_str.jpg" in str_repr
assert "100x200x3" in str_repr
assert "jpg" in str_repr
repr_str = repr(img)
assert "Image" in repr_str
assert "test_str.jpg" in repr_str