Compare commits
40 Commits
c7e1271193
...
monkey-pat
| Author | SHA1 | Date | |
|---|---|---|---|
| d03ffdc4d0 | |||
| 8d30e6bb7a | |||
| f810fec4d8 | |||
| 9c8931e6f3 | |||
| 20578c1fdf | |||
| 2c494dac49 | |||
| 506c74e53a | |||
| eefda5b878 | |||
| 31cb6a6c8e | |||
| 0c19ea2557 | |||
| 89e47591db | |||
| 69cde09e53 | |||
| fcbd5fb16d | |||
| ca52312925 | |||
| 0a93bf797a | |||
| d998c65665 | |||
| 510eabfa94 | |||
| 395d263900 | |||
| e98d287b8a | |||
| d25101de2d | |||
| f88beef188 | |||
| 2fd9a2acf4 | |||
| 2bcd18cc75 | |||
| 5d25378c46 | |||
| 2b0b48921e | |||
| b0c05f0225 | |||
| 97badaa390 | |||
| 8f8132ce61 | |||
| 6ae7481e25 | |||
| 061f8b3ca2 | |||
| a8e5db3135 | |||
| 268ed5175e | |||
| 5e9d3b1dc4 | |||
| 7d83e9b9b1 | |||
| e364d06217 | |||
| e5036c10cf | |||
| c7e388d9ae | |||
| 6b995e7325 | |||
| 0e0741d323 | |||
| dd99a0677c |
@@ -1,57 +0,0 @@
|
|||||||
database:
|
|
||||||
path: data/detections.db
|
|
||||||
image_repository:
|
|
||||||
base_path: ''
|
|
||||||
allowed_extensions:
|
|
||||||
- .jpg
|
|
||||||
- .jpeg
|
|
||||||
- .png
|
|
||||||
- .tif
|
|
||||||
- .tiff
|
|
||||||
- .bmp
|
|
||||||
models:
|
|
||||||
default_base_model: yolov8s-seg.pt
|
|
||||||
models_directory: data/models
|
|
||||||
base_model_choices:
|
|
||||||
- yolov8s-seg.pt
|
|
||||||
- yolo11s-seg.pt
|
|
||||||
training:
|
|
||||||
default_epochs: 100
|
|
||||||
default_batch_size: 16
|
|
||||||
default_imgsz: 1024
|
|
||||||
default_patience: 50
|
|
||||||
default_lr0: 0.01
|
|
||||||
two_stage:
|
|
||||||
enabled: false
|
|
||||||
stage1:
|
|
||||||
epochs: 20
|
|
||||||
lr0: 0.0005
|
|
||||||
patience: 10
|
|
||||||
freeze: 10
|
|
||||||
stage2:
|
|
||||||
epochs: 150
|
|
||||||
lr0: 0.0003
|
|
||||||
patience: 30
|
|
||||||
last_dataset_yaml: /home/martin/code/object_detection/data/datasets/data.yaml
|
|
||||||
last_dataset_dir: /home/martin/code/object_detection/data/datasets
|
|
||||||
detection:
|
|
||||||
default_confidence: 0.25
|
|
||||||
default_iou: 0.45
|
|
||||||
max_batch_size: 100
|
|
||||||
visualization:
|
|
||||||
bbox_colors:
|
|
||||||
organelle: '#FF6B6B'
|
|
||||||
membrane_branch: '#4ECDC4'
|
|
||||||
default: '#00FF00'
|
|
||||||
bbox_thickness: 2
|
|
||||||
font_size: 12
|
|
||||||
export:
|
|
||||||
formats:
|
|
||||||
- csv
|
|
||||||
- json
|
|
||||||
- excel
|
|
||||||
default_format: csv
|
|
||||||
logging:
|
|
||||||
level: INFO
|
|
||||||
file: logs/app.log
|
|
||||||
format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
||||||
@@ -1,300 +0,0 @@
|
|||||||
# 16-bit TIFF Support for YOLO Object Detection
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
This document describes the implementation of 16-bit grayscale TIFF support for YOLO object detection. The system properly loads 16-bit TIFF images, normalizes them to float32 [0-1], and handles them appropriately for both **inference** and **training** **without uint8 conversion** to preserve the full dynamic range and avoid data loss.
|
|
||||||
|
|
||||||
## Key Features
|
|
||||||
|
|
||||||
✅ Reads 16-bit or float32 images using tifffile
|
|
||||||
✅ Converts to float32 [0-1] (NO uint8 conversion)
|
|
||||||
✅ Replicates grayscale → RGB (3 channels)
|
|
||||||
✅ **Inference**: Passes numpy arrays directly to YOLO (no file I/O)
|
|
||||||
✅ **Training**: On-the-fly float32 conversion (NO disk caching)
|
|
||||||
✅ Uses Ultralytics YOLOv8/v11 models
|
|
||||||
✅ Works with segmentation models
|
|
||||||
✅ No data loss, no double normalization, no silent clipping
|
|
||||||
|
|
||||||
## Changes Made
|
|
||||||
|
|
||||||
### 1. Dependencies ([`requirements.txt`](../requirements.txt:14))
|
|
||||||
- Added `tifffile>=2023.0.0` for reliable 16-bit TIFF loading
|
|
||||||
|
|
||||||
### 2. Image Loading ([`src/utils/image.py`](../src/utils/image.py))
|
|
||||||
|
|
||||||
#### Enhanced TIFF Loading
|
|
||||||
- Modified [`Image._load()`](../src/utils/image.py:87) to use `tifffile` for `.tif` and `.tiff` files
|
|
||||||
- Preserves original 16-bit data type during loading
|
|
||||||
- Properly handles both grayscale and multi-channel TIFF files
|
|
||||||
|
|
||||||
#### New Normalization Method
|
|
||||||
Added [`Image.to_normalized_float32()`](../src/utils/image.py:280) method that:
|
|
||||||
- Converts image data to `float32`
|
|
||||||
- Properly scales values to [0, 1] range:
|
|
||||||
- **16-bit images**: divides by 65535 (full dynamic range)
|
|
||||||
- 8-bit images: divides by 255
|
|
||||||
- Float images: clips to [0, 1]
|
|
||||||
- Handles various data types automatically
|
|
||||||
|
|
||||||
### 3. YOLO Preprocessing ([`src/model/yolo_wrapper.py`](../src/model/yolo_wrapper.py))
|
|
||||||
|
|
||||||
Enhanced [`YOLOWrapper._prepare_source()`](../src/model/yolo_wrapper.py:231) to:
|
|
||||||
1. Detect 16-bit TIFF files automatically
|
|
||||||
2. Load and normalize to float32 [0-1] using the new method
|
|
||||||
3. Replicate grayscale to RGB (3 channels)
|
|
||||||
4. **Return numpy array directly** (NO file saving, NO uint8 conversion)
|
|
||||||
5. Pass float32 array directly to YOLO for inference
|
|
||||||
|
|
||||||
## Processing Pipeline
|
|
||||||
|
|
||||||
### For Inference (predict)
|
|
||||||
|
|
||||||
For 16-bit TIFF files during inference:
|
|
||||||
|
|
||||||
1. **Load**: File loaded using `tifffile` → preserves 16-bit uint16 data
|
|
||||||
2. **Normalize**: Convert to float32 and scale to [0, 1]
|
|
||||||
```python
|
|
||||||
float_data = uint16_data.astype(np.float32) / 65535.0
|
|
||||||
```
|
|
||||||
3. **RGB Conversion**: Replicate grayscale to 3 channels
|
|
||||||
```python
|
|
||||||
rgb_float = np.stack([float_data] * 3, axis=-1)
|
|
||||||
```
|
|
||||||
4. **Pass to YOLO**: Return float32 array directly (no uint8, no file I/O)
|
|
||||||
5. **Inference**: YOLO processes the float32 [0-1] RGB array
|
|
||||||
|
|
||||||
### For Training (train)
|
|
||||||
|
|
||||||
Training now uses a custom dataset loader with on-the-fly conversion (NO disk caching):
|
|
||||||
|
|
||||||
1. **Custom Dataset**: Uses `Float32Dataset` class that extends Ultralytics' `YOLODataset`
|
|
||||||
2. **Load On-The-Fly**: Each image is loaded and converted during training:
|
|
||||||
- Detect 16-bit TIFF files automatically
|
|
||||||
- Load with `tifffile` (preserves uint16)
|
|
||||||
- Convert to float32 [0-1] in memory
|
|
||||||
- Replicate to 3 channels (RGB)
|
|
||||||
3. **No Disk Cache**: Conversion happens in memory, no files written
|
|
||||||
4. **Train**: YOLO trains on float32 [0-1] RGB arrays directly
|
|
||||||
|
|
||||||
See [`src/utils/train_ultralytics_float.py`](../src/utils/train_ultralytics_float.py) for implementation.
|
|
||||||
|
|
||||||
### No Data Loss!
|
|
||||||
|
|
||||||
Unlike approaches that convert to uint8 (256 levels), this implementation:
|
|
||||||
- Preserves full 16-bit dynamic range (65536 levels)
|
|
||||||
- Maintains precision with float32 representation
|
|
||||||
- For inference: passes data directly without file conversions
|
|
||||||
- For training: uses float32 TIFFs (not uint8 PNGs)
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
### Basic Image Loading
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.utils.image import Image
|
|
||||||
|
|
||||||
# Load a 16-bit TIFF file
|
|
||||||
img = Image("path/to/16bit_image.tif")
|
|
||||||
|
|
||||||
# Get normalized float32 data [0-1]
|
|
||||||
normalized = img.to_normalized_float32() # Shape: (H, W), dtype: float32
|
|
||||||
|
|
||||||
# Original data is preserved
|
|
||||||
original = img.data # Still uint16
|
|
||||||
```
|
|
||||||
|
|
||||||
### YOLO Inference
|
|
||||||
|
|
||||||
The preprocessing is automatic - just use YOLO as normal:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.model.yolo_wrapper import YOLOWrapper
|
|
||||||
|
|
||||||
# Initialize model
|
|
||||||
yolo = YOLOWrapper("yolov8s-seg.pt")
|
|
||||||
yolo.load_model()
|
|
||||||
|
|
||||||
# Perform inference on 16-bit TIFF
|
|
||||||
# The image will be automatically normalized and passed as float32 [0-1]
|
|
||||||
detections = yolo.predict("path/to/16bit_image.tif", conf=0.25)
|
|
||||||
```
|
|
||||||
|
|
||||||
### With InferenceEngine
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.model.inference import InferenceEngine
|
|
||||||
from src.database.db_manager import DatabaseManager
|
|
||||||
|
|
||||||
# Setup
|
|
||||||
db = DatabaseManager("database.db")
|
|
||||||
engine = InferenceEngine("model.pt", db, model_id=1)
|
|
||||||
|
|
||||||
# Detect objects in 16-bit TIFF
|
|
||||||
result = engine.detect_single(
|
|
||||||
image_path="path/to/16bit_image.tif",
|
|
||||||
relative_path="images/16bit_image.tif",
|
|
||||||
conf=0.25
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Testing
|
|
||||||
|
|
||||||
Three test scripts are provided:
|
|
||||||
|
|
||||||
### 1. Image Loading Test
|
|
||||||
```bash
|
|
||||||
./venv/bin/python tests/test_16bit_tiff_loading.py
|
|
||||||
```
|
|
||||||
|
|
||||||
Tests:
|
|
||||||
- Loading 16-bit TIFF files with tifffile
|
|
||||||
- Normalization to float32 [0-1]
|
|
||||||
- Data type and value range verification
|
|
||||||
|
|
||||||
### 2. Float32 Passthrough Test (Most Important!)
|
|
||||||
```bash
|
|
||||||
./venv/bin/python tests/test_yolo_16bit_float32.py
|
|
||||||
```
|
|
||||||
|
|
||||||
Tests:
|
|
||||||
- YOLO preprocessing returns numpy array (not file path)
|
|
||||||
- Data is float32 [0-1] (not uint8)
|
|
||||||
- No quantization to 256 levels (proves no uint8 conversion)
|
|
||||||
- Sample output:
|
|
||||||
```
|
|
||||||
✓ SUCCESS: Prepared source is a numpy array (float32 passthrough)
|
|
||||||
Shape: (200, 200, 3)
|
|
||||||
Dtype: float32
|
|
||||||
Min value: 0.000000
|
|
||||||
Max value: 1.000000
|
|
||||||
Unique values: 399
|
|
||||||
|
|
||||||
✓ SUCCESS: Data has 399 unique values (> 256)
|
|
||||||
This confirms NO uint8 quantization occurred!
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Legacy Test (Shows Old Behavior)
|
|
||||||
```bash
|
|
||||||
./venv/bin/python tests/test_yolo_16bit_preprocessing.py
|
|
||||||
```
|
|
||||||
|
|
||||||
This test shows the old behavior (uint8 conversion) - kept for comparison.
|
|
||||||
|
|
||||||
## Benefits
|
|
||||||
|
|
||||||
1. **No Data Loss**: Preserves full 16-bit dynamic range (65536 levels vs 256)
|
|
||||||
2. **High Precision**: Float32 maintains fine-grained intensity differences
|
|
||||||
3. **Automatic Processing**: No manual preprocessing needed
|
|
||||||
4. **YOLO Compatible**: Ultralytics YOLO accepts float32 [0-1] arrays
|
|
||||||
5. **Performance**: No intermediate file I/O for 16-bit TIFFs
|
|
||||||
6. **Backwards Compatible**: Regular images (8-bit PNG, JPEG, etc.) still work as before
|
|
||||||
|
|
||||||
## Technical Notes
|
|
||||||
|
|
||||||
### Float32 vs uint8
|
|
||||||
|
|
||||||
**With uint8 conversion (OLD - BAD):**
|
|
||||||
- 16-bit (65536 levels) → uint8 (256 levels) = **99.6% data loss!**
|
|
||||||
- Fine intensity differences are lost
|
|
||||||
- Quantization artifacts
|
|
||||||
|
|
||||||
**With float32 [0-1] (NEW - GOOD):**
|
|
||||||
- 16-bit (65536 levels) → float32 (continuous) = **No data loss**
|
|
||||||
- Full dynamic range preserved
|
|
||||||
- Smooth gradients maintained
|
|
||||||
|
|
||||||
### Memory Considerations
|
|
||||||
|
|
||||||
For a 2048×2048 single-channel image:
|
|
||||||
|
|
||||||
| Format | Memory | Disk Space | Notes |
|
|
||||||
|--------|--------|------------|-------|
|
|
||||||
| Original 16-bit | 8 MB | ~8 MB | uint16 grayscale TIFF |
|
|
||||||
| Float32 grayscale | 16 MB | - | Intermediate |
|
|
||||||
| Float32 3-channel | 48 MB | ~48 MB | Training cache |
|
|
||||||
| uint8 RGB (old) | 12 MB | ~12 MB | OLD approach with data loss |
|
|
||||||
|
|
||||||
The float32 approach uses ~3× more memory than uint8 during training but preserves **all information**.
|
|
||||||
|
|
||||||
**No Disk Cache**: The new on-the-fly approach eliminates the need for cached datasets on disk.
|
|
||||||
|
|
||||||
### Why Direct Numpy Array?
|
|
||||||
|
|
||||||
Passing numpy arrays directly to YOLO (instead of saving to file):
|
|
||||||
|
|
||||||
1. **Faster**: No disk I/O overhead
|
|
||||||
2. **No Quantization**: Avoids PNG/JPEG quantization
|
|
||||||
3. **Memory Efficient**: Single copy in memory
|
|
||||||
4. **Cleaner**: No temp file management
|
|
||||||
|
|
||||||
Ultralytics YOLO supports various input types:
|
|
||||||
- File paths (str): `"image.jpg"`
|
|
||||||
- Numpy arrays: `np.ndarray` ← **we use this**
|
|
||||||
- PIL Images: `PIL.Image`
|
|
||||||
- Torch tensors: `torch.Tensor`
|
|
||||||
|
|
||||||
## Training with Float32 Dataset Loader
|
|
||||||
|
|
||||||
The system now includes a custom dataset loader for 16-bit TIFF training:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.utils.train_ultralytics_float import train_with_float32_loader
|
|
||||||
|
|
||||||
# Train with on-the-fly float32 conversion
|
|
||||||
results = train_with_float32_loader(
|
|
||||||
model_path="yolov8s-seg.pt",
|
|
||||||
data_yaml="data/my_dataset/data.yaml",
|
|
||||||
epochs=100,
|
|
||||||
batch=16,
|
|
||||||
imgsz=640,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
The `Float32Dataset` class automatically:
|
|
||||||
- Detects 16-bit TIFF files
|
|
||||||
- Loads with `tifffile` (not PIL/cv2)
|
|
||||||
- Converts to float32 [0-1] on-the-fly
|
|
||||||
- Replicates to 3 channels
|
|
||||||
- Integrates seamlessly with Ultralytics training pipeline
|
|
||||||
|
|
||||||
This is used automatically by the training tab in the GUI.
|
|
||||||
|
|
||||||
## Installation
|
|
||||||
|
|
||||||
Install the updated dependencies:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
./venv/bin/pip install -r requirements.txt
|
|
||||||
```
|
|
||||||
|
|
||||||
Or install tifffile directly:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
./venv/bin/pip install tifffile>=2023.0.0
|
|
||||||
```
|
|
||||||
|
|
||||||
## Example Test Output
|
|
||||||
|
|
||||||
```
|
|
||||||
=== Testing Float32 Passthrough (NO uint8) ===
|
|
||||||
Created test 16-bit TIFF: /tmp/tmpdt5hm0ab.tif
|
|
||||||
Shape: (200, 200)
|
|
||||||
Dtype: uint16
|
|
||||||
Min value: 0
|
|
||||||
Max value: 65535
|
|
||||||
|
|
||||||
Preprocessing result:
|
|
||||||
Prepared source type: <class 'numpy.ndarray'>
|
|
||||||
|
|
||||||
✓ SUCCESS: Prepared source is a numpy array (float32 passthrough)
|
|
||||||
Shape: (200, 200, 3)
|
|
||||||
Dtype: float32
|
|
||||||
Min value: 0.000000
|
|
||||||
Max value: 1.000000
|
|
||||||
Mean value: 0.499992
|
|
||||||
Unique values: 399
|
|
||||||
|
|
||||||
✓ SUCCESS: Data has 399 unique values (> 256)
|
|
||||||
This confirms NO uint8 quantization occurred!
|
|
||||||
|
|
||||||
✓ All float32 passthrough tests passed!
|
|
||||||
@@ -1,269 +0,0 @@
|
|||||||
# Training YOLO with 16-bit TIFF Datasets
|
|
||||||
|
|
||||||
## Quick Start
|
|
||||||
|
|
||||||
If your dataset contains 16-bit grayscale TIFF files, the training tab will automatically:
|
|
||||||
|
|
||||||
1. Detect 16-bit TIFF images in your dataset
|
|
||||||
2. Convert them to float32 [0-1] RGB **on-the-fly** during training
|
|
||||||
3. Train without any disk caching (memory-efficient)
|
|
||||||
|
|
||||||
**No manual intervention or disk space needed!**
|
|
||||||
|
|
||||||
## Why Float32 On-The-Fly Conversion?
|
|
||||||
|
|
||||||
### The Problem
|
|
||||||
|
|
||||||
YOLO's training expects:
|
|
||||||
- 3-channel images (RGB)
|
|
||||||
- Images loaded from disk by the dataloader
|
|
||||||
|
|
||||||
16-bit grayscale TIFFs are:
|
|
||||||
- 1-channel (grayscale)
|
|
||||||
- Need to be converted to RGB format
|
|
||||||
|
|
||||||
### The Solution
|
|
||||||
|
|
||||||
**NEW APPROACH (Current)**: On-the-fly float32 conversion
|
|
||||||
- Load 16-bit TIFF with `tifffile` (not PIL/cv2)
|
|
||||||
- Convert uint16 [0-65535] → float32 [0-1] in memory
|
|
||||||
- Replicate grayscale to 3 channels
|
|
||||||
- Pass directly to YOLO training pipeline
|
|
||||||
- **No disk caching required!**
|
|
||||||
|
|
||||||
**OLD APPROACH (Deprecated)**: Disk caching
|
|
||||||
- Created 16-bit RGB PNG cache files on disk
|
|
||||||
- Required ~2x dataset size in disk space
|
|
||||||
- Slower first training run
|
|
||||||
|
|
||||||
## How It Works
|
|
||||||
|
|
||||||
### Custom Dataset Loader
|
|
||||||
|
|
||||||
The system uses a custom `Float32Dataset` class that extends Ultralytics' `YOLODataset`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.utils.train_ultralytics_float import Float32Dataset
|
|
||||||
|
|
||||||
# This dataset loader:
|
|
||||||
# 1. Intercepts image loading
|
|
||||||
# 2. Detects 16-bit TIFFs
|
|
||||||
# 3. Converts to float32 [0-1] RGB on-the-fly
|
|
||||||
# 4. Passes to training pipeline
|
|
||||||
```
|
|
||||||
|
|
||||||
### Conversion Process
|
|
||||||
|
|
||||||
For each 16-bit grayscale TIFF during training:
|
|
||||||
|
|
||||||
```
|
|
||||||
1. Load with tifffile → uint16 [0, 65535]
|
|
||||||
2. Convert to float32 → img.astype(float32) / 65535.0
|
|
||||||
3. Replicate to RGB → np.stack([img] * 3, axis=-1)
|
|
||||||
4. Result: float32 [0, 1] RGB array, shape (H, W, 3)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Memory vs Disk
|
|
||||||
|
|
||||||
| Aspect | On-the-fly (NEW) | Disk Cache (OLD) |
|
|
||||||
|--------|------------------|------------------|
|
|
||||||
| Disk Space | Dataset size only | ~2× dataset size |
|
|
||||||
| First Training | Fast | Slow (creates cache) |
|
|
||||||
| Subsequent Training | Fast | Fast |
|
|
||||||
| Data Loss | None | None |
|
|
||||||
| Setup Required | None | Cache creation |
|
|
||||||
|
|
||||||
## Data Preservation
|
|
||||||
|
|
||||||
### Float32 Precision
|
|
||||||
|
|
||||||
16-bit TIFF: 65,536 levels (0-65535)
|
|
||||||
Float32: ~7 decimal digits precision
|
|
||||||
|
|
||||||
**Conversion accuracy:**
|
|
||||||
```python
|
|
||||||
Original: 32768 (uint16, middle intensity)
|
|
||||||
Float32: 32768 / 65535 = 0.50000763 (exact)
|
|
||||||
```
|
|
||||||
|
|
||||||
Full 16-bit precision is preserved in float32 representation.
|
|
||||||
|
|
||||||
### Comparison to uint8
|
|
||||||
|
|
||||||
| Approach | Precision Loss | Recommended |
|
|
||||||
|----------|----------------|-------------|
|
|
||||||
| **float32 [0-1]** | None | ✓ YES |
|
|
||||||
| uint16 RGB | None | ✓ YES (but disk-heavy) |
|
|
||||||
| uint8 | 99.6% data loss | ✗ NO |
|
|
||||||
|
|
||||||
**Why NO uint8:**
|
|
||||||
```
|
|
||||||
Original values: 32768, 32769, 32770 (distinct)
|
|
||||||
Converted to uint8: 128, 128, 128 (collapsed!)
|
|
||||||
```
|
|
||||||
|
|
||||||
Multiple 16-bit values collapse to the same uint8 value.
|
|
||||||
|
|
||||||
## Training Tab Behavior
|
|
||||||
|
|
||||||
When you click "Start Training" with a 16-bit TIFF dataset:
|
|
||||||
|
|
||||||
```
|
|
||||||
[01:23:45] Exported 150 annotations across 50 image(s).
|
|
||||||
[01:23:45] Using Float32 on-the-fly loader for 16-bit TIFF support (no disk caching)
|
|
||||||
[01:23:45] Starting training run 'my_model_v1' using yolov8s-seg.pt
|
|
||||||
[01:23:46] Using Float32Dataset loader for 16-bit TIFF support
|
|
||||||
```
|
|
||||||
|
|
||||||
Every training run uses the same approach - fast and efficient!
|
|
||||||
|
|
||||||
## Inference vs Training
|
|
||||||
|
|
||||||
| Operation | Input | Processing | Output to YOLO |
|
|
||||||
|-----------|-------|------------|----------------|
|
|
||||||
| **Inference** | 16-bit TIFF file | Load → float32 [0-1] → 3ch | numpy array (float32) |
|
|
||||||
| **Training** | 16-bit TIFF dataset | Load on-the-fly → float32 [0-1] → 3ch | numpy array (float32) |
|
|
||||||
|
|
||||||
Both preserve full 16-bit precision using float32 representation.
|
|
||||||
|
|
||||||
## Technical Details
|
|
||||||
|
|
||||||
### Custom Dataset Class
|
|
||||||
|
|
||||||
Located in `src/utils/train_ultralytics_float.py`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
class Float32Dataset(YOLODataset):
|
|
||||||
"""
|
|
||||||
Extends Ultralytics YOLODataset to handle 16-bit TIFFs.
|
|
||||||
|
|
||||||
Key methods:
|
|
||||||
- load_image(): Intercepts image loading
|
|
||||||
- Detects .tif/.tiff with dtype == uint16
|
|
||||||
- Converts: uint16 → float32 [0-1] → RGB (3-channel)
|
|
||||||
"""
|
|
||||||
```
|
|
||||||
|
|
||||||
### Integration with YOLO
|
|
||||||
|
|
||||||
The `YOLOWrapper.train()` method automatically uses the custom loader:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# In src/model/yolo_wrapper.py
|
|
||||||
def train(self, data_yaml, use_float32_loader=True, **kwargs):
|
|
||||||
if use_float32_loader:
|
|
||||||
# Use custom Float32Dataset
|
|
||||||
return train_with_float32_loader(...)
|
|
||||||
else:
|
|
||||||
# Standard YOLO training
|
|
||||||
return self.model.train(...)
|
|
||||||
```
|
|
||||||
|
|
||||||
### No PIL or cv2 for 16-bit
|
|
||||||
|
|
||||||
16-bit TIFF loading uses `tifffile` directly:
|
|
||||||
- PIL: Can load 16-bit but converts during processing
|
|
||||||
- cv2: Limited 16-bit TIFF support
|
|
||||||
- tifffile: Native 16-bit support, numpy output
|
|
||||||
|
|
||||||
## Advantages Over Disk Caching
|
|
||||||
|
|
||||||
### 1. No Disk Space Required
|
|
||||||
```
|
|
||||||
Dataset: 1000 images × 12 MB = 12 GB
|
|
||||||
Old cache: Additional 24 GB (16-bit RGB PNGs)
|
|
||||||
New approach: 0 GB additional (on-the-fly)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. Faster Setup
|
|
||||||
```
|
|
||||||
Old: First training requires cache creation (minutes)
|
|
||||||
New: Start training immediately (seconds)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Always In Sync
|
|
||||||
```
|
|
||||||
Old: Cache could become stale if images change
|
|
||||||
New: Always loads current version from disk
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. Simpler Workflow
|
|
||||||
```
|
|
||||||
Old: Manage cache directory, cleanup, etc.
|
|
||||||
New: Just point to dataset and train
|
|
||||||
```
|
|
||||||
|
|
||||||
## Troubleshooting
|
|
||||||
|
|
||||||
### Error: "expected input to have 3 channels, but got 1"
|
|
||||||
|
|
||||||
This shouldn't happen with the new Float32Dataset, but if it does:
|
|
||||||
|
|
||||||
1. Check that `use_float32_loader=True` in training call
|
|
||||||
2. Verify `Float32Dataset` is being used (check logs)
|
|
||||||
3. Ensure `tifffile` is installed: `pip install tifffile`
|
|
||||||
|
|
||||||
### Memory Usage
|
|
||||||
|
|
||||||
On-the-fly conversion uses memory during training:
|
|
||||||
- Image loaded: ~24 MB (2048×2048 uint16)
|
|
||||||
- Converted float32 RGB: ~48 MB (temporary)
|
|
||||||
- Released after augmentation pipeline
|
|
||||||
|
|
||||||
**Mitigation:**
|
|
||||||
- Reduce batch size if OOM errors occur
|
|
||||||
- Images are processed one at a time during loading
|
|
||||||
- Only active batch kept in memory
|
|
||||||
|
|
||||||
### Slow Training
|
|
||||||
|
|
||||||
If training seems slow:
|
|
||||||
- Check disk I/O (slow disk can bottleneck loading)
|
|
||||||
- Verify images aren't being re-converted each epoch (should cache after first load)
|
|
||||||
- Monitor CPU usage during loading
|
|
||||||
|
|
||||||
## Migration from Old Approach
|
|
||||||
|
|
||||||
If you have existing cached datasets:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Old cache location (safe to delete)
|
|
||||||
rm -rf data/datasets/_float32_cache/
|
|
||||||
|
|
||||||
# The new approach doesn't use this directory
|
|
||||||
```
|
|
||||||
|
|
||||||
Your original dataset structure remains unchanged:
|
|
||||||
```
|
|
||||||
data/my_dataset/
|
|
||||||
├── train/
|
|
||||||
│ ├── images/ (original 16-bit TIFFs)
|
|
||||||
│ └── labels/
|
|
||||||
├── val/
|
|
||||||
│ ├── images/
|
|
||||||
│ └── labels/
|
|
||||||
└── data.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
Just point to the same `data.yaml` and train!
|
|
||||||
|
|
||||||
## Performance Comparison
|
|
||||||
|
|
||||||
| Metric | Old (Disk Cache) | New (On-the-fly) |
|
|
||||||
|--------|------------------|------------------|
|
|
||||||
| First training setup | 5-10 min | 0 sec |
|
|
||||||
| Disk space overhead | 100% | 0% |
|
|
||||||
| Training speed | Fast | Fast |
|
|
||||||
| Subsequent runs | Fast | Fast |
|
|
||||||
| Data accuracy | 16-bit preserved | 16-bit preserved |
|
|
||||||
|
|
||||||
## Summary
|
|
||||||
|
|
||||||
✓ **On-the-fly conversion**: Load and convert during training
|
|
||||||
✓ **No disk caching**: Zero additional disk space
|
|
||||||
✓ **Full precision**: Float32 preserves 16-bit dynamic range
|
|
||||||
✓ **No PIL/cv2**: Direct tifffile loading
|
|
||||||
✓ **Automatic**: Works transparently with training tab
|
|
||||||
✓ **Fast**: Efficient memory-based conversion
|
|
||||||
|
|
||||||
The new approach is simpler, faster to set up, and requires no disk space overhead!
|
|
||||||
@@ -82,12 +82,12 @@ include-package-data = true
|
|||||||
"src.database" = ["*.sql"]
|
"src.database" = ["*.sql"]
|
||||||
|
|
||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 88
|
line-length = 120
|
||||||
target-version = ['py38', 'py39', 'py310', 'py311']
|
target-version = ['py38', 'py39', 'py310', 'py311']
|
||||||
include = '\.pyi?$'
|
include = '\.pyi?$'
|
||||||
|
|
||||||
[tool.pylint.messages_control]
|
[tool.pylint.messages_control]
|
||||||
max-line-length = 88
|
max-line-length = 120
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
python_version = "3.8"
|
python_version = "3.8"
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ pyqtgraph>=0.13.0
|
|||||||
opencv-python>=4.8.0
|
opencv-python>=4.8.0
|
||||||
Pillow>=10.0.0
|
Pillow>=10.0.0
|
||||||
numpy>=1.24.0
|
numpy>=1.24.0
|
||||||
tifffile>=2023.0.0
|
|
||||||
|
|
||||||
# Database
|
# Database
|
||||||
sqlalchemy>=2.0.0
|
sqlalchemy>=2.0.0
|
||||||
|
|||||||
@@ -1,179 +0,0 @@
|
|||||||
# Standalone Float32 Training Script for 16-bit TIFFs
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
This standalone script (`train_float32_standalone.py`) trains YOLO models on 16-bit grayscale TIFF datasets with **no data loss**.
|
|
||||||
|
|
||||||
- Loads 16-bit TIFFs with `tifffile` (not PIL/cv2)
|
|
||||||
- Converts to float32 [0-1] on-the-fly (preserves full 16-bit precision)
|
|
||||||
- Replicates grayscale → 3-channel RGB in memory
|
|
||||||
- **No disk caching required**
|
|
||||||
- Uses custom PyTorch Dataset + training loop
|
|
||||||
|
|
||||||
## Quick Start
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Activate virtual environment
|
|
||||||
source venv/bin/activate
|
|
||||||
|
|
||||||
# Train on your 16-bit TIFF dataset
|
|
||||||
python scripts/train_float32_standalone.py \
|
|
||||||
--data data/my_dataset/data.yaml \
|
|
||||||
--weights yolov8s-seg.pt \
|
|
||||||
--epochs 100 \
|
|
||||||
--batch 16 \
|
|
||||||
--imgsz 640 \
|
|
||||||
--lr 0.0001 \
|
|
||||||
--save-dir runs/my_training \
|
|
||||||
--device cuda
|
|
||||||
```
|
|
||||||
|
|
||||||
## Arguments
|
|
||||||
|
|
||||||
| Argument | Required | Default | Description |
|
|
||||||
|----------|----------|---------|-------------|
|
|
||||||
| `--data` | Yes | - | Path to YOLO data.yaml file |
|
|
||||||
| `--weights` | No | yolov8s-seg.pt | Pretrained model weights |
|
|
||||||
| `--epochs` | No | 100 | Number of training epochs |
|
|
||||||
| `--batch` | No | 16 | Batch size |
|
|
||||||
| `--imgsz` | No | 640 | Input image size |
|
|
||||||
| `--lr` | No | 0.0001 | Learning rate |
|
|
||||||
| `--save-dir` | No | runs/train | Directory to save checkpoints |
|
|
||||||
| `--device` | No | cuda/cpu | Training device (auto-detected) |
|
|
||||||
|
|
||||||
## Dataset Format
|
|
||||||
|
|
||||||
Your data.yaml should follow standard YOLO format:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
path: /path/to/dataset
|
|
||||||
train: train/images
|
|
||||||
val: val/images
|
|
||||||
test: test/images # optional
|
|
||||||
|
|
||||||
names:
|
|
||||||
0: class1
|
|
||||||
1: class2
|
|
||||||
|
|
||||||
nc: 2
|
|
||||||
```
|
|
||||||
|
|
||||||
Directory structure:
|
|
||||||
```
|
|
||||||
dataset/
|
|
||||||
├── train/
|
|
||||||
│ ├── images/
|
|
||||||
│ │ ├── img1.tif (16-bit grayscale TIFF)
|
|
||||||
│ │ └── img2.tif
|
|
||||||
│ └── labels/
|
|
||||||
│ ├── img1.txt (YOLO format)
|
|
||||||
│ └── img2.txt
|
|
||||||
├── val/
|
|
||||||
│ ├── images/
|
|
||||||
│ └── labels/
|
|
||||||
└── data.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
## Output
|
|
||||||
|
|
||||||
The script saves:
|
|
||||||
- `epoch{N}.pt`: Checkpoint after each epoch
|
|
||||||
- `best.pt`: Best model weights (lowest loss)
|
|
||||||
- Training logs to console
|
|
||||||
|
|
||||||
## Features
|
|
||||||
|
|
||||||
✅ **16-bit precision preserved**: Float32 [0-1] maintains full dynamic range
|
|
||||||
✅ **No disk caching**: Conversion happens in memory
|
|
||||||
✅ **No PIL/cv2**: Direct tifffile loading
|
|
||||||
✅ **Variable-length labels**: Handles segmentation polygons
|
|
||||||
✅ **Checkpoint saving**: Resume training if interrupted
|
|
||||||
✅ **Best model tracking**: Automatically saves best weights
|
|
||||||
|
|
||||||
## Example
|
|
||||||
|
|
||||||
Train a segmentation model on microscopy data:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python scripts/train_float32_standalone.py \
|
|
||||||
--data data/microscopy/data.yaml \
|
|
||||||
--weights yolov11s-seg.pt \
|
|
||||||
--epochs 150 \
|
|
||||||
--batch 8 \
|
|
||||||
--imgsz 1024 \
|
|
||||||
--lr 0.0003 \
|
|
||||||
--save-dir data/models/microscopy_v1
|
|
||||||
```
|
|
||||||
|
|
||||||
## Troubleshooting
|
|
||||||
|
|
||||||
### Out of Memory (OOM)
|
|
||||||
Reduce batch size:
|
|
||||||
```bash
|
|
||||||
--batch 4
|
|
||||||
```
|
|
||||||
|
|
||||||
### Slow Loading
|
|
||||||
Reduce num_workers (edit script line 208):
|
|
||||||
```python
|
|
||||||
num_workers=2 # instead of 4
|
|
||||||
```
|
|
||||||
|
|
||||||
### Different Image Sizes
|
|
||||||
The script expects all images to have the same dimensions. For variable sizes:
|
|
||||||
1. Implement letterbox/resize in dataset's `_read_image()`
|
|
||||||
2. Or preprocess images to same size
|
|
||||||
|
|
||||||
### Loss Computation Errors
|
|
||||||
If you see "Cannot determine loss", the script may need adjustment for your Ultralytics version. Check:
|
|
||||||
```python
|
|
||||||
# In train() function, the preds format may vary
|
|
||||||
# Current script assumes: preds is tuple with loss OR dict with 'loss' key
|
|
||||||
```
|
|
||||||
|
|
||||||
## vs GUI Training
|
|
||||||
|
|
||||||
| Feature | Standalone Script | GUI Training Tab |
|
|
||||||
|---------|------------------|------------------|
|
|
||||||
| Float32 conversion | ✓ Yes | ✓ Yes (automatic) |
|
|
||||||
| Disk caching | ✗ None | ✗ None |
|
|
||||||
| Progress UI | ✗ Console only | ✓ Visual progress bar |
|
|
||||||
| Dataset selection | Manual CLI args | ✓ GUI browsing |
|
|
||||||
| Multi-stage training | Manual runs | ✓ Built-in |
|
|
||||||
| Use case | Advanced users | General users |
|
|
||||||
|
|
||||||
## Technical Details
|
|
||||||
|
|
||||||
### Data Loading Pipeline
|
|
||||||
|
|
||||||
```
|
|
||||||
16-bit TIFF file
|
|
||||||
↓ (tifffile.imread)
|
|
||||||
uint16 [0-65535]
|
|
||||||
↓ (/ 65535.0)
|
|
||||||
float32 [0-1]
|
|
||||||
↓ (replicate channels)
|
|
||||||
float32 RGB (H,W,3) [0-1]
|
|
||||||
↓ (permute to C,H,W)
|
|
||||||
torch.Tensor (3,H,W) float32
|
|
||||||
↓ (DataLoader stack)
|
|
||||||
Batch (B,3,H,W) float32
|
|
||||||
↓
|
|
||||||
YOLO Model
|
|
||||||
```
|
|
||||||
|
|
||||||
### Precision Comparison
|
|
||||||
|
|
||||||
| Method | Unique Values | Data Loss |
|
|
||||||
|--------|---------------|-----------|
|
|
||||||
| **float32 [0-1]** | ~65,536 | None ✓ |
|
|
||||||
| uint16 RGB | 65,536 | None ✓ |
|
|
||||||
| uint8 | 256 | 99.6% ✗ |
|
|
||||||
|
|
||||||
Example: Pixel value 32,768 (middle intensity)
|
|
||||||
- Float32: 32768 / 65535.0 = 0.50000763 (exact)
|
|
||||||
- uint8: 32768 → 128 → many values collapse!
|
|
||||||
|
|
||||||
## License
|
|
||||||
|
|
||||||
Same as main project.
|
|
||||||
@@ -1,349 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Standalone training script for YOLO with 16-bit TIFF float32 support.
|
|
||||||
|
|
||||||
This script trains YOLO models on 16-bit grayscale TIFF datasets without data loss.
|
|
||||||
Converts images to float32 [0-1] on-the-fly using tifffile (no PIL/cv2).
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python scripts/train_float32_standalone.py \\
|
|
||||||
--data path/to/data.yaml \\
|
|
||||||
--weights yolov8s-seg.pt \\
|
|
||||||
--epochs 100 \\
|
|
||||||
--batch 16 \\
|
|
||||||
--imgsz 640
|
|
||||||
|
|
||||||
Based on the custom dataset approach to avoid Ultralytics' channel conversion issues.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import tifffile
|
|
||||||
import yaml
|
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
from ultralytics import YOLO
|
|
||||||
|
|
||||||
# Add project root to path
|
|
||||||
project_root = Path(__file__).parent.parent
|
|
||||||
sys.path.insert(0, str(project_root))
|
|
||||||
|
|
||||||
from src.utils.logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# ===================== Dataset =====================
|
|
||||||
|
|
||||||
|
|
||||||
class Float32YOLODataset(Dataset):
|
|
||||||
"""PyTorch dataset for 16-bit TIFF images with float32 conversion."""
|
|
||||||
|
|
||||||
def __init__(self, images_dir, labels_dir, img_size=640):
|
|
||||||
self.images_dir = Path(images_dir)
|
|
||||||
self.labels_dir = Path(labels_dir)
|
|
||||||
self.img_size = img_size
|
|
||||||
|
|
||||||
# Find images
|
|
||||||
extensions = {".tif", ".tiff", ".png", ".jpg", ".jpeg", ".bmp"}
|
|
||||||
self.paths = sorted(
|
|
||||||
[
|
|
||||||
p
|
|
||||||
for p in self.images_dir.rglob("*")
|
|
||||||
if p.is_file() and p.suffix.lower() in extensions
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.paths:
|
|
||||||
raise ValueError(f"No images found in {images_dir}")
|
|
||||||
|
|
||||||
logger.info(f"Dataset: {len(self.paths)} images from {images_dir}")
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.paths)
|
|
||||||
|
|
||||||
def _read_image(self, path: Path) -> np.ndarray:
|
|
||||||
"""Load image as float32 [0-1] RGB."""
|
|
||||||
# Load with tifffile
|
|
||||||
img = tifffile.imread(str(path))
|
|
||||||
|
|
||||||
# Convert to float32
|
|
||||||
img = img.astype(np.float32)
|
|
||||||
|
|
||||||
# Normalize 16-bit→[0,1]
|
|
||||||
if img.max() > 1.5:
|
|
||||||
img = img / 65535.0
|
|
||||||
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
|
|
||||||
# Grayscale→RGB
|
|
||||||
if img.ndim == 2:
|
|
||||||
img = np.repeat(img[..., None], 3, axis=2)
|
|
||||||
elif img.ndim == 3 and img.shape[2] == 1:
|
|
||||||
img = np.repeat(img, 3, axis=2)
|
|
||||||
|
|
||||||
return img # float32 (H,W,3) [0,1]
|
|
||||||
|
|
||||||
def _parse_label(self, path: Path) -> np.ndarray:
|
|
||||||
"""Parse YOLO label with variable-length rows."""
|
|
||||||
if not path.exists():
|
|
||||||
return np.zeros((0, 5), dtype=np.float32)
|
|
||||||
|
|
||||||
labels = []
|
|
||||||
with open(path, "r") as f:
|
|
||||||
for line in f:
|
|
||||||
vals = line.strip().split()
|
|
||||||
if len(vals) >= 5:
|
|
||||||
labels.append([float(v) for v in vals])
|
|
||||||
|
|
||||||
return (
|
|
||||||
np.array(labels, dtype=np.float32)
|
|
||||||
if labels
|
|
||||||
else np.zeros((0, 5), dtype=np.float32)
|
|
||||||
)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
img_path = self.paths[idx]
|
|
||||||
label_path = self.labels_dir / f"{img_path.stem}.txt"
|
|
||||||
|
|
||||||
# Load & convert to tensor (C,H,W)
|
|
||||||
img = self._read_image(img_path)
|
|
||||||
img_t = torch.from_numpy(img).permute(2, 0, 1).contiguous()
|
|
||||||
|
|
||||||
# Load labels
|
|
||||||
labels = self._parse_label(label_path)
|
|
||||||
|
|
||||||
return img_t, labels, str(img_path.name)
|
|
||||||
|
|
||||||
|
|
||||||
# ===================== Collate =====================
|
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(batch):
|
|
||||||
"""Stack images, keep labels as list."""
|
|
||||||
imgs = torch.stack([b[0] for b in batch], dim=0)
|
|
||||||
labels = [b[1] for b in batch]
|
|
||||||
names = [b[2] for b in batch]
|
|
||||||
return imgs, labels, names
|
|
||||||
|
|
||||||
|
|
||||||
# ===================== Training =====================
|
|
||||||
|
|
||||||
|
|
||||||
def get_pytorch_model(ul_model):
|
|
||||||
"""Extract PyTorch model and loss from Ultralytics wrapper."""
|
|
||||||
pt_model = None
|
|
||||||
loss_fn = None
|
|
||||||
|
|
||||||
# Try common patterns
|
|
||||||
if hasattr(ul_model, "model"):
|
|
||||||
pt_model = ul_model.model
|
|
||||||
if pt_model and hasattr(pt_model, "model"):
|
|
||||||
pt_model = pt_model.model
|
|
||||||
|
|
||||||
# Find loss
|
|
||||||
if pt_model and hasattr(pt_model, "loss"):
|
|
||||||
loss_fn = pt_model.loss
|
|
||||||
elif pt_model and hasattr(pt_model, "compute_loss"):
|
|
||||||
loss_fn = pt_model.compute_loss
|
|
||||||
|
|
||||||
if pt_model is None:
|
|
||||||
raise RuntimeError("Could not extract PyTorch model")
|
|
||||||
|
|
||||||
return pt_model, loss_fn
|
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
|
||||||
"""Main training function."""
|
|
||||||
device = args.device
|
|
||||||
logger.info(f"Device: {device}")
|
|
||||||
|
|
||||||
# Parse data.yaml
|
|
||||||
with open(args.data, "r") as f:
|
|
||||||
data_config = yaml.safe_load(f)
|
|
||||||
|
|
||||||
dataset_root = Path(data_config.get("path", Path(args.data).parent))
|
|
||||||
train_img = dataset_root / data_config.get("train", "train/images")
|
|
||||||
val_img = dataset_root / data_config.get("val", "val/images")
|
|
||||||
train_lbl = train_img.parent / "labels"
|
|
||||||
val_lbl = val_img.parent / "labels"
|
|
||||||
|
|
||||||
# Load model
|
|
||||||
logger.info(f"Loading {args.weights}")
|
|
||||||
ul_model = YOLO(args.weights)
|
|
||||||
pt_model, loss_fn = get_pytorch_model(ul_model)
|
|
||||||
|
|
||||||
# Configure model args
|
|
||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
if not hasattr(pt_model, "args"):
|
|
||||||
pt_model.args = SimpleNamespace()
|
|
||||||
if isinstance(pt_model.args, dict):
|
|
||||||
pt_model.args = SimpleNamespace(**pt_model.args)
|
|
||||||
|
|
||||||
# Set segmentation loss args
|
|
||||||
pt_model.args.overlap_mask = getattr(pt_model.args, "overlap_mask", True)
|
|
||||||
pt_model.args.mask_ratio = getattr(pt_model.args, "mask_ratio", 4)
|
|
||||||
pt_model.args.task = "segment"
|
|
||||||
|
|
||||||
pt_model.to(device)
|
|
||||||
pt_model.train()
|
|
||||||
|
|
||||||
# Create datasets
|
|
||||||
train_ds = Float32YOLODataset(str(train_img), str(train_lbl), args.imgsz)
|
|
||||||
val_ds = Float32YOLODataset(str(val_img), str(val_lbl), args.imgsz)
|
|
||||||
|
|
||||||
train_loader = DataLoader(
|
|
||||||
train_ds,
|
|
||||||
batch_size=args.batch,
|
|
||||||
shuffle=True,
|
|
||||||
num_workers=4,
|
|
||||||
pin_memory=(device == "cuda"),
|
|
||||||
collate_fn=collate_fn,
|
|
||||||
)
|
|
||||||
val_loader = DataLoader(
|
|
||||||
val_ds,
|
|
||||||
batch_size=args.batch,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=2,
|
|
||||||
pin_memory=(device == "cuda"),
|
|
||||||
collate_fn=collate_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Optimizer
|
|
||||||
optimizer = torch.optim.AdamW(pt_model.parameters(), lr=args.lr)
|
|
||||||
|
|
||||||
# Training loop
|
|
||||||
os.makedirs(args.save_dir, exist_ok=True)
|
|
||||||
best_loss = float("inf")
|
|
||||||
|
|
||||||
for epoch in range(args.epochs):
|
|
||||||
t0 = time.time()
|
|
||||||
running_loss = 0.0
|
|
||||||
num_batches = 0
|
|
||||||
|
|
||||||
for imgs, labels_list, names in train_loader:
|
|
||||||
imgs = imgs.to(device)
|
|
||||||
optimizer.zero_grad()
|
|
||||||
num_batches += 1
|
|
||||||
|
|
||||||
# Forward (simple approach - just use preds)
|
|
||||||
preds = pt_model(imgs)
|
|
||||||
|
|
||||||
# Try to compute loss
|
|
||||||
# Simplest fallback: if preds is tuple/list, assume last element is loss
|
|
||||||
if isinstance(preds, (tuple, list)):
|
|
||||||
# Often YOLO forward returns (preds, loss) in training mode
|
|
||||||
if (
|
|
||||||
len(preds) >= 2
|
|
||||||
and isinstance(preds[-1], dict)
|
|
||||||
and "loss" in preds[-1]
|
|
||||||
):
|
|
||||||
loss = preds[-1]["loss"]
|
|
||||||
elif len(preds) >= 2 and isinstance(preds[-1], torch.Tensor):
|
|
||||||
loss = preds[-1]
|
|
||||||
else:
|
|
||||||
# Manually compute using loss_fn if available
|
|
||||||
if loss_fn:
|
|
||||||
# This may fail - see logs
|
|
||||||
try:
|
|
||||||
loss_out = loss_fn(preds, labels_list)
|
|
||||||
loss = (
|
|
||||||
loss_out[0]
|
|
||||||
if isinstance(loss_out, (tuple, list))
|
|
||||||
else loss_out
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Loss computation failed: {e}")
|
|
||||||
logger.error(
|
|
||||||
"Consider using Ultralytics .train() or check model/loss compatibility"
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Cannot determine loss from model output")
|
|
||||||
elif isinstance(preds, dict) and "loss" in preds:
|
|
||||||
loss = preds["loss"]
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Unexpected preds format: {type(preds)}")
|
|
||||||
|
|
||||||
# Backward
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
running_loss += loss.item()
|
|
||||||
|
|
||||||
if (num_batches % 10) == 0:
|
|
||||||
logger.info(
|
|
||||||
f"Epoch {epoch+1} Batch {num_batches} Loss: {loss.item():.4f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
epoch_loss = running_loss / max(1, num_batches)
|
|
||||||
epoch_time = time.time() - t0
|
|
||||||
logger.info(
|
|
||||||
f"Epoch {epoch+1}/{args.epochs} - Loss: {epoch_loss:.4f}, Time: {epoch_time:.1f}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save checkpoint
|
|
||||||
ckpt = Path(args.save_dir) / f"epoch{epoch+1}.pt"
|
|
||||||
torch.save(
|
|
||||||
{
|
|
||||||
"epoch": epoch + 1,
|
|
||||||
"model_state_dict": pt_model.state_dict(),
|
|
||||||
"optimizer_state_dict": optimizer.state_dict(),
|
|
||||||
"loss": epoch_loss,
|
|
||||||
},
|
|
||||||
ckpt,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save best
|
|
||||||
if epoch_loss < best_loss:
|
|
||||||
best_loss = epoch_loss
|
|
||||||
best_ckpt = Path(args.save_dir) / "best.pt"
|
|
||||||
torch.save(pt_model.state_dict(), best_ckpt)
|
|
||||||
logger.info(f"New best: {best_ckpt}")
|
|
||||||
|
|
||||||
logger.info("Training complete")
|
|
||||||
|
|
||||||
|
|
||||||
# ===================== Main =====================
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Train YOLO on 16-bit TIFF with float32"
|
|
||||||
)
|
|
||||||
parser.add_argument("--data", type=str, required=True, help="Path to data.yaml")
|
|
||||||
parser.add_argument(
|
|
||||||
"--weights", type=str, default="yolov8s-seg.pt", help="Pretrained weights"
|
|
||||||
)
|
|
||||||
parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
|
|
||||||
parser.add_argument("--batch", type=int, default=16, help="Batch size")
|
|
||||||
parser.add_argument("--imgsz", type=int, default=640, help="Image size")
|
|
||||||
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
|
|
||||||
parser.add_argument(
|
|
||||||
"--save-dir", type=str, default="runs/train", help="Save directory"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
)
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
args = parse_args()
|
|
||||||
logger.info("=" * 70)
|
|
||||||
logger.info("Float32 16-bit TIFF Training - Standalone Script")
|
|
||||||
logger.info("=" * 70)
|
|
||||||
logger.info(f"Data: {args.data}")
|
|
||||||
logger.info(f"Weights: {args.weights}")
|
|
||||||
logger.info(f"Epochs: {args.epochs}, Batch: {args.batch}, ImgSz: {args.imgsz}")
|
|
||||||
logger.info(f"LR: {args.lr}, Device: {args.device}")
|
|
||||||
logger.info("=" * 70)
|
|
||||||
|
|
||||||
train(args)
|
|
||||||
@@ -60,9 +60,7 @@ class DatabaseManager:
|
|||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
# Check if annotations table exists
|
# Check if annotations table exists
|
||||||
cursor.execute(
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='annotations'")
|
||||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='annotations'"
|
|
||||||
)
|
|
||||||
if not cursor.fetchone():
|
if not cursor.fetchone():
|
||||||
# Table doesn't exist yet, no migration needed
|
# Table doesn't exist yet, no migration needed
|
||||||
return
|
return
|
||||||
@@ -203,6 +201,28 @@ class DatabaseManager:
|
|||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
def delete_model(self, model_id: int) -> bool:
|
||||||
|
"""Delete a model from the database.
|
||||||
|
|
||||||
|
Note: detections referencing this model are deleted automatically via
|
||||||
|
the `detections.model_id` foreign key (ON DELETE CASCADE).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: ID of the model to delete.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if a model row was deleted, False otherwise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
conn = self.get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("DELETE FROM models WHERE id = ?", (model_id,))
|
||||||
|
conn.commit()
|
||||||
|
return cursor.rowcount > 0
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
# ==================== Image Operations ====================
|
# ==================== Image Operations ====================
|
||||||
|
|
||||||
def add_image(
|
def add_image(
|
||||||
@@ -242,9 +262,7 @@ class DatabaseManager:
|
|||||||
return cursor.lastrowid
|
return cursor.lastrowid
|
||||||
except sqlite3.IntegrityError:
|
except sqlite3.IntegrityError:
|
||||||
# Image already exists, return its ID
|
# Image already exists, return its ID
|
||||||
cursor.execute(
|
cursor.execute("SELECT id FROM images WHERE relative_path = ?", (relative_path,))
|
||||||
"SELECT id FROM images WHERE relative_path = ?", (relative_path,)
|
|
||||||
)
|
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
return row["id"] if row else None
|
return row["id"] if row else None
|
||||||
finally:
|
finally:
|
||||||
@@ -255,17 +273,13 @@ class DatabaseManager:
|
|||||||
conn = self.get_connection()
|
conn = self.get_connection()
|
||||||
try:
|
try:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute(
|
cursor.execute("SELECT * FROM images WHERE relative_path = ?", (relative_path,))
|
||||||
"SELECT * FROM images WHERE relative_path = ?", (relative_path,)
|
|
||||||
)
|
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
return dict(row) if row else None
|
return dict(row) if row else None
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def get_or_create_image(
|
def get_or_create_image(self, relative_path: str, filename: str, width: int, height: int) -> int:
|
||||||
self, relative_path: str, filename: str, width: int, height: int
|
|
||||||
) -> int:
|
|
||||||
"""Get existing image or create new one."""
|
"""Get existing image or create new one."""
|
||||||
existing = self.get_image_by_path(relative_path)
|
existing = self.get_image_by_path(relative_path)
|
||||||
if existing:
|
if existing:
|
||||||
@@ -355,16 +369,8 @@ class DatabaseManager:
|
|||||||
bbox[2],
|
bbox[2],
|
||||||
bbox[3],
|
bbox[3],
|
||||||
det["confidence"],
|
det["confidence"],
|
||||||
(
|
(json.dumps(det.get("segmentation_mask")) if det.get("segmentation_mask") else None),
|
||||||
json.dumps(det.get("segmentation_mask"))
|
(json.dumps(det.get("metadata")) if det.get("metadata") else None),
|
||||||
if det.get("segmentation_mask")
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
json.dumps(det.get("metadata"))
|
|
||||||
if det.get("metadata")
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
@@ -409,12 +415,13 @@ class DatabaseManager:
|
|||||||
if filters:
|
if filters:
|
||||||
conditions = []
|
conditions = []
|
||||||
for key, value in filters.items():
|
for key, value in filters.items():
|
||||||
if (
|
if key.startswith("d.") or key.startswith("i.") or key.startswith("m."):
|
||||||
key.startswith("d.")
|
if "like" in value.lower():
|
||||||
or key.startswith("i.")
|
conditions.append(f"{key} LIKE ?")
|
||||||
or key.startswith("m.")
|
params.append(value.split(" ")[1])
|
||||||
):
|
else:
|
||||||
conditions.append(f"{key} = ?")
|
conditions.append(f"{key} = ?")
|
||||||
|
params.append(value)
|
||||||
else:
|
else:
|
||||||
conditions.append(f"d.{key} = ?")
|
conditions.append(f"d.{key} = ?")
|
||||||
params.append(value)
|
params.append(value)
|
||||||
@@ -442,18 +449,14 @@ class DatabaseManager:
|
|||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def get_detections_for_image(
|
def get_detections_for_image(self, image_id: int, model_id: Optional[int] = None) -> List[Dict]:
|
||||||
self, image_id: int, model_id: Optional[int] = None
|
|
||||||
) -> List[Dict]:
|
|
||||||
"""Get all detections for a specific image."""
|
"""Get all detections for a specific image."""
|
||||||
filters = {"image_id": image_id}
|
filters = {"image_id": image_id}
|
||||||
if model_id:
|
if model_id:
|
||||||
filters["model_id"] = model_id
|
filters["model_id"] = model_id
|
||||||
return self.get_detections(filters)
|
return self.get_detections(filters)
|
||||||
|
|
||||||
def delete_detections_for_image(
|
def delete_detections_for_image(self, image_id: int, model_id: Optional[int] = None) -> int:
|
||||||
self, image_id: int, model_id: Optional[int] = None
|
|
||||||
) -> int:
|
|
||||||
"""Delete detections tied to a specific image and optional model."""
|
"""Delete detections tied to a specific image and optional model."""
|
||||||
conn = self.get_connection()
|
conn = self.get_connection()
|
||||||
try:
|
try:
|
||||||
@@ -481,6 +484,22 @@ class DatabaseManager:
|
|||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
def delete_all_detections(self) -> int:
|
||||||
|
"""Delete all detections from the database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of rows deleted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
conn = self.get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("DELETE FROM detections")
|
||||||
|
conn.commit()
|
||||||
|
return cursor.rowcount
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
# ==================== Statistics Operations ====================
|
# ==================== Statistics Operations ====================
|
||||||
|
|
||||||
def get_detection_statistics(
|
def get_detection_statistics(
|
||||||
@@ -524,9 +543,7 @@ class DatabaseManager:
|
|||||||
""",
|
""",
|
||||||
params,
|
params,
|
||||||
)
|
)
|
||||||
class_counts = {
|
class_counts = {row["class_name"]: row["count"] for row in cursor.fetchall()}
|
||||||
row["class_name"]: row["count"] for row in cursor.fetchall()
|
|
||||||
}
|
|
||||||
|
|
||||||
# Average confidence
|
# Average confidence
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
@@ -583,9 +600,7 @@ class DatabaseManager:
|
|||||||
|
|
||||||
# ==================== Export Operations ====================
|
# ==================== Export Operations ====================
|
||||||
|
|
||||||
def export_detections_to_csv(
|
def export_detections_to_csv(self, output_path: str, filters: Optional[Dict] = None) -> bool:
|
||||||
self, output_path: str, filters: Optional[Dict] = None
|
|
||||||
) -> bool:
|
|
||||||
"""Export detections to CSV file."""
|
"""Export detections to CSV file."""
|
||||||
try:
|
try:
|
||||||
detections = self.get_detections(filters)
|
detections = self.get_detections(filters)
|
||||||
@@ -614,9 +629,7 @@ class DatabaseManager:
|
|||||||
for det in detections:
|
for det in detections:
|
||||||
row = {k: det[k] for k in fieldnames if k in det}
|
row = {k: det[k] for k in fieldnames if k in det}
|
||||||
# Convert segmentation mask list to JSON string for CSV
|
# Convert segmentation mask list to JSON string for CSV
|
||||||
if row.get("segmentation_mask") and isinstance(
|
if row.get("segmentation_mask") and isinstance(row["segmentation_mask"], list):
|
||||||
row["segmentation_mask"], list
|
|
||||||
):
|
|
||||||
row["segmentation_mask"] = json.dumps(row["segmentation_mask"])
|
row["segmentation_mask"] = json.dumps(row["segmentation_mask"])
|
||||||
writer.writerow(row)
|
writer.writerow(row)
|
||||||
|
|
||||||
@@ -625,9 +638,7 @@ class DatabaseManager:
|
|||||||
print(f"Error exporting to CSV: {e}")
|
print(f"Error exporting to CSV: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def export_detections_to_json(
|
def export_detections_to_json(self, output_path: str, filters: Optional[Dict] = None) -> bool:
|
||||||
self, output_path: str, filters: Optional[Dict] = None
|
|
||||||
) -> bool:
|
|
||||||
"""Export detections to JSON file."""
|
"""Export detections to JSON file."""
|
||||||
try:
|
try:
|
||||||
detections = self.get_detections(filters)
|
detections = self.get_detections(filters)
|
||||||
@@ -647,6 +658,75 @@ class DatabaseManager:
|
|||||||
|
|
||||||
# ==================== Annotation Operations ====================
|
# ==================== Annotation Operations ====================
|
||||||
|
|
||||||
|
def get_annotated_images_summary(
|
||||||
|
self,
|
||||||
|
name_filter: Optional[str] = None,
|
||||||
|
order_by: str = "filename",
|
||||||
|
order_dir: str = "ASC",
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""Return images that have at least one manual annotation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name_filter: Optional substring filter applied to filename/relative_path.
|
||||||
|
order_by: One of: 'filename', 'relative_path', 'annotation_count', 'added_at'.
|
||||||
|
order_dir: 'ASC' or 'DESC'.
|
||||||
|
limit: Optional max number of rows.
|
||||||
|
offset: Pagination offset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts: {id, relative_path, filename, added_at, annotation_count}
|
||||||
|
"""
|
||||||
|
|
||||||
|
allowed_order_by = {
|
||||||
|
"filename": "i.filename",
|
||||||
|
"relative_path": "i.relative_path",
|
||||||
|
"annotation_count": "annotation_count",
|
||||||
|
"added_at": "i.added_at",
|
||||||
|
}
|
||||||
|
order_expr = allowed_order_by.get(order_by, "i.filename")
|
||||||
|
dir_norm = str(order_dir).upper().strip()
|
||||||
|
if dir_norm not in {"ASC", "DESC"}:
|
||||||
|
dir_norm = "ASC"
|
||||||
|
|
||||||
|
conn = self.get_connection()
|
||||||
|
try:
|
||||||
|
params: List[Any] = []
|
||||||
|
where_sql = ""
|
||||||
|
if name_filter:
|
||||||
|
# Case-insensitive substring search.
|
||||||
|
token = f"%{name_filter}%"
|
||||||
|
where_sql = "WHERE (i.filename LIKE ? OR i.relative_path LIKE ?)"
|
||||||
|
params.extend([token, token])
|
||||||
|
|
||||||
|
limit_sql = ""
|
||||||
|
if limit is not None:
|
||||||
|
limit_sql = " LIMIT ? OFFSET ?"
|
||||||
|
params.extend([int(limit), int(offset)])
|
||||||
|
|
||||||
|
query = f"""
|
||||||
|
SELECT
|
||||||
|
i.id,
|
||||||
|
i.relative_path,
|
||||||
|
i.filename,
|
||||||
|
i.added_at,
|
||||||
|
COUNT(a.id) AS annotation_count
|
||||||
|
FROM images i
|
||||||
|
JOIN annotations a ON a.image_id = i.id
|
||||||
|
{where_sql}
|
||||||
|
GROUP BY i.id
|
||||||
|
HAVING annotation_count > 0
|
||||||
|
ORDER BY {order_expr} {dir_norm}
|
||||||
|
{limit_sql}
|
||||||
|
"""
|
||||||
|
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(query, params)
|
||||||
|
return [dict(row) for row in cursor.fetchall()]
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
def add_annotation(
|
def add_annotation(
|
||||||
self,
|
self,
|
||||||
image_id: int,
|
image_id: int,
|
||||||
@@ -785,17 +865,13 @@ class DatabaseManager:
|
|||||||
conn = self.get_connection()
|
conn = self.get_connection()
|
||||||
try:
|
try:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute(
|
cursor.execute("SELECT * FROM object_classes WHERE class_name = ?", (class_name,))
|
||||||
"SELECT * FROM object_classes WHERE class_name = ?", (class_name,)
|
|
||||||
)
|
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
return dict(row) if row else None
|
return dict(row) if row else None
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def add_object_class(
|
def add_object_class(self, class_name: str, color: str, description: Optional[str] = None) -> int:
|
||||||
self, class_name: str, color: str, description: Optional[str] = None
|
|
||||||
) -> int:
|
|
||||||
"""
|
"""
|
||||||
Add a new object class.
|
Add a new object class.
|
||||||
|
|
||||||
@@ -928,8 +1004,7 @@ class DatabaseManager:
|
|||||||
if not split_map[required]:
|
if not split_map[required]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unable to determine %s image directory under %s. Provide it "
|
"Unable to determine %s image directory under %s. Provide it "
|
||||||
"explicitly via the 'splits' argument."
|
"explicitly via the 'splits' argument." % (required, dataset_root_path)
|
||||||
% (required, dataset_root_path)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
yaml_splits: Dict[str, str] = {}
|
yaml_splits: Dict[str, str] = {}
|
||||||
@@ -955,11 +1030,7 @@ class DatabaseManager:
|
|||||||
if yaml_splits.get("test"):
|
if yaml_splits.get("test"):
|
||||||
payload["test"] = yaml_splits["test"]
|
payload["test"] = yaml_splits["test"]
|
||||||
|
|
||||||
output_path_obj = (
|
output_path_obj = Path(output_path).expanduser() if output_path else dataset_root_path / "data.yaml"
|
||||||
Path(output_path).expanduser()
|
|
||||||
if output_path
|
|
||||||
else dataset_root_path / "data.yaml"
|
|
||||||
)
|
|
||||||
output_path_obj.parent.mkdir(parents=True, exist_ok=True)
|
output_path_obj.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
with open(output_path_obj, "w", encoding="utf-8") as handle:
|
with open(output_path_obj, "w", encoding="utf-8") as handle:
|
||||||
@@ -1019,15 +1090,9 @@ class DatabaseManager:
|
|||||||
for split_name, options in patterns.items():
|
for split_name, options in patterns.items():
|
||||||
for relative in options:
|
for relative in options:
|
||||||
candidate = (dataset_root / relative).resolve()
|
candidate = (dataset_root / relative).resolve()
|
||||||
if (
|
if candidate.exists() and candidate.is_dir() and self._directory_has_images(candidate):
|
||||||
candidate.exists()
|
|
||||||
and candidate.is_dir()
|
|
||||||
and self._directory_has_images(candidate)
|
|
||||||
):
|
|
||||||
try:
|
try:
|
||||||
inferred[split_name] = candidate.relative_to(
|
inferred[split_name] = candidate.relative_to(dataset_root).as_posix()
|
||||||
dataset_root
|
|
||||||
).as_posix()
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
inferred[split_name] = candidate.as_posix()
|
inferred[split_name] = candidate.as_posix()
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -55,10 +55,7 @@ CREATE TABLE IF NOT EXISTS object_classes (
|
|||||||
|
|
||||||
-- Insert default object classes
|
-- Insert default object classes
|
||||||
INSERT OR IGNORE INTO object_classes (class_name, color, description) VALUES
|
INSERT OR IGNORE INTO object_classes (class_name, color, description) VALUES
|
||||||
('cell', '#FF0000', 'Cell object'),
|
('terminal', '#FFFF00', 'Axion terminal');
|
||||||
('nucleus', '#00FF00', 'Cell nucleus'),
|
|
||||||
('mitochondria', '#0000FF', 'Mitochondria'),
|
|
||||||
('vesicle', '#FFFF00', 'Vesicle');
|
|
||||||
|
|
||||||
-- Annotations table: stores manual annotations
|
-- Annotations table: stores manual annotations
|
||||||
CREATE TABLE IF NOT EXISTS annotations (
|
CREATE TABLE IF NOT EXISTS annotations (
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""Main window for the microscopy object detection application."""
|
||||||
Main window for the microscopy object detection application.
|
|
||||||
"""
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from PySide6.QtWidgets import (
|
from PySide6.QtWidgets import (
|
||||||
QMainWindow,
|
QMainWindow,
|
||||||
@@ -20,6 +21,7 @@ from src.database.db_manager import DatabaseManager
|
|||||||
from src.utils.config_manager import ConfigManager
|
from src.utils.config_manager import ConfigManager
|
||||||
from src.utils.logger import get_logger
|
from src.utils.logger import get_logger
|
||||||
from src.gui.dialogs.config_dialog import ConfigDialog
|
from src.gui.dialogs.config_dialog import ConfigDialog
|
||||||
|
from src.gui.dialogs.delete_model_dialog import DeleteModelDialog
|
||||||
from src.gui.tabs.detection_tab import DetectionTab
|
from src.gui.tabs.detection_tab import DetectionTab
|
||||||
from src.gui.tabs.training_tab import TrainingTab
|
from src.gui.tabs.training_tab import TrainingTab
|
||||||
from src.gui.tabs.validation_tab import ValidationTab
|
from src.gui.tabs.validation_tab import ValidationTab
|
||||||
@@ -91,6 +93,12 @@ class MainWindow(QMainWindow):
|
|||||||
db_stats_action.triggered.connect(self._show_database_stats)
|
db_stats_action.triggered.connect(self._show_database_stats)
|
||||||
tools_menu.addAction(db_stats_action)
|
tools_menu.addAction(db_stats_action)
|
||||||
|
|
||||||
|
tools_menu.addSeparator()
|
||||||
|
|
||||||
|
delete_model_action = QAction("Delete &Model…", self)
|
||||||
|
delete_model_action.triggered.connect(self._show_delete_model_dialog)
|
||||||
|
tools_menu.addAction(delete_model_action)
|
||||||
|
|
||||||
# Help menu
|
# Help menu
|
||||||
help_menu = menubar.addMenu("&Help")
|
help_menu = menubar.addMenu("&Help")
|
||||||
|
|
||||||
@@ -117,10 +125,10 @@ class MainWindow(QMainWindow):
|
|||||||
|
|
||||||
# Add tabs to widget
|
# Add tabs to widget
|
||||||
self.tab_widget.addTab(self.detection_tab, "Detection")
|
self.tab_widget.addTab(self.detection_tab, "Detection")
|
||||||
|
self.tab_widget.addTab(self.results_tab, "Results")
|
||||||
|
self.tab_widget.addTab(self.annotation_tab, "Annotation")
|
||||||
self.tab_widget.addTab(self.training_tab, "Training")
|
self.tab_widget.addTab(self.training_tab, "Training")
|
||||||
self.tab_widget.addTab(self.validation_tab, "Validation")
|
self.tab_widget.addTab(self.validation_tab, "Validation")
|
||||||
self.tab_widget.addTab(self.results_tab, "Results")
|
|
||||||
self.tab_widget.addTab(self.annotation_tab, "Annotation (Future)")
|
|
||||||
|
|
||||||
# Connect tab change signal
|
# Connect tab change signal
|
||||||
self.tab_widget.currentChanged.connect(self._on_tab_changed)
|
self.tab_widget.currentChanged.connect(self._on_tab_changed)
|
||||||
@@ -152,9 +160,7 @@ class MainWindow(QMainWindow):
|
|||||||
"""Center window on screen."""
|
"""Center window on screen."""
|
||||||
screen = self.screen().geometry()
|
screen = self.screen().geometry()
|
||||||
size = self.geometry()
|
size = self.geometry()
|
||||||
self.move(
|
self.move((screen.width() - size.width()) // 2, (screen.height() - size.height()) // 2)
|
||||||
(screen.width() - size.width()) // 2, (screen.height() - size.height()) // 2
|
|
||||||
)
|
|
||||||
|
|
||||||
def _restore_window_state(self):
|
def _restore_window_state(self):
|
||||||
"""Restore window geometry from settings or center window."""
|
"""Restore window geometry from settings or center window."""
|
||||||
@@ -193,6 +199,10 @@ class MainWindow(QMainWindow):
|
|||||||
self.training_tab.refresh()
|
self.training_tab.refresh()
|
||||||
if hasattr(self, "results_tab"):
|
if hasattr(self, "results_tab"):
|
||||||
self.results_tab.refresh()
|
self.results_tab.refresh()
|
||||||
|
if hasattr(self, "annotation_tab"):
|
||||||
|
self.annotation_tab.refresh()
|
||||||
|
if hasattr(self, "validation_tab"):
|
||||||
|
self.validation_tab.refresh()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error applying settings: {e}")
|
logger.error(f"Error applying settings: {e}")
|
||||||
|
|
||||||
@@ -209,6 +219,14 @@ class MainWindow(QMainWindow):
|
|||||||
logger.debug(f"Switched to tab: {tab_name}")
|
logger.debug(f"Switched to tab: {tab_name}")
|
||||||
self._update_status(f"Viewing: {tab_name}")
|
self._update_status(f"Viewing: {tab_name}")
|
||||||
|
|
||||||
|
# Ensure the Annotation tab always shows up-to-date DB-backed lists.
|
||||||
|
try:
|
||||||
|
current_widget = self.tab_widget.widget(index)
|
||||||
|
if hasattr(self, "annotation_tab") and current_widget is self.annotation_tab:
|
||||||
|
self.annotation_tab.refresh()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug(f"Failed to refresh annotation tab on selection: {exc}")
|
||||||
|
|
||||||
def _show_database_stats(self):
|
def _show_database_stats(self):
|
||||||
"""Show database statistics dialog."""
|
"""Show database statistics dialog."""
|
||||||
try:
|
try:
|
||||||
@@ -231,10 +249,230 @@ class MainWindow(QMainWindow):
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting database stats: {e}")
|
logger.error(f"Error getting database stats: {e}")
|
||||||
QMessageBox.warning(
|
QMessageBox.warning(self, "Error", f"Failed to get database statistics:\n{str(e)}")
|
||||||
self, "Error", f"Failed to get database statistics:\n{str(e)}"
|
|
||||||
|
def _show_delete_model_dialog(self) -> None:
|
||||||
|
"""Open the model deletion dialog."""
|
||||||
|
dialog = DeleteModelDialog(self.db_manager, self)
|
||||||
|
if not dialog.exec():
|
||||||
|
return
|
||||||
|
|
||||||
|
model_ids = dialog.selected_model_ids
|
||||||
|
if not model_ids:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._delete_models(model_ids)
|
||||||
|
|
||||||
|
def _delete_models(self, model_ids: list[int]) -> None:
|
||||||
|
"""Delete one or more models from the database and remove artifacts from disk."""
|
||||||
|
|
||||||
|
deleted_count = 0
|
||||||
|
removed_paths: list[str] = []
|
||||||
|
remove_errors: list[str] = []
|
||||||
|
|
||||||
|
for model_id in model_ids:
|
||||||
|
model = None
|
||||||
|
try:
|
||||||
|
model = self.db_manager.get_model_by_id(int(model_id))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Failed to load model {model_id} before deletion: {exc}")
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
remove_errors.append(f"Model id {model_id} not found in database.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
deleted = self.db_manager.delete_model(int(model_id))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Failed to delete model {model_id}: {exc}")
|
||||||
|
remove_errors.append(f"Failed to delete model id {model_id} from DB: {exc}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not deleted:
|
||||||
|
remove_errors.append(f"Model id {model_id} was not deleted (already removed?).")
|
||||||
|
continue
|
||||||
|
|
||||||
|
deleted_count += 1
|
||||||
|
removed, errors = self._delete_model_artifacts_from_disk(model)
|
||||||
|
removed_paths.extend(removed)
|
||||||
|
remove_errors.extend(errors)
|
||||||
|
|
||||||
|
# Refresh tabs to reflect the deletion(s).
|
||||||
|
try:
|
||||||
|
if hasattr(self, "detection_tab"):
|
||||||
|
self.detection_tab.refresh()
|
||||||
|
if hasattr(self, "results_tab"):
|
||||||
|
self.results_tab.refresh()
|
||||||
|
if hasattr(self, "validation_tab"):
|
||||||
|
self.validation_tab.refresh()
|
||||||
|
if hasattr(self, "training_tab"):
|
||||||
|
self.training_tab.refresh()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"Failed to refresh tabs after model deletion: {exc}")
|
||||||
|
|
||||||
|
details: list[str] = []
|
||||||
|
if removed_paths:
|
||||||
|
details.append("Removed from disk:\n" + "\n".join(removed_paths))
|
||||||
|
if remove_errors:
|
||||||
|
details.append("\nDisk cleanup warnings:\n" + "\n".join(remove_errors))
|
||||||
|
|
||||||
|
QMessageBox.information(
|
||||||
|
self,
|
||||||
|
"Delete Model",
|
||||||
|
f"Deleted {deleted_count} model(s) from database." + ("\n\n" + "\n".join(details) if details else ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _delete_model(self, model_id: int) -> None:
|
||||||
|
"""Delete a model from the database and remove its artifacts from disk."""
|
||||||
|
|
||||||
|
model = None
|
||||||
|
try:
|
||||||
|
model = self.db_manager.get_model_by_id(model_id)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Failed to load model {model_id} before deletion: {exc}")
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
QMessageBox.warning(self, "Delete Model", "Selected model was not found in the database.")
|
||||||
|
return
|
||||||
|
|
||||||
|
model_path = str(model.get("model_path") or "")
|
||||||
|
|
||||||
|
try:
|
||||||
|
deleted = self.db_manager.delete_model(model_id)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Failed to delete model {model_id}: {exc}")
|
||||||
|
QMessageBox.critical(self, "Delete Model", f"Failed to delete model from database:\n{exc}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not deleted:
|
||||||
|
QMessageBox.warning(self, "Delete Model", "No model was deleted (it may have already been removed).")
|
||||||
|
return
|
||||||
|
|
||||||
|
removed_paths, remove_errors = self._delete_model_artifacts_from_disk(model)
|
||||||
|
|
||||||
|
# Refresh tabs to reflect the deletion.
|
||||||
|
try:
|
||||||
|
if hasattr(self, "detection_tab"):
|
||||||
|
self.detection_tab.refresh()
|
||||||
|
if hasattr(self, "results_tab"):
|
||||||
|
self.results_tab.refresh()
|
||||||
|
if hasattr(self, "validation_tab"):
|
||||||
|
self.validation_tab.refresh()
|
||||||
|
if hasattr(self, "training_tab"):
|
||||||
|
self.training_tab.refresh()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"Failed to refresh tabs after model deletion: {exc}")
|
||||||
|
|
||||||
|
details = []
|
||||||
|
if model_path:
|
||||||
|
details.append(f"Deleted model record for: {model_path}")
|
||||||
|
if removed_paths:
|
||||||
|
details.append("\nRemoved from disk:\n" + "\n".join(removed_paths))
|
||||||
|
if remove_errors:
|
||||||
|
details.append("\nDisk cleanup warnings:\n" + "\n".join(remove_errors))
|
||||||
|
|
||||||
|
QMessageBox.information(
|
||||||
|
self,
|
||||||
|
"Delete Model",
|
||||||
|
"Model deleted from database." + ("\n\n" + "\n".join(details) if details else ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _delete_model_artifacts_from_disk(self, model: dict) -> tuple[list[str], list[str]]:
|
||||||
|
"""Best-effort removal of model artifacts on disk.
|
||||||
|
|
||||||
|
Strategy:
|
||||||
|
- Remove run directories inferred from:
|
||||||
|
- model.model_path (…/<run>/weights/*.pt => <run>)
|
||||||
|
- training_params.stage_results[].results.save_dir
|
||||||
|
but only if they are under the configured models directory.
|
||||||
|
- If the weights file itself exists and is outside the models directory, delete only the file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(removed_paths, errors)
|
||||||
|
"""
|
||||||
|
|
||||||
|
removed: list[str] = []
|
||||||
|
errors: list[str] = []
|
||||||
|
|
||||||
|
models_root = Path(self.config_manager.get_models_directory() or "data/models").expanduser()
|
||||||
|
try:
|
||||||
|
models_root_resolved = models_root.resolve()
|
||||||
|
except Exception:
|
||||||
|
models_root_resolved = models_root
|
||||||
|
|
||||||
|
inferred_dirs: list[Path] = []
|
||||||
|
|
||||||
|
# 1) From model_path
|
||||||
|
model_path_value = model.get("model_path")
|
||||||
|
if model_path_value:
|
||||||
|
try:
|
||||||
|
p = Path(str(model_path_value)).expanduser()
|
||||||
|
p_resolved = p.resolve() if p.exists() else p
|
||||||
|
if p_resolved.is_file():
|
||||||
|
if p_resolved.parent.name == "weights" and p_resolved.parent.parent.exists():
|
||||||
|
inferred_dirs.append(p_resolved.parent.parent)
|
||||||
|
elif p_resolved.parent.exists():
|
||||||
|
inferred_dirs.append(p_resolved.parent)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 2) From training_params.stage_results[].results.save_dir
|
||||||
|
training_params = model.get("training_params") or {}
|
||||||
|
if isinstance(training_params, dict):
|
||||||
|
stage_results = training_params.get("stage_results")
|
||||||
|
if isinstance(stage_results, list):
|
||||||
|
for stage in stage_results:
|
||||||
|
results = (stage or {}).get("results")
|
||||||
|
save_dir = (results or {}).get("save_dir") if isinstance(results, dict) else None
|
||||||
|
if not save_dir:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
d = Path(str(save_dir)).expanduser()
|
||||||
|
if d.exists() and d.is_dir():
|
||||||
|
inferred_dirs.append(d)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Deduplicate inferred_dirs
|
||||||
|
unique_dirs: list[Path] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for d in inferred_dirs:
|
||||||
|
try:
|
||||||
|
key = str(d.resolve())
|
||||||
|
except Exception:
|
||||||
|
key = str(d)
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
unique_dirs.append(d)
|
||||||
|
|
||||||
|
# Delete directories under models_root
|
||||||
|
for d in unique_dirs:
|
||||||
|
try:
|
||||||
|
d_resolved = d.resolve()
|
||||||
|
except Exception:
|
||||||
|
d_resolved = d
|
||||||
|
try:
|
||||||
|
if d_resolved.exists() and d_resolved.is_dir() and d_resolved.is_relative_to(models_root_resolved):
|
||||||
|
shutil.rmtree(d_resolved)
|
||||||
|
removed.append(str(d_resolved))
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"Failed to remove directory {d_resolved}: {exc}")
|
||||||
|
|
||||||
|
# If nothing matched (e.g., model_path outside models_root), delete just the file.
|
||||||
|
if model_path_value:
|
||||||
|
try:
|
||||||
|
p = Path(str(model_path_value)).expanduser()
|
||||||
|
if p.exists() and p.is_file():
|
||||||
|
p_resolved = p.resolve()
|
||||||
|
if not p_resolved.is_relative_to(models_root_resolved):
|
||||||
|
p_resolved.unlink()
|
||||||
|
removed.append(str(p_resolved))
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"Failed to remove model file {model_path_value}: {exc}")
|
||||||
|
|
||||||
|
return removed, errors
|
||||||
|
|
||||||
def _show_about(self):
|
def _show_about(self):
|
||||||
"""Show about dialog."""
|
"""Show about dialog."""
|
||||||
about_text = """
|
about_text = """
|
||||||
@@ -301,6 +539,11 @@ class MainWindow(QMainWindow):
|
|||||||
if hasattr(self, "training_tab"):
|
if hasattr(self, "training_tab"):
|
||||||
self.training_tab.shutdown()
|
self.training_tab.shutdown()
|
||||||
if hasattr(self, "annotation_tab"):
|
if hasattr(self, "annotation_tab"):
|
||||||
|
# Best-effort refresh so DB-backed UI state is consistent at shutdown.
|
||||||
|
try:
|
||||||
|
self.annotation_tab.refresh()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
self.annotation_tab.save_state()
|
self.annotation_tab.save_state()
|
||||||
|
|
||||||
logger.info("Application closing")
|
logger.info("Application closing")
|
||||||
|
|||||||
@@ -13,6 +13,11 @@ from PySide6.QtWidgets import (
|
|||||||
QFileDialog,
|
QFileDialog,
|
||||||
QMessageBox,
|
QMessageBox,
|
||||||
QSplitter,
|
QSplitter,
|
||||||
|
QLineEdit,
|
||||||
|
QTableWidget,
|
||||||
|
QTableWidgetItem,
|
||||||
|
QHeaderView,
|
||||||
|
QAbstractItemView,
|
||||||
)
|
)
|
||||||
from PySide6.QtCore import Qt, QSettings
|
from PySide6.QtCore import Qt, QSettings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -29,9 +34,7 @@ logger = get_logger(__name__)
|
|||||||
class AnnotationTab(QWidget):
|
class AnnotationTab(QWidget):
|
||||||
"""Annotation tab for manual image annotation."""
|
"""Annotation tab for manual image annotation."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
|
||||||
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
|
||||||
):
|
|
||||||
super().__init__(parent)
|
super().__init__(parent)
|
||||||
self.db_manager = db_manager
|
self.db_manager = db_manager
|
||||||
self.config_manager = config_manager
|
self.config_manager = config_manager
|
||||||
@@ -52,6 +55,32 @@ class AnnotationTab(QWidget):
|
|||||||
self.main_splitter = QSplitter(Qt.Horizontal)
|
self.main_splitter = QSplitter(Qt.Horizontal)
|
||||||
self.main_splitter.setHandleWidth(10)
|
self.main_splitter.setHandleWidth(10)
|
||||||
|
|
||||||
|
# { Left-most pane: annotated images list
|
||||||
|
annotated_group = QGroupBox("Annotated Images")
|
||||||
|
annotated_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
filter_row = QHBoxLayout()
|
||||||
|
filter_row.addWidget(QLabel("Filter:"))
|
||||||
|
self.annotated_filter_edit = QLineEdit()
|
||||||
|
self.annotated_filter_edit.setPlaceholderText("Type to filter by image name…")
|
||||||
|
self.annotated_filter_edit.textChanged.connect(self._refresh_annotated_images_list)
|
||||||
|
filter_row.addWidget(self.annotated_filter_edit, 1)
|
||||||
|
annotated_layout.addLayout(filter_row)
|
||||||
|
|
||||||
|
self.annotated_images_table = QTableWidget(0, 2)
|
||||||
|
self.annotated_images_table.setHorizontalHeaderLabels(["Image", "Annotations"])
|
||||||
|
self.annotated_images_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Stretch)
|
||||||
|
self.annotated_images_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeToContents)
|
||||||
|
self.annotated_images_table.setSelectionBehavior(QAbstractItemView.SelectRows)
|
||||||
|
self.annotated_images_table.setSelectionMode(QAbstractItemView.SingleSelection)
|
||||||
|
self.annotated_images_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
|
||||||
|
self.annotated_images_table.setSortingEnabled(True)
|
||||||
|
self.annotated_images_table.itemSelectionChanged.connect(self._on_annotated_image_selected)
|
||||||
|
annotated_layout.addWidget(self.annotated_images_table, 1)
|
||||||
|
|
||||||
|
annotated_group.setLayout(annotated_layout)
|
||||||
|
# }
|
||||||
|
|
||||||
# { Left splitter for image display and zoom info
|
# { Left splitter for image display and zoom info
|
||||||
self.left_splitter = QSplitter(Qt.Vertical)
|
self.left_splitter = QSplitter(Qt.Vertical)
|
||||||
self.left_splitter.setHandleWidth(10)
|
self.left_splitter.setHandleWidth(10)
|
||||||
@@ -62,6 +91,9 @@ class AnnotationTab(QWidget):
|
|||||||
|
|
||||||
# Use the AnnotationCanvasWidget
|
# Use the AnnotationCanvasWidget
|
||||||
self.annotation_canvas = AnnotationCanvasWidget()
|
self.annotation_canvas = AnnotationCanvasWidget()
|
||||||
|
# Auto-zoom so newly loaded images fill the available canvas viewport.
|
||||||
|
# (Matches the behavior used in ResultsTab.)
|
||||||
|
self.annotation_canvas.set_auto_fit_to_view(True)
|
||||||
self.annotation_canvas.zoom_changed.connect(self._on_zoom_changed)
|
self.annotation_canvas.zoom_changed.connect(self._on_zoom_changed)
|
||||||
self.annotation_canvas.annotation_drawn.connect(self._on_annotation_drawn)
|
self.annotation_canvas.annotation_drawn.connect(self._on_annotation_drawn)
|
||||||
# Selection of existing polylines (when tool is not in drawing mode)
|
# Selection of existing polylines (when tool is not in drawing mode)
|
||||||
@@ -72,9 +104,7 @@ class AnnotationTab(QWidget):
|
|||||||
self.left_splitter.addWidget(canvas_group)
|
self.left_splitter.addWidget(canvas_group)
|
||||||
|
|
||||||
# Controls info
|
# Controls info
|
||||||
controls_info = QLabel(
|
controls_info = QLabel("Zoom: Mouse wheel or +/- keys | Drawing: Enable pen and drag mouse")
|
||||||
"Zoom: Mouse wheel or +/- keys | Drawing: Enable pen and drag mouse"
|
|
||||||
)
|
|
||||||
controls_info.setStyleSheet("QLabel { color: #888; font-style: italic; }")
|
controls_info.setStyleSheet("QLabel { color: #888; font-style: italic; }")
|
||||||
self.left_splitter.addWidget(controls_info)
|
self.left_splitter.addWidget(controls_info)
|
||||||
# }
|
# }
|
||||||
@@ -85,36 +115,20 @@ class AnnotationTab(QWidget):
|
|||||||
|
|
||||||
# Annotation tools section
|
# Annotation tools section
|
||||||
self.annotation_tools = AnnotationToolsWidget(self.db_manager)
|
self.annotation_tools = AnnotationToolsWidget(self.db_manager)
|
||||||
self.annotation_tools.polyline_enabled_changed.connect(
|
self.annotation_tools.polyline_enabled_changed.connect(self.annotation_canvas.set_polyline_enabled)
|
||||||
self.annotation_canvas.set_polyline_enabled
|
self.annotation_tools.polyline_pen_color_changed.connect(self.annotation_canvas.set_polyline_pen_color)
|
||||||
)
|
self.annotation_tools.polyline_pen_width_changed.connect(self.annotation_canvas.set_polyline_pen_width)
|
||||||
self.annotation_tools.polyline_pen_color_changed.connect(
|
|
||||||
self.annotation_canvas.set_polyline_pen_color
|
|
||||||
)
|
|
||||||
self.annotation_tools.polyline_pen_width_changed.connect(
|
|
||||||
self.annotation_canvas.set_polyline_pen_width
|
|
||||||
)
|
|
||||||
# Show / hide bounding boxes
|
# Show / hide bounding boxes
|
||||||
self.annotation_tools.show_bboxes_changed.connect(
|
self.annotation_tools.show_bboxes_changed.connect(self.annotation_canvas.set_show_bboxes)
|
||||||
self.annotation_canvas.set_show_bboxes
|
|
||||||
)
|
|
||||||
# RDP simplification controls
|
# RDP simplification controls
|
||||||
self.annotation_tools.simplify_on_finish_changed.connect(
|
self.annotation_tools.simplify_on_finish_changed.connect(self._on_simplify_on_finish_changed)
|
||||||
self._on_simplify_on_finish_changed
|
self.annotation_tools.simplify_epsilon_changed.connect(self._on_simplify_epsilon_changed)
|
||||||
)
|
|
||||||
self.annotation_tools.simplify_epsilon_changed.connect(
|
|
||||||
self._on_simplify_epsilon_changed
|
|
||||||
)
|
|
||||||
# Class selection and class-color changes
|
# Class selection and class-color changes
|
||||||
self.annotation_tools.class_selected.connect(self._on_class_selected)
|
self.annotation_tools.class_selected.connect(self._on_class_selected)
|
||||||
self.annotation_tools.class_color_changed.connect(self._on_class_color_changed)
|
self.annotation_tools.class_color_changed.connect(self._on_class_color_changed)
|
||||||
self.annotation_tools.clear_annotations_requested.connect(
|
self.annotation_tools.clear_annotations_requested.connect(self._on_clear_annotations)
|
||||||
self._on_clear_annotations
|
|
||||||
)
|
|
||||||
# Delete selected annotation on canvas
|
# Delete selected annotation on canvas
|
||||||
self.annotation_tools.delete_selected_annotation_requested.connect(
|
self.annotation_tools.delete_selected_annotation_requested.connect(self._on_delete_selected_annotation)
|
||||||
self._on_delete_selected_annotation
|
|
||||||
)
|
|
||||||
self.right_splitter.addWidget(self.annotation_tools)
|
self.right_splitter.addWidget(self.annotation_tools)
|
||||||
|
|
||||||
# Image loading section
|
# Image loading section
|
||||||
@@ -137,12 +151,13 @@ class AnnotationTab(QWidget):
|
|||||||
self.right_splitter.addWidget(load_group)
|
self.right_splitter.addWidget(load_group)
|
||||||
# }
|
# }
|
||||||
|
|
||||||
# Add both splitters to the main horizontal splitter
|
# Add list + both splitters to the main horizontal splitter
|
||||||
|
self.main_splitter.addWidget(annotated_group)
|
||||||
self.main_splitter.addWidget(self.left_splitter)
|
self.main_splitter.addWidget(self.left_splitter)
|
||||||
self.main_splitter.addWidget(self.right_splitter)
|
self.main_splitter.addWidget(self.right_splitter)
|
||||||
|
|
||||||
# Set initial sizes: 75% for left (image), 25% for right (controls)
|
# Set initial sizes: list (left), canvas (middle), controls (right)
|
||||||
self.main_splitter.setSizes([750, 250])
|
self.main_splitter.setSizes([320, 650, 280])
|
||||||
|
|
||||||
layout.addWidget(self.main_splitter)
|
layout.addWidget(self.main_splitter)
|
||||||
self.setLayout(layout)
|
self.setLayout(layout)
|
||||||
@@ -150,6 +165,9 @@ class AnnotationTab(QWidget):
|
|||||||
# Restore splitter positions from settings
|
# Restore splitter positions from settings
|
||||||
self._restore_state()
|
self._restore_state()
|
||||||
|
|
||||||
|
# Populate list on startup.
|
||||||
|
self._refresh_annotated_images_list()
|
||||||
|
|
||||||
def _load_image(self):
|
def _load_image(self):
|
||||||
"""Load and display an image file."""
|
"""Load and display an image file."""
|
||||||
# Get last opened directory from QSettings
|
# Get last opened directory from QSettings
|
||||||
@@ -180,12 +198,24 @@ class AnnotationTab(QWidget):
|
|||||||
self.current_image_path = file_path
|
self.current_image_path = file_path
|
||||||
|
|
||||||
# Store the directory for next time
|
# Store the directory for next time
|
||||||
settings.setValue(
|
settings.setValue("annotation_tab/last_directory", str(Path(file_path).parent))
|
||||||
"annotation_tab/last_directory", str(Path(file_path).parent)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get or create image in database
|
# Get or create image in database
|
||||||
relative_path = str(Path(file_path).name) # Simplified for now
|
repo_root = self.config_manager.get_image_repository_path()
|
||||||
|
relative_path: str
|
||||||
|
try:
|
||||||
|
if repo_root:
|
||||||
|
repo_root_path = Path(repo_root).expanduser().resolve()
|
||||||
|
file_resolved = Path(file_path).expanduser().resolve()
|
||||||
|
if file_resolved.is_relative_to(repo_root_path):
|
||||||
|
relative_path = file_resolved.relative_to(repo_root_path).as_posix()
|
||||||
|
else:
|
||||||
|
# Fallback: store filename only to avoid leaking absolute paths.
|
||||||
|
relative_path = file_resolved.name
|
||||||
|
else:
|
||||||
|
relative_path = str(Path(file_path).name)
|
||||||
|
except Exception:
|
||||||
|
relative_path = str(Path(file_path).name)
|
||||||
self.current_image_id = self.db_manager.get_or_create_image(
|
self.current_image_id = self.db_manager.get_or_create_image(
|
||||||
relative_path,
|
relative_path,
|
||||||
Path(file_path).name,
|
Path(file_path).name,
|
||||||
@@ -199,6 +229,9 @@ class AnnotationTab(QWidget):
|
|||||||
# Load and display any existing annotations for this image
|
# Load and display any existing annotations for this image
|
||||||
self._load_annotations_for_current_image()
|
self._load_annotations_for_current_image()
|
||||||
|
|
||||||
|
# Update annotated images list (newly annotated image added/selected).
|
||||||
|
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
|
||||||
|
|
||||||
# Update info label
|
# Update info label
|
||||||
self._update_image_info()
|
self._update_image_info()
|
||||||
|
|
||||||
@@ -206,9 +239,7 @@ class AnnotationTab(QWidget):
|
|||||||
|
|
||||||
except ImageLoadError as e:
|
except ImageLoadError as e:
|
||||||
logger.error(f"Failed to load image: {e}")
|
logger.error(f"Failed to load image: {e}")
|
||||||
QMessageBox.critical(
|
QMessageBox.critical(self, "Error Loading Image", f"Failed to load image:\n{str(e)}")
|
||||||
self, "Error Loading Image", f"Failed to load image:\n{str(e)}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error loading image: {e}")
|
logger.error(f"Unexpected error loading image: {e}")
|
||||||
QMessageBox.critical(self, "Error", f"Unexpected error:\n{str(e)}")
|
QMessageBox.critical(self, "Error", f"Unexpected error:\n{str(e)}")
|
||||||
@@ -296,6 +327,9 @@ class AnnotationTab(QWidget):
|
|||||||
# Reload annotations from DB and redraw (respecting current class filter)
|
# Reload annotations from DB and redraw (respecting current class filter)
|
||||||
self._load_annotations_for_current_image()
|
self._load_annotations_for_current_image()
|
||||||
|
|
||||||
|
# Update list counts.
|
||||||
|
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to save annotation: {e}")
|
logger.error(f"Failed to save annotation: {e}")
|
||||||
QMessageBox.critical(self, "Error", f"Failed to save annotation:\n{str(e)}")
|
QMessageBox.critical(self, "Error", f"Failed to save annotation:\n{str(e)}")
|
||||||
@@ -340,9 +374,7 @@ class AnnotationTab(QWidget):
|
|||||||
if not self.current_image_id:
|
if not self.current_image_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(f"Class color changed; reloading annotations for image ID {self.current_image_id}")
|
||||||
f"Class color changed; reloading annotations for image ID {self.current_image_id}"
|
|
||||||
)
|
|
||||||
self._load_annotations_for_current_image()
|
self._load_annotations_for_current_image()
|
||||||
|
|
||||||
def _on_class_selected(self, class_data):
|
def _on_class_selected(self, class_data):
|
||||||
@@ -355,9 +387,7 @@ class AnnotationTab(QWidget):
|
|||||||
if class_data:
|
if class_data:
|
||||||
logger.debug(f"Object class selected: {class_data['class_name']}")
|
logger.debug(f"Object class selected: {class_data['class_name']}")
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
logger.debug('No class selected ("-- Select Class --"), showing all annotations')
|
||||||
'No class selected ("-- Select Class --"), showing all annotations'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Changing the class filter invalidates any previous selection
|
# Changing the class filter invalidates any previous selection
|
||||||
self.selected_annotation_ids = []
|
self.selected_annotation_ids = []
|
||||||
@@ -390,9 +420,7 @@ class AnnotationTab(QWidget):
|
|||||||
question = "Are you sure you want to delete the selected annotation?"
|
question = "Are you sure you want to delete the selected annotation?"
|
||||||
title = "Delete Annotation"
|
title = "Delete Annotation"
|
||||||
else:
|
else:
|
||||||
question = (
|
question = f"Are you sure you want to delete the {count} selected annotations?"
|
||||||
f"Are you sure you want to delete the {count} selected annotations?"
|
|
||||||
)
|
|
||||||
title = "Delete Annotations"
|
title = "Delete Annotations"
|
||||||
|
|
||||||
reply = QMessageBox.question(
|
reply = QMessageBox.question(
|
||||||
@@ -420,13 +448,11 @@ class AnnotationTab(QWidget):
|
|||||||
QMessageBox.warning(
|
QMessageBox.warning(
|
||||||
self,
|
self,
|
||||||
"Partial Failure",
|
"Partial Failure",
|
||||||
"Some annotations could not be deleted:\n"
|
"Some annotations could not be deleted:\n" + ", ".join(str(a) for a in failed_ids),
|
||||||
+ ", ".join(str(a) for a in failed_ids),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Deleted {count} annotation(s): "
|
f"Deleted {count} annotation(s): " + ", ".join(str(a) for a in self.selected_annotation_ids)
|
||||||
+ ", ".join(str(a) for a in self.selected_annotation_ids)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Clear selection and reload annotations for the current image from DB
|
# Clear selection and reload annotations for the current image from DB
|
||||||
@@ -434,6 +460,9 @@ class AnnotationTab(QWidget):
|
|||||||
self.annotation_tools.set_has_selected_annotation(False)
|
self.annotation_tools.set_has_selected_annotation(False)
|
||||||
self._load_annotations_for_current_image()
|
self._load_annotations_for_current_image()
|
||||||
|
|
||||||
|
# Update list counts.
|
||||||
|
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to delete annotations: {e}")
|
logger.error(f"Failed to delete annotations: {e}")
|
||||||
QMessageBox.critical(
|
QMessageBox.critical(
|
||||||
@@ -456,17 +485,13 @@ class AnnotationTab(QWidget):
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.current_annotations = self.db_manager.get_annotations_for_image(
|
self.current_annotations = self.db_manager.get_annotations_for_image(self.current_image_id)
|
||||||
self.current_image_id
|
|
||||||
)
|
|
||||||
# New annotations loaded; reset any selection
|
# New annotations loaded; reset any selection
|
||||||
self.selected_annotation_ids = []
|
self.selected_annotation_ids = []
|
||||||
self.annotation_tools.set_has_selected_annotation(False)
|
self.annotation_tools.set_has_selected_annotation(False)
|
||||||
self._redraw_annotations_for_current_filter()
|
self._redraw_annotations_for_current_filter()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(f"Failed to load annotations for image {self.current_image_id}: {e}")
|
||||||
f"Failed to load annotations for image {self.current_image_id}: {e}"
|
|
||||||
)
|
|
||||||
QMessageBox.critical(
|
QMessageBox.critical(
|
||||||
self,
|
self,
|
||||||
"Error",
|
"Error",
|
||||||
@@ -490,10 +515,7 @@ class AnnotationTab(QWidget):
|
|||||||
drawn_count = 0
|
drawn_count = 0
|
||||||
for ann in self.current_annotations:
|
for ann in self.current_annotations:
|
||||||
# Filter by class if one is selected
|
# Filter by class if one is selected
|
||||||
if (
|
if selected_class_id is not None and ann.get("class_id") != selected_class_id:
|
||||||
selected_class_id is not None
|
|
||||||
and ann.get("class_id") != selected_class_id
|
|
||||||
):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if ann.get("segmentation_mask"):
|
if ann.get("segmentation_mask"):
|
||||||
@@ -545,22 +567,176 @@ class AnnotationTab(QWidget):
|
|||||||
settings = QSettings("microscopy_app", "object_detection")
|
settings = QSettings("microscopy_app", "object_detection")
|
||||||
|
|
||||||
# Save main splitter state
|
# Save main splitter state
|
||||||
settings.setValue(
|
settings.setValue("annotation_tab/main_splitter_state", self.main_splitter.saveState())
|
||||||
"annotation_tab/main_splitter_state", self.main_splitter.saveState()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save left splitter state
|
# Save left splitter state
|
||||||
settings.setValue(
|
settings.setValue("annotation_tab/left_splitter_state", self.left_splitter.saveState())
|
||||||
"annotation_tab/left_splitter_state", self.left_splitter.saveState()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save right splitter state
|
# Save right splitter state
|
||||||
settings.setValue(
|
settings.setValue("annotation_tab/right_splitter_state", self.right_splitter.saveState())
|
||||||
"annotation_tab/right_splitter_state", self.right_splitter.saveState()
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug("Saved annotation tab splitter states")
|
logger.debug("Saved annotation tab splitter states")
|
||||||
|
|
||||||
def refresh(self):
|
def refresh(self):
|
||||||
"""Refresh the tab."""
|
"""Refresh the tab."""
|
||||||
|
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
|
||||||
|
|
||||||
|
# ==================== Annotated images list ====================
|
||||||
|
|
||||||
|
def _refresh_annotated_images_list(self, select_image_id: int | None = None) -> None:
|
||||||
|
"""Reload annotated-images list from the database."""
|
||||||
|
if not hasattr(self, "annotated_images_table"):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Preserve selection if possible
|
||||||
|
desired_id = select_image_id if select_image_id is not None else self.current_image_id
|
||||||
|
|
||||||
|
name_filter = ""
|
||||||
|
if hasattr(self, "annotated_filter_edit"):
|
||||||
|
name_filter = self.annotated_filter_edit.text().strip()
|
||||||
|
|
||||||
|
try:
|
||||||
|
rows = self.db_manager.get_annotated_images_summary(name_filter=name_filter)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Failed to load annotated images summary: {exc}")
|
||||||
|
rows = []
|
||||||
|
|
||||||
|
sorting_enabled = self.annotated_images_table.isSortingEnabled()
|
||||||
|
self.annotated_images_table.setSortingEnabled(False)
|
||||||
|
self.annotated_images_table.blockSignals(True)
|
||||||
|
try:
|
||||||
|
self.annotated_images_table.setRowCount(len(rows))
|
||||||
|
for r, entry in enumerate(rows):
|
||||||
|
image_name = str(entry.get("filename") or "")
|
||||||
|
count = int(entry.get("annotation_count") or 0)
|
||||||
|
rel_path = str(entry.get("relative_path") or "")
|
||||||
|
|
||||||
|
name_item = QTableWidgetItem(image_name)
|
||||||
|
# Tooltip shows full path of the image (best-effort: repository_root + relative_path)
|
||||||
|
full_path = rel_path
|
||||||
|
repo_root = self.config_manager.get_image_repository_path()
|
||||||
|
if repo_root and rel_path and not Path(rel_path).is_absolute():
|
||||||
|
try:
|
||||||
|
full_path = str((Path(repo_root) / rel_path).resolve())
|
||||||
|
except Exception:
|
||||||
|
full_path = str(Path(repo_root) / rel_path)
|
||||||
|
name_item.setToolTip(full_path)
|
||||||
|
name_item.setData(Qt.UserRole, int(entry.get("id")))
|
||||||
|
name_item.setData(Qt.UserRole + 1, rel_path)
|
||||||
|
|
||||||
|
count_item = QTableWidgetItem()
|
||||||
|
# Use EditRole to ensure numeric sorting.
|
||||||
|
count_item.setData(Qt.EditRole, count)
|
||||||
|
count_item.setData(Qt.UserRole, int(entry.get("id")))
|
||||||
|
count_item.setData(Qt.UserRole + 1, rel_path)
|
||||||
|
|
||||||
|
self.annotated_images_table.setItem(r, 0, name_item)
|
||||||
|
self.annotated_images_table.setItem(r, 1, count_item)
|
||||||
|
|
||||||
|
# Re-select desired row
|
||||||
|
if desired_id is not None:
|
||||||
|
for r in range(self.annotated_images_table.rowCount()):
|
||||||
|
item = self.annotated_images_table.item(r, 0)
|
||||||
|
if item and item.data(Qt.UserRole) == desired_id:
|
||||||
|
self.annotated_images_table.selectRow(r)
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
self.annotated_images_table.blockSignals(False)
|
||||||
|
self.annotated_images_table.setSortingEnabled(sorting_enabled)
|
||||||
|
|
||||||
|
def _on_annotated_image_selected(self) -> None:
|
||||||
|
"""When user clicks an item in the list, load that image in the annotation canvas."""
|
||||||
|
selected = self.annotated_images_table.selectedItems()
|
||||||
|
if not selected:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Row selection -> take the first column item
|
||||||
|
row = self.annotated_images_table.currentRow()
|
||||||
|
item = self.annotated_images_table.item(row, 0)
|
||||||
|
if not item:
|
||||||
|
return
|
||||||
|
|
||||||
|
image_id = item.data(Qt.UserRole)
|
||||||
|
rel_path = item.data(Qt.UserRole + 1) or ""
|
||||||
|
if not image_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
image_path = self._resolve_image_path_for_relative_path(rel_path)
|
||||||
|
if not image_path:
|
||||||
|
QMessageBox.warning(
|
||||||
|
self,
|
||||||
|
"Image Not Found",
|
||||||
|
"Unable to locate image on disk for:\n"
|
||||||
|
f"{rel_path}\n\n"
|
||||||
|
"Tip: set Settings → Image repository path to the folder containing your images.",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.current_image = Image(image_path)
|
||||||
|
self.current_image_path = image_path
|
||||||
|
self.current_image_id = int(image_id)
|
||||||
|
self.annotation_canvas.load_image(self.current_image)
|
||||||
|
self._load_annotations_for_current_image()
|
||||||
|
self._update_image_info()
|
||||||
|
except ImageLoadError as exc:
|
||||||
|
logger.error(f"Failed to load image '{image_path}': {exc}")
|
||||||
|
QMessageBox.critical(self, "Error Loading Image", f"Failed to load image:\n{exc}")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Unexpected error loading image '{image_path}': {exc}")
|
||||||
|
QMessageBox.critical(self, "Error", f"Unexpected error:\n{exc}")
|
||||||
|
|
||||||
|
def _resolve_image_path_for_relative_path(self, relative_path: str) -> str | None:
|
||||||
|
"""Best-effort conversion from a DB relative_path to an on-disk file path."""
|
||||||
|
|
||||||
|
rel = (relative_path or "").strip()
|
||||||
|
if not rel:
|
||||||
|
return None
|
||||||
|
|
||||||
|
candidates: list[Path] = []
|
||||||
|
|
||||||
|
# 1) Repository root + relative
|
||||||
|
repo_root = (self.config_manager.get_image_repository_path() or "").strip()
|
||||||
|
if repo_root:
|
||||||
|
candidates.append(Path(repo_root) / rel)
|
||||||
|
|
||||||
|
# 2) If the DB path is absolute, try it directly.
|
||||||
|
candidates.append(Path(rel))
|
||||||
|
|
||||||
|
# 3) Try the directory of the currently loaded image (helps when DB stores only filenames)
|
||||||
|
if self.current_image_path:
|
||||||
|
try:
|
||||||
|
candidates.append(Path(self.current_image_path).expanduser().resolve().parent / Path(rel).name)
|
||||||
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# 4) Try the last directory used by the annotation file picker
|
||||||
|
try:
|
||||||
|
settings = QSettings("microscopy_app", "object_detection")
|
||||||
|
last_dir = settings.value("annotation_tab/last_directory", None)
|
||||||
|
if last_dir:
|
||||||
|
candidates.append(Path(str(last_dir)) / Path(rel).name)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
for p in candidates:
|
||||||
|
try:
|
||||||
|
expanded = p.expanduser()
|
||||||
|
if expanded.exists() and expanded.is_file():
|
||||||
|
return str(expanded.resolve())
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 5) Fallback: search by filename within repository root.
|
||||||
|
filename = Path(rel).name
|
||||||
|
if repo_root and filename:
|
||||||
|
root = Path(repo_root).expanduser()
|
||||||
|
try:
|
||||||
|
if root.exists():
|
||||||
|
for match in root.rglob(filename):
|
||||||
|
if match.is_file():
|
||||||
|
return str(match.resolve())
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug(f"Search for {filename} under {root} failed: {exc}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ Results tab for browsing stored detections and visualizing overlays.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from PySide6.QtWidgets import (
|
from PySide6.QtWidgets import (
|
||||||
QWidget,
|
QWidget,
|
||||||
@@ -35,9 +35,7 @@ logger = get_logger(__name__)
|
|||||||
class ResultsTab(QWidget):
|
class ResultsTab(QWidget):
|
||||||
"""Results tab showing detection history and preview overlays."""
|
"""Results tab showing detection history and preview overlays."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
|
||||||
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
|
||||||
):
|
|
||||||
super().__init__(parent)
|
super().__init__(parent)
|
||||||
self.db_manager = db_manager
|
self.db_manager = db_manager
|
||||||
self.config_manager = config_manager
|
self.config_manager = config_manager
|
||||||
@@ -67,28 +65,32 @@ class ResultsTab(QWidget):
|
|||||||
self.refresh_btn = QPushButton("Refresh")
|
self.refresh_btn = QPushButton("Refresh")
|
||||||
self.refresh_btn.clicked.connect(self.refresh)
|
self.refresh_btn.clicked.connect(self.refresh)
|
||||||
controls_layout.addWidget(self.refresh_btn)
|
controls_layout.addWidget(self.refresh_btn)
|
||||||
|
|
||||||
|
self.delete_all_btn = QPushButton("Delete All Detections")
|
||||||
|
self.delete_all_btn.setToolTip(
|
||||||
|
"Permanently delete ALL detections from the database.\n" "This cannot be undone."
|
||||||
|
)
|
||||||
|
self.delete_all_btn.clicked.connect(self._delete_all_detections)
|
||||||
|
controls_layout.addWidget(self.delete_all_btn)
|
||||||
|
|
||||||
|
self.export_labels_btn = QPushButton("Export Labels")
|
||||||
|
self.export_labels_btn.setToolTip(
|
||||||
|
"Export YOLO .txt labels for the selected image/model run.\n"
|
||||||
|
"Output path is inferred from the image path (images/ -> labels/)."
|
||||||
|
)
|
||||||
|
self.export_labels_btn.clicked.connect(self._export_labels_for_current_selection)
|
||||||
|
controls_layout.addWidget(self.export_labels_btn)
|
||||||
|
|
||||||
controls_layout.addStretch()
|
controls_layout.addStretch()
|
||||||
left_layout.addLayout(controls_layout)
|
left_layout.addLayout(controls_layout)
|
||||||
|
|
||||||
self.results_table = QTableWidget(0, 5)
|
self.results_table = QTableWidget(0, 5)
|
||||||
self.results_table.setHorizontalHeaderLabels(
|
self.results_table.setHorizontalHeaderLabels(["Image", "Model", "Detections", "Classes", "Last Updated"])
|
||||||
["Image", "Model", "Detections", "Classes", "Last Updated"]
|
self.results_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Stretch)
|
||||||
)
|
self.results_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.Stretch)
|
||||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
self.results_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeToContents)
|
||||||
0, QHeaderView.Stretch
|
self.results_table.horizontalHeader().setSectionResizeMode(3, QHeaderView.Stretch)
|
||||||
)
|
self.results_table.horizontalHeader().setSectionResizeMode(4, QHeaderView.ResizeToContents)
|
||||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
|
||||||
1, QHeaderView.Stretch
|
|
||||||
)
|
|
||||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
|
||||||
2, QHeaderView.ResizeToContents
|
|
||||||
)
|
|
||||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
|
||||||
3, QHeaderView.Stretch
|
|
||||||
)
|
|
||||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
|
||||||
4, QHeaderView.ResizeToContents
|
|
||||||
)
|
|
||||||
self.results_table.setSelectionBehavior(QAbstractItemView.SelectRows)
|
self.results_table.setSelectionBehavior(QAbstractItemView.SelectRows)
|
||||||
self.results_table.setSelectionMode(QAbstractItemView.SingleSelection)
|
self.results_table.setSelectionMode(QAbstractItemView.SingleSelection)
|
||||||
self.results_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
|
self.results_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
|
||||||
@@ -106,6 +108,8 @@ class ResultsTab(QWidget):
|
|||||||
preview_layout = QVBoxLayout()
|
preview_layout = QVBoxLayout()
|
||||||
|
|
||||||
self.preview_canvas = AnnotationCanvasWidget()
|
self.preview_canvas = AnnotationCanvasWidget()
|
||||||
|
# Auto-zoom so newly loaded images fill the available preview viewport.
|
||||||
|
self.preview_canvas.set_auto_fit_to_view(True)
|
||||||
self.preview_canvas.set_polyline_enabled(False)
|
self.preview_canvas.set_polyline_enabled(False)
|
||||||
self.preview_canvas.set_show_bboxes(True)
|
self.preview_canvas.set_show_bboxes(True)
|
||||||
preview_layout.addWidget(self.preview_canvas)
|
preview_layout.addWidget(self.preview_canvas)
|
||||||
@@ -119,9 +123,7 @@ class ResultsTab(QWidget):
|
|||||||
self.show_bboxes_checkbox.stateChanged.connect(self._toggle_bboxes)
|
self.show_bboxes_checkbox.stateChanged.connect(self._toggle_bboxes)
|
||||||
self.show_confidence_checkbox = QCheckBox("Show Confidence")
|
self.show_confidence_checkbox = QCheckBox("Show Confidence")
|
||||||
self.show_confidence_checkbox.setChecked(False)
|
self.show_confidence_checkbox.setChecked(False)
|
||||||
self.show_confidence_checkbox.stateChanged.connect(
|
self.show_confidence_checkbox.stateChanged.connect(self._apply_detection_overlays)
|
||||||
self._apply_detection_overlays
|
|
||||||
)
|
|
||||||
toggles_layout.addWidget(self.show_masks_checkbox)
|
toggles_layout.addWidget(self.show_masks_checkbox)
|
||||||
toggles_layout.addWidget(self.show_bboxes_checkbox)
|
toggles_layout.addWidget(self.show_bboxes_checkbox)
|
||||||
toggles_layout.addWidget(self.show_confidence_checkbox)
|
toggles_layout.addWidget(self.show_confidence_checkbox)
|
||||||
@@ -144,6 +146,41 @@ class ResultsTab(QWidget):
|
|||||||
layout.addWidget(splitter)
|
layout.addWidget(splitter)
|
||||||
self.setLayout(layout)
|
self.setLayout(layout)
|
||||||
|
|
||||||
|
def _delete_all_detections(self):
|
||||||
|
"""Delete all detections from the database after user confirmation."""
|
||||||
|
confirm = QMessageBox.warning(
|
||||||
|
self,
|
||||||
|
"Delete All Detections",
|
||||||
|
"This will permanently delete ALL detections from the database.\n\n"
|
||||||
|
"This action cannot be undone.\n\n"
|
||||||
|
"Do you want to continue?",
|
||||||
|
QMessageBox.Yes | QMessageBox.No,
|
||||||
|
QMessageBox.No,
|
||||||
|
)
|
||||||
|
|
||||||
|
if confirm != QMessageBox.Yes:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
deleted = self.db_manager.delete_all_detections()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Failed to delete all detections: {exc}")
|
||||||
|
QMessageBox.critical(
|
||||||
|
self,
|
||||||
|
"Error",
|
||||||
|
f"Failed to delete detections:\n{exc}",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
QMessageBox.information(
|
||||||
|
self,
|
||||||
|
"Delete All Detections",
|
||||||
|
f"Deleted {deleted} detection(s) from the database.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reset UI state.
|
||||||
|
self.refresh()
|
||||||
|
|
||||||
def refresh(self):
|
def refresh(self):
|
||||||
"""Refresh the detection list and preview."""
|
"""Refresh the detection list and preview."""
|
||||||
self._load_detection_summary()
|
self._load_detection_summary()
|
||||||
@@ -153,6 +190,8 @@ class ResultsTab(QWidget):
|
|||||||
self.current_detections = []
|
self.current_detections = []
|
||||||
self.preview_canvas.clear()
|
self.preview_canvas.clear()
|
||||||
self.summary_label.setText("Select a detection result to preview.")
|
self.summary_label.setText("Select a detection result to preview.")
|
||||||
|
if hasattr(self, "export_labels_btn"):
|
||||||
|
self.export_labels_btn.setEnabled(False)
|
||||||
|
|
||||||
def _load_detection_summary(self):
|
def _load_detection_summary(self):
|
||||||
"""Load latest detection summaries grouped by image + model."""
|
"""Load latest detection summaries grouped by image + model."""
|
||||||
@@ -169,8 +208,7 @@ class ResultsTab(QWidget):
|
|||||||
"image_id": det["image_id"],
|
"image_id": det["image_id"],
|
||||||
"model_id": det["model_id"],
|
"model_id": det["model_id"],
|
||||||
"image_path": det.get("image_path"),
|
"image_path": det.get("image_path"),
|
||||||
"image_filename": det.get("image_filename")
|
"image_filename": det.get("image_filename") or det.get("image_path"),
|
||||||
or det.get("image_path"),
|
|
||||||
"model_name": det.get("model_name", ""),
|
"model_name": det.get("model_name", ""),
|
||||||
"model_version": det.get("model_version", ""),
|
"model_version": det.get("model_version", ""),
|
||||||
"last_detected": det.get("detected_at"),
|
"last_detected": det.get("detected_at"),
|
||||||
@@ -183,8 +221,7 @@ class ResultsTab(QWidget):
|
|||||||
|
|
||||||
entry["count"] += 1
|
entry["count"] += 1
|
||||||
if det.get("detected_at") and (
|
if det.get("detected_at") and (
|
||||||
not entry.get("last_detected")
|
not entry.get("last_detected") or str(det.get("detected_at")) > str(entry.get("last_detected"))
|
||||||
or str(det.get("detected_at")) > str(entry.get("last_detected"))
|
|
||||||
):
|
):
|
||||||
entry["last_detected"] = det.get("detected_at")
|
entry["last_detected"] = det.get("detected_at")
|
||||||
if det.get("class_name"):
|
if det.get("class_name"):
|
||||||
@@ -214,9 +251,7 @@ class ResultsTab(QWidget):
|
|||||||
|
|
||||||
for row, entry in enumerate(self.detection_summary):
|
for row, entry in enumerate(self.detection_summary):
|
||||||
model_label = f"{entry['model_name']} {entry['model_version']}".strip()
|
model_label = f"{entry['model_name']} {entry['model_version']}".strip()
|
||||||
class_list = (
|
class_list = ", ".join(sorted(entry["classes"])) if entry["classes"] else "-"
|
||||||
", ".join(sorted(entry["classes"])) if entry["classes"] else "-"
|
|
||||||
)
|
|
||||||
|
|
||||||
items = [
|
items = [
|
||||||
QTableWidgetItem(entry.get("image_filename", "")),
|
QTableWidgetItem(entry.get("image_filename", "")),
|
||||||
@@ -276,6 +311,231 @@ class ResultsTab(QWidget):
|
|||||||
self._load_detections_for_selection(entry)
|
self._load_detections_for_selection(entry)
|
||||||
self._apply_detection_overlays()
|
self._apply_detection_overlays()
|
||||||
self._update_summary_label(entry)
|
self._update_summary_label(entry)
|
||||||
|
if hasattr(self, "export_labels_btn"):
|
||||||
|
self.export_labels_btn.setEnabled(True)
|
||||||
|
|
||||||
|
def _export_labels_for_current_selection(self):
|
||||||
|
"""Export YOLO label file(s) for the currently selected image/model."""
|
||||||
|
if not self.current_selection:
|
||||||
|
QMessageBox.information(self, "Export Labels", "Select a detection result first.")
|
||||||
|
return
|
||||||
|
|
||||||
|
entry = self.current_selection
|
||||||
|
|
||||||
|
image_path_str = self._resolve_image_path(entry)
|
||||||
|
if not image_path_str:
|
||||||
|
QMessageBox.warning(
|
||||||
|
self,
|
||||||
|
"Export Labels",
|
||||||
|
"Unable to locate the image file for this detection; cannot infer labels path.",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Ensure we have the detections for the selection.
|
||||||
|
if not self.current_detections:
|
||||||
|
self._load_detections_for_selection(entry)
|
||||||
|
|
||||||
|
if not self.current_detections:
|
||||||
|
QMessageBox.information(
|
||||||
|
self,
|
||||||
|
"Export Labels",
|
||||||
|
"No detections found for this image/model selection.",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
image_path = Path(image_path_str)
|
||||||
|
try:
|
||||||
|
label_path = self._infer_yolo_label_path(image_path)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Failed to infer label path for {image_path}: {exc}")
|
||||||
|
QMessageBox.critical(
|
||||||
|
self,
|
||||||
|
"Export Labels",
|
||||||
|
f"Failed to infer export path for labels:\n{exc}",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
class_map = self._build_detection_class_index_map(self.current_detections)
|
||||||
|
if not class_map:
|
||||||
|
QMessageBox.warning(
|
||||||
|
self,
|
||||||
|
"Export Labels",
|
||||||
|
"Unable to build class->index mapping (missing class names).",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
lines_written = 0
|
||||||
|
skipped = 0
|
||||||
|
label_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
try:
|
||||||
|
with open(label_path, "w", encoding="utf-8") as handle:
|
||||||
|
print("writing to", label_path)
|
||||||
|
for det in self.current_detections:
|
||||||
|
yolo_line = self._format_detection_as_yolo_line(det, class_map)
|
||||||
|
if not yolo_line:
|
||||||
|
skipped += 1
|
||||||
|
continue
|
||||||
|
handle.write(yolo_line + "\n")
|
||||||
|
lines_written += 1
|
||||||
|
except OSError as exc:
|
||||||
|
logger.error(f"Failed to write labels file {label_path}: {exc}")
|
||||||
|
QMessageBox.critical(
|
||||||
|
self,
|
||||||
|
"Export Labels",
|
||||||
|
f"Failed to write label file:\n{label_path}\n\n{exc}",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
return
|
||||||
|
# Optional: write a classes.txt next to the labels root to make the mapping discoverable.
|
||||||
|
# This is not required by Ultralytics (data.yaml usually holds class names), but helps reuse.
|
||||||
|
try:
|
||||||
|
classes_txt = label_path.parent.parent / "classes.txt"
|
||||||
|
classes_txt.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
inv = {idx: name for name, idx in class_map.items()}
|
||||||
|
with open(classes_txt, "w", encoding="utf-8") as handle:
|
||||||
|
for idx in range(len(inv)):
|
||||||
|
handle.write(f"{inv[idx]}\n")
|
||||||
|
except Exception:
|
||||||
|
# Non-fatal
|
||||||
|
pass
|
||||||
|
|
||||||
|
QMessageBox.information(
|
||||||
|
self,
|
||||||
|
"Export Labels",
|
||||||
|
f"Exported {lines_written} label line(s) to:\n{label_path}\n\nSkipped {skipped} invalid detection(s).",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _infer_yolo_label_path(self, image_path: Path) -> Path:
|
||||||
|
"""Infer a YOLO label path from an image path.
|
||||||
|
|
||||||
|
If the image lives under an `images/` directory (anywhere in the path), we mirror the
|
||||||
|
subpath under a sibling `labels/` directory at the same level.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
/dataset/train/images/sub/img.jpg -> /dataset/train/labels/sub/img.txt
|
||||||
|
"""
|
||||||
|
|
||||||
|
resolved = image_path.expanduser().resolve()
|
||||||
|
|
||||||
|
# Find the nearest ancestor directory named 'images'
|
||||||
|
images_dir: Optional[Path] = None
|
||||||
|
for parent in [resolved.parent, *resolved.parents]:
|
||||||
|
if parent.name.lower() == "images":
|
||||||
|
images_dir = parent
|
||||||
|
break
|
||||||
|
|
||||||
|
if images_dir is not None:
|
||||||
|
rel = resolved.relative_to(images_dir)
|
||||||
|
labels_dir = images_dir.parent / "labels"
|
||||||
|
return (labels_dir / rel).with_suffix(".txt")
|
||||||
|
|
||||||
|
# Fallback: create a local sibling labels folder next to the image.
|
||||||
|
return (resolved.parent / "labels" / resolved.name).with_suffix(".txt")
|
||||||
|
|
||||||
|
def _build_detection_class_index_map(self, detections: List[Dict]) -> Dict[str, int]:
|
||||||
|
"""Build a stable class_name -> YOLO class index mapping.
|
||||||
|
|
||||||
|
Preference order:
|
||||||
|
1) Database object_classes table (alphabetical class_name order)
|
||||||
|
2) Fallback to class_name values present in the detections (alphabetical)
|
||||||
|
"""
|
||||||
|
|
||||||
|
names: List[str] = []
|
||||||
|
try:
|
||||||
|
db_classes = self.db_manager.get_object_classes() or []
|
||||||
|
names = [str(row.get("class_name")) for row in db_classes if row.get("class_name")]
|
||||||
|
except Exception:
|
||||||
|
names = []
|
||||||
|
|
||||||
|
if not names:
|
||||||
|
observed = sorted({str(det.get("class_name")) for det in detections if det.get("class_name")})
|
||||||
|
names = list(observed)
|
||||||
|
|
||||||
|
return {name: idx for idx, name in enumerate(names)}
|
||||||
|
|
||||||
|
def _format_detection_as_yolo_line(self, det: Dict, class_map: Dict[str, int]) -> Optional[str]:
|
||||||
|
"""Convert a detection row to a YOLO label line.
|
||||||
|
|
||||||
|
- If segmentation_mask is present, exports segmentation polygon format:
|
||||||
|
class x1 y1 x2 y2 ...
|
||||||
|
(normalized coordinates)
|
||||||
|
- Otherwise exports bbox format:
|
||||||
|
class x_center y_center width height
|
||||||
|
(normalized coordinates)
|
||||||
|
"""
|
||||||
|
|
||||||
|
class_name = det.get("class_name")
|
||||||
|
if not class_name or class_name not in class_map:
|
||||||
|
return None
|
||||||
|
class_idx = class_map[class_name]
|
||||||
|
|
||||||
|
mask = det.get("segmentation_mask")
|
||||||
|
polygon = self._convert_segmentation_mask_to_polygon(mask)
|
||||||
|
if polygon:
|
||||||
|
coords = " ".join(f"{value:.6f}" for value in polygon)
|
||||||
|
return f"{class_idx} {coords}".strip()
|
||||||
|
|
||||||
|
bbox = self._convert_bbox_to_yolo_xywh(det)
|
||||||
|
if bbox is None:
|
||||||
|
return None
|
||||||
|
x_center, y_center, width, height = bbox
|
||||||
|
return f"{class_idx} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
|
||||||
|
|
||||||
|
def _convert_bbox_to_yolo_xywh(self, det: Dict) -> Optional[Tuple[float, float, float, float]]:
|
||||||
|
"""Convert stored xyxy (normalized) bbox to YOLO xywh (normalized)."""
|
||||||
|
|
||||||
|
x_min = det.get("x_min")
|
||||||
|
y_min = det.get("y_min")
|
||||||
|
x_max = det.get("x_max")
|
||||||
|
y_max = det.get("y_max")
|
||||||
|
if any(v is None for v in (x_min, y_min, x_max, y_max)):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
x_min_f = self._clamp01(float(x_min))
|
||||||
|
y_min_f = self._clamp01(float(y_min))
|
||||||
|
x_max_f = self._clamp01(float(x_max))
|
||||||
|
y_max_f = self._clamp01(float(y_max))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
width = max(0.0, x_max_f - x_min_f)
|
||||||
|
height = max(0.0, y_max_f - y_min_f)
|
||||||
|
if width <= 0.0 or height <= 0.0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
x_center = x_min_f + width / 2.0
|
||||||
|
y_center = y_min_f + height / 2.0
|
||||||
|
return x_center, y_center, width, height
|
||||||
|
|
||||||
|
def _convert_segmentation_mask_to_polygon(self, mask_data) -> List[float]:
|
||||||
|
"""Convert stored segmentation_mask [[x,y], ...] to YOLO polygon coords [x1,y1,...]."""
|
||||||
|
|
||||||
|
if not isinstance(mask_data, list):
|
||||||
|
return []
|
||||||
|
|
||||||
|
coords: List[float] = []
|
||||||
|
for point in mask_data:
|
||||||
|
if not isinstance(point, (list, tuple)) or len(point) < 2:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
x = self._clamp01(float(point[0]))
|
||||||
|
y = self._clamp01(float(point[1]))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
coords.extend([x, y])
|
||||||
|
|
||||||
|
# Need at least 3 points => 6 values.
|
||||||
|
return coords if len(coords) >= 6 else []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _clamp01(value: float) -> float:
|
||||||
|
if value < 0.0:
|
||||||
|
return 0.0
|
||||||
|
if value > 1.0:
|
||||||
|
return 1.0
|
||||||
|
return value
|
||||||
|
|
||||||
def _load_detections_for_selection(self, entry: Dict):
|
def _load_detections_for_selection(self, entry: Dict):
|
||||||
"""Load detection records for the selected image/model pair."""
|
"""Load detection records for the selected image/model pair."""
|
||||||
|
|||||||
@@ -3,12 +3,14 @@ Training tab for the microscopy object detection application.
|
|||||||
Handles model training with YOLO.
|
Handles model training with YOLO.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import shutil
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import yaml
|
import yaml
|
||||||
|
import numpy as np
|
||||||
from PySide6.QtCore import Qt, QThread, Signal
|
from PySide6.QtCore import Qt, QThread, Signal
|
||||||
from PySide6.QtWidgets import (
|
from PySide6.QtWidgets import (
|
||||||
QWidget,
|
QWidget,
|
||||||
@@ -90,10 +92,7 @@ class TrainingWorker(QThread):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
computed_total = sum(
|
computed_total = sum(max(0, int((stage.get("params") or {}).get("epochs", 0))) for stage in self.stage_plan)
|
||||||
max(0, int((stage.get("params") or {}).get("epochs", 0)))
|
|
||||||
for stage in self.stage_plan
|
|
||||||
)
|
|
||||||
self.total_epochs = total_epochs if total_epochs else computed_total or epochs
|
self.total_epochs = total_epochs if total_epochs else computed_total or epochs
|
||||||
self._stop_requested = False
|
self._stop_requested = False
|
||||||
|
|
||||||
@@ -200,9 +199,7 @@ class TrainingWorker(QThread):
|
|||||||
class TrainingTab(QWidget):
|
class TrainingTab(QWidget):
|
||||||
"""Training tab for model training."""
|
"""Training tab for model training."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
|
||||||
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
|
||||||
):
|
|
||||||
super().__init__(parent)
|
super().__init__(parent)
|
||||||
self.db_manager = db_manager
|
self.db_manager = db_manager
|
||||||
self.config_manager = config_manager
|
self.config_manager = config_manager
|
||||||
@@ -336,18 +333,14 @@ class TrainingTab(QWidget):
|
|||||||
self.model_version_edit = QLineEdit("v1")
|
self.model_version_edit = QLineEdit("v1")
|
||||||
form_layout.addRow("Version:", self.model_version_edit)
|
form_layout.addRow("Version:", self.model_version_edit)
|
||||||
|
|
||||||
default_base_model = self.config_manager.get(
|
default_base_model = self.config_manager.get("models.default_base_model", "yolov8s-seg.pt")
|
||||||
"models.default_base_model", "yolov8s-seg.pt"
|
|
||||||
)
|
|
||||||
base_model_choices = self.config_manager.get("models.base_model_choices", [])
|
base_model_choices = self.config_manager.get("models.base_model_choices", [])
|
||||||
|
|
||||||
self.base_model_combo = QComboBox()
|
self.base_model_combo = QComboBox()
|
||||||
self.base_model_combo.addItem("Custom path…", "")
|
self.base_model_combo.addItem("Custom path…", "")
|
||||||
for choice in base_model_choices:
|
for choice in base_model_choices:
|
||||||
self.base_model_combo.addItem(choice, choice)
|
self.base_model_combo.addItem(choice, choice)
|
||||||
self.base_model_combo.currentIndexChanged.connect(
|
self.base_model_combo.currentIndexChanged.connect(self._on_base_model_preset_changed)
|
||||||
self._on_base_model_preset_changed
|
|
||||||
)
|
|
||||||
form_layout.addRow("Base Model Preset:", self.base_model_combo)
|
form_layout.addRow("Base Model Preset:", self.base_model_combo)
|
||||||
|
|
||||||
base_model_layout = QHBoxLayout()
|
base_model_layout = QHBoxLayout()
|
||||||
@@ -433,12 +426,8 @@ class TrainingTab(QWidget):
|
|||||||
group_layout = QVBoxLayout()
|
group_layout = QVBoxLayout()
|
||||||
|
|
||||||
self.two_stage_checkbox = QCheckBox("Enable staged head-only + full fine-tune")
|
self.two_stage_checkbox = QCheckBox("Enable staged head-only + full fine-tune")
|
||||||
two_stage_defaults = (
|
two_stage_defaults = training_defaults.get("two_stage", {}) if training_defaults else {}
|
||||||
training_defaults.get("two_stage", {}) if training_defaults else {}
|
self.two_stage_checkbox.setChecked(bool(two_stage_defaults.get("enabled", False)))
|
||||||
)
|
|
||||||
self.two_stage_checkbox.setChecked(
|
|
||||||
bool(two_stage_defaults.get("enabled", False))
|
|
||||||
)
|
|
||||||
self.two_stage_checkbox.toggled.connect(self._on_two_stage_toggled)
|
self.two_stage_checkbox.toggled.connect(self._on_two_stage_toggled)
|
||||||
group_layout.addWidget(self.two_stage_checkbox)
|
group_layout.addWidget(self.two_stage_checkbox)
|
||||||
|
|
||||||
@@ -500,9 +489,7 @@ class TrainingTab(QWidget):
|
|||||||
stage2_group.setLayout(stage2_form)
|
stage2_group.setLayout(stage2_form)
|
||||||
controls_layout.addWidget(stage2_group)
|
controls_layout.addWidget(stage2_group)
|
||||||
|
|
||||||
helper_label = QLabel(
|
helper_label = QLabel("When enabled, staged hyperparameters override the global epochs/patience/lr.")
|
||||||
"When enabled, staged hyperparameters override the global epochs/patience/lr."
|
|
||||||
)
|
|
||||||
helper_label.setWordWrap(True)
|
helper_label.setWordWrap(True)
|
||||||
controls_layout.addWidget(helper_label)
|
controls_layout.addWidget(helper_label)
|
||||||
|
|
||||||
@@ -547,9 +534,7 @@ class TrainingTab(QWidget):
|
|||||||
if normalized == preset_value:
|
if normalized == preset_value:
|
||||||
target_index = idx
|
target_index = idx
|
||||||
break
|
break
|
||||||
if normalized.endswith(f"/{preset_value}") or normalized.endswith(
|
if normalized.endswith(f"/{preset_value}") or normalized.endswith(f"\\{preset_value}"):
|
||||||
f"\\{preset_value}"
|
|
||||||
):
|
|
||||||
target_index = idx
|
target_index = idx
|
||||||
break
|
break
|
||||||
self.base_model_combo.blockSignals(True)
|
self.base_model_combo.blockSignals(True)
|
||||||
@@ -637,9 +622,7 @@ class TrainingTab(QWidget):
|
|||||||
|
|
||||||
def _browse_dataset(self):
|
def _browse_dataset(self):
|
||||||
"""Open a file dialog to manually select data.yaml."""
|
"""Open a file dialog to manually select data.yaml."""
|
||||||
start_dir = self.config_manager.get(
|
start_dir = self.config_manager.get("training.last_dataset_dir", "data/datasets")
|
||||||
"training.last_dataset_dir", "data/datasets"
|
|
||||||
)
|
|
||||||
start_path = Path(start_dir).expanduser()
|
start_path = Path(start_dir).expanduser()
|
||||||
if not start_path.exists():
|
if not start_path.exists():
|
||||||
start_path = Path.cwd()
|
start_path = Path.cwd()
|
||||||
@@ -675,9 +658,7 @@ class TrainingTab(QWidget):
|
|||||||
return
|
return
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.exception("Unexpected error while generating data.yaml")
|
logger.exception("Unexpected error while generating data.yaml")
|
||||||
self._display_dataset_error(
|
self._display_dataset_error("Unexpected error while generating data.yaml. Check logs for details.")
|
||||||
"Unexpected error while generating data.yaml. Check logs for details."
|
|
||||||
)
|
|
||||||
QMessageBox.critical(
|
QMessageBox.critical(
|
||||||
self,
|
self,
|
||||||
"data.yaml Generation Failed",
|
"data.yaml Generation Failed",
|
||||||
@@ -754,13 +735,9 @@ class TrainingTab(QWidget):
|
|||||||
self.selected_dataset = info
|
self.selected_dataset = info
|
||||||
|
|
||||||
self.dataset_root_label.setText(info["root"]) # type: ignore[arg-type]
|
self.dataset_root_label.setText(info["root"]) # type: ignore[arg-type]
|
||||||
self.train_count_label.setText(
|
self.train_count_label.setText(self._format_split_info(info["splits"].get("train")))
|
||||||
self._format_split_info(info["splits"].get("train"))
|
|
||||||
)
|
|
||||||
self.val_count_label.setText(self._format_split_info(info["splits"].get("val")))
|
self.val_count_label.setText(self._format_split_info(info["splits"].get("val")))
|
||||||
self.test_count_label.setText(
|
self.test_count_label.setText(self._format_split_info(info["splits"].get("test")))
|
||||||
self._format_split_info(info["splits"].get("test"))
|
|
||||||
)
|
|
||||||
self.num_classes_label.setText(str(info["num_classes"]))
|
self.num_classes_label.setText(str(info["num_classes"]))
|
||||||
class_names = ", ".join(info["class_names"]) or "–"
|
class_names = ", ".join(info["class_names"]) or "–"
|
||||||
self.class_names_label.setText(class_names)
|
self.class_names_label.setText(class_names)
|
||||||
@@ -814,18 +791,12 @@ class TrainingTab(QWidget):
|
|||||||
if split_path.exists():
|
if split_path.exists():
|
||||||
split_info["count"] = self._count_images(split_path)
|
split_info["count"] = self._count_images(split_path)
|
||||||
if split_info["count"] == 0:
|
if split_info["count"] == 0:
|
||||||
warnings.append(
|
warnings.append(f"No images found for {split_name} split at {split_path}")
|
||||||
f"No images found for {split_name} split at {split_path}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
warnings.append(
|
warnings.append(f"{split_name.capitalize()} path does not exist: {split_path}")
|
||||||
f"{split_name.capitalize()} path does not exist: {split_path}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if split_name in ("train", "val"):
|
if split_name in ("train", "val"):
|
||||||
warnings.append(
|
warnings.append(f"{split_name.capitalize()} split missing in data.yaml")
|
||||||
f"{split_name.capitalize()} split missing in data.yaml"
|
|
||||||
)
|
|
||||||
splits[split_name] = split_info
|
splits[split_name] = split_info
|
||||||
|
|
||||||
names_list = self._normalize_class_names(data.get("names"))
|
names_list = self._normalize_class_names(data.get("names"))
|
||||||
@@ -843,9 +814,7 @@ class TrainingTab(QWidget):
|
|||||||
if not names_list and nc_value:
|
if not names_list and nc_value:
|
||||||
names_list = [f"class_{idx}" for idx in range(int(nc_value))]
|
names_list = [f"class_{idx}" for idx in range(int(nc_value))]
|
||||||
elif nc_value and len(names_list) not in (0, int(nc_value)):
|
elif nc_value and len(names_list) not in (0, int(nc_value)):
|
||||||
warnings.append(
|
warnings.append(f"Number of class names ({len(names_list)}) does not match nc={nc_value}")
|
||||||
f"Number of class names ({len(names_list)}) does not match nc={nc_value}"
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_name = data.get("name") or base_path.name
|
dataset_name = data.get("name") or base_path.name
|
||||||
|
|
||||||
@@ -897,16 +866,12 @@ class TrainingTab(QWidget):
|
|||||||
|
|
||||||
class_index_map = self._build_class_index_map(dataset_info)
|
class_index_map = self._build_class_index_map(dataset_info)
|
||||||
if not class_index_map:
|
if not class_index_map:
|
||||||
self._append_training_log(
|
self._append_training_log("Skipping label export: dataset classes do not match database entries.")
|
||||||
"Skipping label export: dataset classes do not match database entries."
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
dataset_root_str = dataset_info.get("root")
|
dataset_root_str = dataset_info.get("root")
|
||||||
dataset_yaml_path = dataset_info.get("yaml_path")
|
dataset_yaml_path = dataset_info.get("yaml_path")
|
||||||
dataset_yaml = (
|
dataset_yaml = Path(dataset_yaml_path).expanduser() if dataset_yaml_path else None
|
||||||
Path(dataset_yaml_path).expanduser() if dataset_yaml_path else None
|
|
||||||
)
|
|
||||||
dataset_root: Optional[Path]
|
dataset_root: Optional[Path]
|
||||||
if dataset_root_str:
|
if dataset_root_str:
|
||||||
dataset_root = Path(dataset_root_str).resolve()
|
dataset_root = Path(dataset_root_str).resolve()
|
||||||
@@ -940,12 +905,17 @@ class TrainingTab(QWidget):
|
|||||||
if stats["registered_images"]:
|
if stats["registered_images"]:
|
||||||
message += f" {stats['registered_images']} image(s) had database-backed annotations."
|
message += f" {stats['registered_images']} image(s) had database-backed annotations."
|
||||||
if stats["missing_records"]:
|
if stats["missing_records"]:
|
||||||
message += f" {stats['missing_records']} image(s) had no database entry; empty label files were written."
|
message += (
|
||||||
|
f" {stats['missing_records']} image(s) had no database entry; empty label files were written."
|
||||||
|
)
|
||||||
split_messages.append(message)
|
split_messages.append(message)
|
||||||
|
|
||||||
for msg in split_messages:
|
for msg in split_messages:
|
||||||
self._append_training_log(msg)
|
self._append_training_log(msg)
|
||||||
|
|
||||||
|
if dataset_yaml:
|
||||||
|
self._clear_rgb_cache_for_dataset(dataset_yaml)
|
||||||
|
|
||||||
def _export_labels_for_split(
|
def _export_labels_for_split(
|
||||||
self,
|
self,
|
||||||
split_name: str,
|
split_name: str,
|
||||||
@@ -969,9 +939,7 @@ class TrainingTab(QWidget):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
processed_images += 1
|
processed_images += 1
|
||||||
label_path = (labels_dir / image_file.relative_to(images_dir)).with_suffix(
|
label_path = (labels_dir / image_file.relative_to(images_dir)).with_suffix(".txt")
|
||||||
".txt"
|
|
||||||
)
|
|
||||||
label_path.parent.mkdir(parents=True, exist_ok=True)
|
label_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
found, annotation_entries = self._fetch_annotations_for_image(
|
found, annotation_entries = self._fetch_annotations_for_image(
|
||||||
@@ -987,25 +955,23 @@ class TrainingTab(QWidget):
|
|||||||
for entry in annotation_entries:
|
for entry in annotation_entries:
|
||||||
polygon = entry.get("polygon") or []
|
polygon = entry.get("polygon") or []
|
||||||
if polygon:
|
if polygon:
|
||||||
|
print(image_file, polygon[:4], polygon[-2:], entry.get("bbox"))
|
||||||
|
# coords = " ".join(f"{value:.6f}" for value in entry.get("bbox"))
|
||||||
|
# coords += " "
|
||||||
coords = " ".join(f"{value:.6f}" for value in polygon)
|
coords = " ".join(f"{value:.6f}" for value in polygon)
|
||||||
handle.write(f"{entry['class_idx']} {coords}\n")
|
handle.write(f"{entry['class_idx']} {coords}\n")
|
||||||
annotations_written += 1
|
annotations_written += 1
|
||||||
elif entry.get("bbox"):
|
elif entry.get("bbox"):
|
||||||
x_center, y_center, width, height = entry["bbox"]
|
x_center, y_center, width, height = entry["bbox"]
|
||||||
handle.write(
|
handle.write(f"{entry['class_idx']} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n")
|
||||||
f"{entry['class_idx']} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n"
|
|
||||||
)
|
|
||||||
annotations_written += 1
|
annotations_written += 1
|
||||||
|
|
||||||
total_annotations += annotations_written
|
total_annotations += annotations_written
|
||||||
|
|
||||||
cache_reset_root = labels_dir.parent
|
cache_reset_root = labels_dir.parent
|
||||||
self._invalidate_split_cache(cache_reset_root)
|
self._invalidate_split_cache(cache_reset_root)
|
||||||
|
|
||||||
if processed_images == 0:
|
if processed_images == 0:
|
||||||
self._append_training_log(
|
self._append_training_log(f"[{split_name}] No images found to export labels for.")
|
||||||
f"[{split_name}] No images found to export labels for."
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -1131,6 +1097,10 @@ class TrainingTab(QWidget):
|
|||||||
xs.append(x_val)
|
xs.append(x_val)
|
||||||
ys.append(y_val)
|
ys.append(y_val)
|
||||||
|
|
||||||
|
if any(np.abs(np.array(coords[:2]) - np.array(coords[-2:])) < 1e-5):
|
||||||
|
print("Closing polygon")
|
||||||
|
coords.extend(coords[:2])
|
||||||
|
|
||||||
if len(coords) < 6:
|
if len(coords) < 6:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -1143,6 +1113,11 @@ class TrainingTab(QWidget):
|
|||||||
+ abs((min(ys) if ys else 0.0) - y_min)
|
+ abs((min(ys) if ys else 0.0) - y_min)
|
||||||
+ abs((max(ys) if ys else 0.0) - y_max)
|
+ abs((max(ys) if ys else 0.0) - y_max)
|
||||||
)
|
)
|
||||||
|
width = max(0.0, x_max - x_min)
|
||||||
|
height = max(0.0, y_max - y_min)
|
||||||
|
x_center = x_min + width / 2.0
|
||||||
|
y_center = y_min + height / 2.0
|
||||||
|
score = (x_center, y_center, width, height)
|
||||||
|
|
||||||
candidates.append((score, coords))
|
candidates.append((score, coords))
|
||||||
|
|
||||||
@@ -1160,6 +1135,35 @@ class TrainingTab(QWidget):
|
|||||||
return 1.0
|
return 1.0
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
def _prepare_dataset_for_training(self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None) -> Path:
|
||||||
|
dataset_info = dataset_info or (
|
||||||
|
self.selected_dataset
|
||||||
|
if self.selected_dataset and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
|
||||||
|
else self._parse_dataset_yaml(dataset_yaml)
|
||||||
|
)
|
||||||
|
|
||||||
|
train_split = dataset_info.get("splits", {}).get("train") or {}
|
||||||
|
images_path_str = train_split.get("path")
|
||||||
|
if not images_path_str:
|
||||||
|
return dataset_yaml
|
||||||
|
|
||||||
|
images_path = Path(images_path_str)
|
||||||
|
if not images_path.exists():
|
||||||
|
return dataset_yaml
|
||||||
|
|
||||||
|
if not self._dataset_requires_rgb_conversion(images_path):
|
||||||
|
return dataset_yaml
|
||||||
|
|
||||||
|
cache_root = self._get_rgb_cache_root(dataset_yaml)
|
||||||
|
rgb_yaml = cache_root / "data.yaml"
|
||||||
|
if rgb_yaml.exists():
|
||||||
|
self._append_training_log(f"Detected grayscale dataset; reusing RGB cache at {cache_root}")
|
||||||
|
return rgb_yaml
|
||||||
|
|
||||||
|
self._append_training_log(f"Detected grayscale dataset; creating RGB cache at {cache_root}")
|
||||||
|
self._build_rgb_dataset(cache_root, dataset_info)
|
||||||
|
return rgb_yaml
|
||||||
|
|
||||||
def _compose_stage_plan(self, params: Dict[str, Any]) -> List[Dict[str, Any]]:
|
def _compose_stage_plan(self, params: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||||
two_stage = params.get("two_stage") or {}
|
two_stage = params.get("two_stage") or {}
|
||||||
base_stage = {
|
base_stage = {
|
||||||
@@ -1244,6 +1248,113 @@ class TrainingTab(QWidget):
|
|||||||
f" • {stage_label}: epochs={epochs}, lr0={lr0}, patience={patience}, freeze={freeze}"
|
f" • {stage_label}: epochs={epochs}, lr0={lr0}, patience={patience}, freeze={freeze}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_rgb_cache_root(self, dataset_yaml: Path) -> Path:
|
||||||
|
cache_base = Path("data/datasets/_rgb_cache")
|
||||||
|
cache_base.mkdir(parents=True, exist_ok=True)
|
||||||
|
key = hashlib.md5(str(dataset_yaml.parent.resolve()).encode()).hexdigest()[:8]
|
||||||
|
return cache_base / f"{dataset_yaml.parent.name}_{key}"
|
||||||
|
|
||||||
|
def _clear_rgb_cache_for_dataset(self, dataset_yaml: Path):
|
||||||
|
cache_root = self._get_rgb_cache_root(dataset_yaml)
|
||||||
|
if cache_root.exists():
|
||||||
|
try:
|
||||||
|
shutil.rmtree(cache_root)
|
||||||
|
logger.debug(f"Removed RGB cache at {cache_root}")
|
||||||
|
except OSError as exc:
|
||||||
|
logger.warning(f"Failed to remove RGB cache {cache_root}: {exc}")
|
||||||
|
|
||||||
|
def _dataset_requires_rgb_conversion(self, images_dir: Path) -> bool:
|
||||||
|
sample_image = self._find_first_image(images_dir)
|
||||||
|
if not sample_image:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Do not force an RGB cache for TIFF datasets.
|
||||||
|
# We handle grayscale/16-bit TIFFs via runtime Ultralytics patches that:
|
||||||
|
# - load TIFFs with `tifffile`
|
||||||
|
# - replicate grayscale to 3 channels without quantization
|
||||||
|
# - normalize uint16 correctly during training
|
||||||
|
if sample_image.suffix.lower() in {".tif", ".tiff"}:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
img = Image(sample_image)
|
||||||
|
return img.pil_image.mode.upper() != "RGB"
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"Failed to inspect image {sample_image}: {exc}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _find_first_image(self, directory: Path) -> Optional[Path]:
|
||||||
|
if not directory.exists():
|
||||||
|
return None
|
||||||
|
for path in directory.rglob("*"):
|
||||||
|
if path.is_file() and path.suffix.lower() in self.allowed_extensions:
|
||||||
|
return path
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _build_rgb_dataset(self, cache_root: Path, dataset_info: Dict[str, Any]):
|
||||||
|
if cache_root.exists():
|
||||||
|
shutil.rmtree(cache_root)
|
||||||
|
cache_root.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
splits = dataset_info.get("splits", {})
|
||||||
|
for split_name in ("train", "val", "test"):
|
||||||
|
split_entry = splits.get(split_name)
|
||||||
|
if not split_entry:
|
||||||
|
continue
|
||||||
|
images_src = Path(split_entry.get("path", ""))
|
||||||
|
if not images_src.exists():
|
||||||
|
continue
|
||||||
|
images_dst = cache_root / split_name / "images"
|
||||||
|
self._convert_images_to_rgb(images_src, images_dst)
|
||||||
|
|
||||||
|
labels_src = self._infer_labels_dir(images_src)
|
||||||
|
if labels_src.exists():
|
||||||
|
labels_dst = cache_root / split_name / "labels"
|
||||||
|
self._copy_labels(labels_src, labels_dst)
|
||||||
|
|
||||||
|
class_names = dataset_info.get("class_names") or []
|
||||||
|
names_map = {idx: name for idx, name in enumerate(class_names)}
|
||||||
|
num_classes = dataset_info.get("num_classes") or len(class_names)
|
||||||
|
|
||||||
|
yaml_payload: Dict[str, Any] = {
|
||||||
|
"path": cache_root.as_posix(),
|
||||||
|
"names": names_map,
|
||||||
|
"nc": num_classes,
|
||||||
|
}
|
||||||
|
|
||||||
|
for split_name in ("train", "val", "test"):
|
||||||
|
images_dir = cache_root / split_name / "images"
|
||||||
|
if images_dir.exists():
|
||||||
|
yaml_payload[split_name] = f"{split_name}/images"
|
||||||
|
|
||||||
|
with open(cache_root / "data.yaml", "w", encoding="utf-8") as handle:
|
||||||
|
yaml.safe_dump(yaml_payload, handle, sort_keys=False)
|
||||||
|
|
||||||
|
def _convert_images_to_rgb(self, src_dir: Path, dst_dir: Path):
|
||||||
|
for src in src_dir.rglob("*"):
|
||||||
|
if not src.is_file() or src.suffix.lower() not in self.allowed_extensions:
|
||||||
|
continue
|
||||||
|
relative = src.relative_to(src_dir)
|
||||||
|
dst = dst_dir / relative
|
||||||
|
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
try:
|
||||||
|
img_obj = Image(src)
|
||||||
|
pil_img = img_obj.pil_image
|
||||||
|
if len(pil_img.getbands()) == 1:
|
||||||
|
rgb_img = img_obj.convert_grayscale_to_rgb_preserve_range()
|
||||||
|
else:
|
||||||
|
rgb_img = pil_img.convert("RGB")
|
||||||
|
rgb_img.save(dst)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"Failed to convert {src} to RGB: {exc}")
|
||||||
|
|
||||||
|
def _copy_labels(self, labels_src: Path, labels_dst: Path):
|
||||||
|
label_files = list(labels_src.rglob("*.txt"))
|
||||||
|
for label_file in label_files:
|
||||||
|
relative = label_file.relative_to(labels_src)
|
||||||
|
dst = labels_dst / relative
|
||||||
|
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.copy2(label_file, dst)
|
||||||
|
|
||||||
def _infer_labels_dir(self, images_dir: Path) -> Path:
|
def _infer_labels_dir(self, images_dir: Path) -> Path:
|
||||||
return images_dir.parent / "labels"
|
return images_dir.parent / "labels"
|
||||||
|
|
||||||
@@ -1316,31 +1427,26 @@ class TrainingTab(QWidget):
|
|||||||
|
|
||||||
dataset_path = Path(dataset_yaml).expanduser()
|
dataset_path = Path(dataset_yaml).expanduser()
|
||||||
if not dataset_path.exists():
|
if not dataset_path.exists():
|
||||||
QMessageBox.warning(
|
QMessageBox.warning(self, "Invalid Dataset", "Selected data.yaml file does not exist.")
|
||||||
self, "Invalid Dataset", "Selected data.yaml file does not exist."
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
dataset_info = (
|
dataset_info = (
|
||||||
self.selected_dataset
|
self.selected_dataset
|
||||||
if self.selected_dataset
|
if self.selected_dataset and self.selected_dataset.get("yaml_path") == str(dataset_path)
|
||||||
and self.selected_dataset.get("yaml_path") == str(dataset_path)
|
|
||||||
else self._parse_dataset_yaml(dataset_path)
|
else self._parse_dataset_yaml(dataset_path)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.training_log.clear()
|
self.training_log.clear()
|
||||||
self._export_labels_from_database(dataset_info)
|
self._export_labels_from_database(dataset_info)
|
||||||
|
|
||||||
self._append_training_log(
|
dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info)
|
||||||
"Using Float32 on-the-fly loader for 16-bit TIFF support (no disk caching)"
|
if dataset_to_use != dataset_path:
|
||||||
)
|
self._append_training_log(f"Using RGB-converted dataset at {dataset_to_use.parent}")
|
||||||
|
|
||||||
params = self._collect_training_params()
|
params = self._collect_training_params()
|
||||||
stage_plan = self._compose_stage_plan(params)
|
stage_plan = self._compose_stage_plan(params)
|
||||||
params["stage_plan"] = stage_plan
|
params["stage_plan"] = stage_plan
|
||||||
total_planned_epochs = (
|
total_planned_epochs = self._calculate_total_stage_epochs(stage_plan) or params["epochs"]
|
||||||
self._calculate_total_stage_epochs(stage_plan) or params["epochs"]
|
|
||||||
)
|
|
||||||
params["total_planned_epochs"] = total_planned_epochs
|
params["total_planned_epochs"] = total_planned_epochs
|
||||||
self._active_training_params = params
|
self._active_training_params = params
|
||||||
self._training_cancelled = False
|
self._training_cancelled = False
|
||||||
@@ -1349,9 +1455,7 @@ class TrainingTab(QWidget):
|
|||||||
self._append_training_log("Two-stage fine-tuning schedule:")
|
self._append_training_log("Two-stage fine-tuning schedule:")
|
||||||
self._log_stage_plan(stage_plan)
|
self._log_stage_plan(stage_plan)
|
||||||
|
|
||||||
self._append_training_log(
|
self._append_training_log(f"Starting training run '{params['run_name']}' using {params['base_model']}")
|
||||||
f"Starting training run '{params['run_name']}' using {params['base_model']}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.training_progress_bar.setVisible(True)
|
self.training_progress_bar.setVisible(True)
|
||||||
self.training_progress_bar.setMaximum(max(1, total_planned_epochs))
|
self.training_progress_bar.setMaximum(max(1, total_planned_epochs))
|
||||||
@@ -1359,7 +1463,7 @@ class TrainingTab(QWidget):
|
|||||||
self._set_training_state(True)
|
self._set_training_state(True)
|
||||||
|
|
||||||
self.training_worker = TrainingWorker(
|
self.training_worker = TrainingWorker(
|
||||||
data_yaml=dataset_path.as_posix(),
|
data_yaml=dataset_to_use.as_posix(),
|
||||||
base_model=params["base_model"],
|
base_model=params["base_model"],
|
||||||
epochs=params["epochs"],
|
epochs=params["epochs"],
|
||||||
batch=params["batch"],
|
batch=params["batch"],
|
||||||
@@ -1379,9 +1483,7 @@ class TrainingTab(QWidget):
|
|||||||
def _stop_training(self):
|
def _stop_training(self):
|
||||||
if self.training_worker and self.training_worker.isRunning():
|
if self.training_worker and self.training_worker.isRunning():
|
||||||
self._training_cancelled = True
|
self._training_cancelled = True
|
||||||
self._append_training_log(
|
self._append_training_log("Stop requested. Waiting for the current epoch to finish...")
|
||||||
"Stop requested. Waiting for the current epoch to finish..."
|
|
||||||
)
|
|
||||||
self.training_worker.stop()
|
self.training_worker.stop()
|
||||||
self.stop_training_button.setEnabled(False)
|
self.stop_training_button.setEnabled(False)
|
||||||
|
|
||||||
@@ -1417,9 +1519,7 @@ class TrainingTab(QWidget):
|
|||||||
|
|
||||||
if worker.isRunning():
|
if worker.isRunning():
|
||||||
if not worker.wait(wait_timeout_ms):
|
if not worker.wait(wait_timeout_ms):
|
||||||
logger.warning(
|
logger.warning("Training worker did not finish within %sms", wait_timeout_ms)
|
||||||
"Training worker did not finish within %sms", wait_timeout_ms
|
|
||||||
)
|
|
||||||
|
|
||||||
worker.deleteLater()
|
worker.deleteLater()
|
||||||
|
|
||||||
@@ -1436,16 +1536,12 @@ class TrainingTab(QWidget):
|
|||||||
self._set_training_state(False)
|
self._set_training_state(False)
|
||||||
self.training_progress_bar.setVisible(False)
|
self.training_progress_bar.setVisible(False)
|
||||||
|
|
||||||
def _on_training_progress(
|
def _on_training_progress(self, current_epoch: int, total_epochs: int, metrics: Dict[str, Any]):
|
||||||
self, current_epoch: int, total_epochs: int, metrics: Dict[str, Any]
|
|
||||||
):
|
|
||||||
self.training_progress_bar.setMaximum(total_epochs)
|
self.training_progress_bar.setMaximum(total_epochs)
|
||||||
self.training_progress_bar.setValue(current_epoch)
|
self.training_progress_bar.setValue(current_epoch)
|
||||||
parts = [f"Epoch {current_epoch}/{total_epochs}"]
|
parts = [f"Epoch {current_epoch}/{total_epochs}"]
|
||||||
if metrics:
|
if metrics:
|
||||||
metric_text = ", ".join(
|
metric_text = ", ".join(f"{key}: {value:.4f}" for key, value in metrics.items())
|
||||||
f"{key}: {value:.4f}" for key, value in metrics.items()
|
|
||||||
)
|
|
||||||
parts.append(metric_text)
|
parts.append(metric_text)
|
||||||
self._append_training_log(" | ".join(parts))
|
self._append_training_log(" | ".join(parts))
|
||||||
|
|
||||||
@@ -1472,9 +1568,7 @@ class TrainingTab(QWidget):
|
|||||||
f"Model trained but not registered: {exc}",
|
f"Model trained but not registered: {exc}",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
QMessageBox.information(
|
QMessageBox.information(self, "Training Complete", "Training finished successfully.")
|
||||||
self, "Training Complete", "Training finished successfully."
|
|
||||||
)
|
|
||||||
|
|
||||||
def _on_training_error(self, message: str):
|
def _on_training_error(self, message: str):
|
||||||
self._cleanup_training_worker()
|
self._cleanup_training_worker()
|
||||||
@@ -1520,9 +1614,7 @@ class TrainingTab(QWidget):
|
|||||||
metrics=results.get("metrics"),
|
metrics=results.get("metrics"),
|
||||||
)
|
)
|
||||||
|
|
||||||
self._append_training_log(
|
self._append_training_log(f"Registered model '{params['model_name']}' (ID {model_id}) at {model_path}")
|
||||||
f"Registered model '{params['model_name']}' (ID {model_id}) at {model_path}"
|
|
||||||
)
|
|
||||||
self._active_training_params = None
|
self._active_training_params = None
|
||||||
|
|
||||||
def _set_training_state(self, is_training: bool):
|
def _set_training_state(self, is_training: bool):
|
||||||
@@ -1565,9 +1657,7 @@ class TrainingTab(QWidget):
|
|||||||
|
|
||||||
def _browse_save_dir(self):
|
def _browse_save_dir(self):
|
||||||
start_path = self.save_dir_edit.text().strip() or "data/models"
|
start_path = self.save_dir_edit.text().strip() or "data/models"
|
||||||
directory = QFileDialog.getExistingDirectory(
|
directory = QFileDialog.getExistingDirectory(self, "Select Save Directory", start_path)
|
||||||
self, "Select Save Directory", start_path
|
|
||||||
)
|
|
||||||
if directory:
|
if directory:
|
||||||
self.save_dir_edit.setText(directory)
|
self.save_dir_edit.setText(directory)
|
||||||
|
|
||||||
|
|||||||
@@ -2,45 +2,554 @@
|
|||||||
Validation tab for the microscopy object detection application.
|
Validation tab for the microscopy object detection application.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
from PySide6.QtCore import Qt, QSize
|
||||||
|
from PySide6.QtGui import QPainter, QPixmap
|
||||||
|
from PySide6.QtWidgets import (
|
||||||
|
QWidget,
|
||||||
|
QVBoxLayout,
|
||||||
|
QLabel,
|
||||||
|
QGroupBox,
|
||||||
|
QHBoxLayout,
|
||||||
|
QPushButton,
|
||||||
|
QComboBox,
|
||||||
|
QFormLayout,
|
||||||
|
QScrollArea,
|
||||||
|
QGridLayout,
|
||||||
|
QFrame,
|
||||||
|
QTableWidget,
|
||||||
|
QTableWidgetItem,
|
||||||
|
QHeaderView,
|
||||||
|
QSplitter,
|
||||||
|
QListWidget,
|
||||||
|
QListWidgetItem,
|
||||||
|
QAbstractItemView,
|
||||||
|
QGraphicsView,
|
||||||
|
QGraphicsScene,
|
||||||
|
QGraphicsPixmapItem,
|
||||||
|
)
|
||||||
|
|
||||||
from src.database.db_manager import DatabaseManager
|
from src.database.db_manager import DatabaseManager
|
||||||
from src.utils.config_manager import ConfigManager
|
from src.utils.config_manager import ConfigManager
|
||||||
|
from src.utils.logger import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class _PlotItem:
|
||||||
|
label: str
|
||||||
|
path: Path
|
||||||
|
|
||||||
|
|
||||||
|
class _ZoomableImageView(QGraphicsView):
|
||||||
|
"""Zoomable image viewer.
|
||||||
|
|
||||||
|
- Mouse wheel: zoom in/out
|
||||||
|
- Left mouse drag: pan (ScrollHandDrag)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, parent: Optional[QWidget] = None):
|
||||||
|
super().__init__(parent)
|
||||||
|
self._scene = QGraphicsScene(self)
|
||||||
|
self.setScene(self._scene)
|
||||||
|
self._pixmap_item = QGraphicsPixmapItem()
|
||||||
|
self._scene.addItem(self._pixmap_item)
|
||||||
|
|
||||||
|
# QGraphicsView render hints are QPainter.RenderHints.
|
||||||
|
self.setRenderHints(self.renderHints() | QPainter.RenderHint.SmoothPixmapTransform)
|
||||||
|
self.setDragMode(QGraphicsView.DragMode.ScrollHandDrag)
|
||||||
|
self.setTransformationAnchor(QGraphicsView.ViewportAnchor.AnchorUnderMouse)
|
||||||
|
self.setResizeAnchor(QGraphicsView.ViewportAnchor.AnchorUnderMouse)
|
||||||
|
|
||||||
|
self._has_pixmap = False
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
self._pixmap_item.setPixmap(QPixmap())
|
||||||
|
self._scene.setSceneRect(0, 0, 1, 1)
|
||||||
|
self.resetTransform()
|
||||||
|
self._has_pixmap = False
|
||||||
|
|
||||||
|
def set_pixmap(self, pixmap: QPixmap, *, fit: bool = True) -> None:
|
||||||
|
self._pixmap_item.setPixmap(pixmap)
|
||||||
|
self._scene.setSceneRect(pixmap.rect())
|
||||||
|
self._has_pixmap = not pixmap.isNull()
|
||||||
|
self.resetTransform()
|
||||||
|
if fit and self._has_pixmap:
|
||||||
|
self.fitInView(self._pixmap_item, Qt.AspectRatioMode.KeepAspectRatio)
|
||||||
|
|
||||||
|
def wheelEvent(self, event) -> None: # type: ignore[override]
|
||||||
|
if not self._has_pixmap:
|
||||||
|
return
|
||||||
|
zoom_in_factor = 1.25
|
||||||
|
zoom_out_factor = 1.0 / zoom_in_factor
|
||||||
|
factor = zoom_in_factor if event.angleDelta().y() > 0 else zoom_out_factor
|
||||||
|
self.scale(factor, factor)
|
||||||
|
|
||||||
|
|
||||||
class ValidationTab(QWidget):
|
class ValidationTab(QWidget):
|
||||||
"""Validation tab placeholder."""
|
"""Validation tab that shows stored validation metrics + plots for a selected model."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
|
||||||
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
|
||||||
):
|
|
||||||
super().__init__(parent)
|
super().__init__(parent)
|
||||||
self.db_manager = db_manager
|
self.db_manager = db_manager
|
||||||
self.config_manager = config_manager
|
self.config_manager = config_manager
|
||||||
|
|
||||||
|
self._models: List[Dict[str, Any]] = []
|
||||||
|
self._selected_model_id: Optional[int] = None
|
||||||
|
self._plot_widgets: List[QWidget] = []
|
||||||
|
self._plot_items: List[_PlotItem] = []
|
||||||
|
|
||||||
self._setup_ui()
|
self._setup_ui()
|
||||||
|
self.refresh()
|
||||||
|
|
||||||
def _setup_ui(self):
|
def _setup_ui(self):
|
||||||
"""Setup user interface."""
|
"""Setup user interface."""
|
||||||
layout = QVBoxLayout()
|
layout = QVBoxLayout(self)
|
||||||
|
|
||||||
group = QGroupBox("Validation")
|
# ===== Header controls =====
|
||||||
group_layout = QVBoxLayout()
|
header = QGroupBox("Validation")
|
||||||
label = QLabel(
|
header_layout = QVBoxLayout()
|
||||||
"Validation functionality will be implemented here.\n\n"
|
header_row = QHBoxLayout()
|
||||||
"Features:\n"
|
|
||||||
"- Model validation\n"
|
|
||||||
"- Metrics visualization\n"
|
|
||||||
"- Confusion matrix\n"
|
|
||||||
"- Precision-Recall curves"
|
|
||||||
)
|
|
||||||
group_layout.addWidget(label)
|
|
||||||
group.setLayout(group_layout)
|
|
||||||
|
|
||||||
layout.addWidget(group)
|
header_row.addWidget(QLabel("Select model:"))
|
||||||
layout.addStretch()
|
|
||||||
self.setLayout(layout)
|
self.model_combo = QComboBox()
|
||||||
|
self.model_combo.setMinimumWidth(420)
|
||||||
|
self.model_combo.currentIndexChanged.connect(self._on_model_selected)
|
||||||
|
header_row.addWidget(self.model_combo, 1)
|
||||||
|
|
||||||
|
self.refresh_btn = QPushButton("Refresh")
|
||||||
|
self.refresh_btn.clicked.connect(self.refresh)
|
||||||
|
header_row.addWidget(self.refresh_btn)
|
||||||
|
header_row.addStretch()
|
||||||
|
|
||||||
|
header_layout.addLayout(header_row)
|
||||||
|
self.header_status = QLabel("No models loaded.")
|
||||||
|
self.header_status.setWordWrap(True)
|
||||||
|
header_layout.addWidget(self.header_status)
|
||||||
|
header.setLayout(header_layout)
|
||||||
|
layout.addWidget(header)
|
||||||
|
|
||||||
|
# ===== Metrics =====
|
||||||
|
metrics_group = QGroupBox("Validation Metrics")
|
||||||
|
metrics_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
self.metrics_form = QFormLayout()
|
||||||
|
self.metric_labels: Dict[str, QLabel] = {}
|
||||||
|
for key in ("mAP50", "mAP50-95", "precision", "recall", "fitness"):
|
||||||
|
value_label = QLabel("–")
|
||||||
|
value_label.setTextInteractionFlags(Qt.TextSelectableByMouse)
|
||||||
|
self.metric_labels[key] = value_label
|
||||||
|
self.metrics_form.addRow(f"{key}:", value_label)
|
||||||
|
metrics_layout.addLayout(self.metrics_form)
|
||||||
|
|
||||||
|
self.per_class_table = QTableWidget(0, 3)
|
||||||
|
self.per_class_table.setHorizontalHeaderLabels(["Class", "AP", "AP50"])
|
||||||
|
self.per_class_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Stretch)
|
||||||
|
self.per_class_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeToContents)
|
||||||
|
self.per_class_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeToContents)
|
||||||
|
self.per_class_table.setEditTriggers(QTableWidget.NoEditTriggers)
|
||||||
|
self.per_class_table.setMinimumHeight(160)
|
||||||
|
metrics_layout.addWidget(QLabel("Per-class metrics (if available):"))
|
||||||
|
metrics_layout.addWidget(self.per_class_table)
|
||||||
|
|
||||||
|
metrics_group.setLayout(metrics_layout)
|
||||||
|
layout.addWidget(metrics_group)
|
||||||
|
|
||||||
|
# ===== Plots =====
|
||||||
|
plots_group = QGroupBox("Validation Plots")
|
||||||
|
plots_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
self.plots_status = QLabel("Select a model to see validation plots.")
|
||||||
|
self.plots_status.setWordWrap(True)
|
||||||
|
plots_layout.addWidget(self.plots_status)
|
||||||
|
|
||||||
|
self.plots_splitter = QSplitter(Qt.Orientation.Horizontal)
|
||||||
|
|
||||||
|
# Left: selected image viewer
|
||||||
|
left_widget = QWidget()
|
||||||
|
left_layout = QVBoxLayout(left_widget)
|
||||||
|
left_layout.setContentsMargins(0, 0, 0, 0)
|
||||||
|
|
||||||
|
self.selected_plot_title = QLabel("No image selected.")
|
||||||
|
self.selected_plot_title.setWordWrap(True)
|
||||||
|
self.selected_plot_title.setTextInteractionFlags(Qt.TextSelectableByMouse)
|
||||||
|
left_layout.addWidget(self.selected_plot_title)
|
||||||
|
|
||||||
|
self.plot_view = _ZoomableImageView()
|
||||||
|
self.plot_view.setMinimumHeight(360)
|
||||||
|
left_layout.addWidget(self.plot_view, 1)
|
||||||
|
|
||||||
|
self.selected_plot_path = QLabel("")
|
||||||
|
self.selected_plot_path.setWordWrap(True)
|
||||||
|
self.selected_plot_path.setStyleSheet("color: #888;")
|
||||||
|
self.selected_plot_path.setTextInteractionFlags(Qt.TextSelectableByMouse)
|
||||||
|
left_layout.addWidget(self.selected_plot_path)
|
||||||
|
|
||||||
|
# Right: scrollable list
|
||||||
|
right_widget = QWidget()
|
||||||
|
right_layout = QVBoxLayout(right_widget)
|
||||||
|
right_layout.setContentsMargins(0, 0, 0, 0)
|
||||||
|
right_layout.addWidget(QLabel("Images:"))
|
||||||
|
|
||||||
|
self.plots_list = QListWidget()
|
||||||
|
self.plots_list.setSelectionMode(QAbstractItemView.SelectionMode.SingleSelection)
|
||||||
|
self.plots_list.setIconSize(QSize(160, 160))
|
||||||
|
self.plots_list.itemSelectionChanged.connect(self._on_plot_item_selected)
|
||||||
|
right_layout.addWidget(self.plots_list, 1)
|
||||||
|
|
||||||
|
self.plots_splitter.addWidget(left_widget)
|
||||||
|
self.plots_splitter.addWidget(right_widget)
|
||||||
|
self.plots_splitter.setStretchFactor(0, 3)
|
||||||
|
self.plots_splitter.setStretchFactor(1, 1)
|
||||||
|
plots_layout.addWidget(self.plots_splitter, 1)
|
||||||
|
|
||||||
|
plots_group.setLayout(plots_layout)
|
||||||
|
layout.addWidget(plots_group, 1)
|
||||||
|
|
||||||
|
layout.addStretch(0)
|
||||||
|
|
||||||
|
self._clear_metrics()
|
||||||
|
self._clear_plots()
|
||||||
|
|
||||||
|
# ==================== Public API ====================
|
||||||
|
|
||||||
def refresh(self):
|
def refresh(self):
|
||||||
"""Refresh the tab."""
|
"""Refresh the tab."""
|
||||||
|
self._load_models()
|
||||||
|
self._populate_model_combo()
|
||||||
|
self._restore_or_select_default_model()
|
||||||
|
|
||||||
|
# ==================== Internal: models ====================
|
||||||
|
|
||||||
|
def _load_models(self) -> None:
|
||||||
|
try:
|
||||||
|
self._models = self.db_manager.get_models() or []
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to load models: %s", exc)
|
||||||
|
self._models = []
|
||||||
|
|
||||||
|
def _populate_model_combo(self) -> None:
|
||||||
|
self.model_combo.blockSignals(True)
|
||||||
|
self.model_combo.clear()
|
||||||
|
self.model_combo.addItem("Select a model…", None)
|
||||||
|
|
||||||
|
for model in self._models:
|
||||||
|
model_id = model.get("id")
|
||||||
|
name = (model.get("model_name") or "").strip()
|
||||||
|
version = (model.get("model_version") or "").strip()
|
||||||
|
created_at = model.get("created_at")
|
||||||
|
label = f"{name} {version}".strip()
|
||||||
|
if created_at:
|
||||||
|
label = f"{label} ({created_at})"
|
||||||
|
self.model_combo.addItem(label, model_id)
|
||||||
|
|
||||||
|
self.model_combo.blockSignals(False)
|
||||||
|
|
||||||
|
if self._models:
|
||||||
|
self.header_status.setText(f"Loaded {len(self._models)} model(s).")
|
||||||
|
else:
|
||||||
|
self.header_status.setText("No models found. Train a model first.")
|
||||||
|
|
||||||
|
def _restore_or_select_default_model(self) -> None:
|
||||||
|
if not self._models:
|
||||||
|
self._selected_model_id = None
|
||||||
|
self._clear_metrics()
|
||||||
|
self._clear_plots()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Keep selection if still present.
|
||||||
|
if self._selected_model_id is not None:
|
||||||
|
for idx in range(1, self.model_combo.count()):
|
||||||
|
if self.model_combo.itemData(idx) == self._selected_model_id:
|
||||||
|
self.model_combo.setCurrentIndex(idx)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Otherwise select the newest model (top of get_models ORDER BY created_at DESC).
|
||||||
|
first_model_id = self.model_combo.itemData(1) if self.model_combo.count() > 1 else None
|
||||||
|
if first_model_id is not None:
|
||||||
|
self.model_combo.setCurrentIndex(1)
|
||||||
|
|
||||||
|
def _on_model_selected(self, index: int) -> None:
|
||||||
|
model_id = self.model_combo.itemData(index)
|
||||||
|
if not model_id:
|
||||||
|
self._selected_model_id = None
|
||||||
|
self._clear_metrics()
|
||||||
|
self._clear_plots()
|
||||||
|
self.plots_status.setText("Select a model to see validation plots.")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._selected_model_id = int(model_id)
|
||||||
|
model = self._get_model_by_id(self._selected_model_id)
|
||||||
|
if not model:
|
||||||
|
self._clear_metrics()
|
||||||
|
self._clear_plots()
|
||||||
|
self.plots_status.setText("Selected model not found.")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._render_metrics(model)
|
||||||
|
self._render_plots(model)
|
||||||
|
|
||||||
|
def _get_model_by_id(self, model_id: int) -> Optional[Dict[str, Any]]:
|
||||||
|
for model in self._models:
|
||||||
|
if model.get("id") == model_id:
|
||||||
|
return model
|
||||||
|
try:
|
||||||
|
return self.db_manager.get_model_by_id(model_id)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ==================== Internal: metrics ====================
|
||||||
|
|
||||||
|
def _clear_metrics(self) -> None:
|
||||||
|
for label in self.metric_labels.values():
|
||||||
|
label.setText("–")
|
||||||
|
self.per_class_table.setRowCount(0)
|
||||||
|
|
||||||
|
def _render_metrics(self, model: Dict[str, Any]) -> None:
|
||||||
|
self._clear_metrics()
|
||||||
|
|
||||||
|
metrics: Dict[str, Any] = model.get("metrics") or {}
|
||||||
|
# Training tab stores metrics under results['metrics'] in training results payload.
|
||||||
|
if isinstance(metrics, dict) and "metrics" in metrics and isinstance(metrics.get("metrics"), dict):
|
||||||
|
metrics = metrics.get("metrics") or {}
|
||||||
|
|
||||||
|
def set_metric(key: str, value: Any) -> None:
|
||||||
|
if key not in self.metric_labels:
|
||||||
|
return
|
||||||
|
if value is None:
|
||||||
|
self.metric_labels[key].setText("–")
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
self.metric_labels[key].setText(f"{float(value):.4f}")
|
||||||
|
except Exception:
|
||||||
|
self.metric_labels[key].setText(str(value))
|
||||||
|
|
||||||
|
set_metric("mAP50", metrics.get("mAP50"))
|
||||||
|
set_metric("mAP50-95", metrics.get("mAP50-95") or metrics.get("mAP50_95") or metrics.get("mAP50-95"))
|
||||||
|
set_metric("precision", metrics.get("precision"))
|
||||||
|
set_metric("recall", metrics.get("recall"))
|
||||||
|
set_metric("fitness", metrics.get("fitness"))
|
||||||
|
|
||||||
|
# Optional per-class metrics
|
||||||
|
class_metrics = metrics.get("class_metrics") if isinstance(metrics, dict) else None
|
||||||
|
if isinstance(class_metrics, dict) and class_metrics:
|
||||||
|
items = sorted(class_metrics.items(), key=lambda kv: str(kv[0]))
|
||||||
|
self.per_class_table.setRowCount(len(items))
|
||||||
|
for row, (cls_name, cls_stats) in enumerate(items):
|
||||||
|
ap = (cls_stats or {}).get("ap")
|
||||||
|
ap50 = (cls_stats or {}).get("ap50")
|
||||||
|
self.per_class_table.setItem(row, 0, QTableWidgetItem(str(cls_name)))
|
||||||
|
self.per_class_table.setItem(row, 1, QTableWidgetItem(self._format_float(ap)))
|
||||||
|
self.per_class_table.setItem(row, 2, QTableWidgetItem(self._format_float(ap50)))
|
||||||
|
else:
|
||||||
|
self.per_class_table.setRowCount(0)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _format_float(value: Any) -> str:
|
||||||
|
if value is None:
|
||||||
|
return "–"
|
||||||
|
try:
|
||||||
|
return f"{float(value):.4f}"
|
||||||
|
except Exception:
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
# ==================== Internal: plots ====================
|
||||||
|
|
||||||
|
def _clear_plots(self) -> None:
|
||||||
|
# Remove legacy grid widgets (from the initial implementation).
|
||||||
|
for widget in self._plot_widgets:
|
||||||
|
widget.setParent(None)
|
||||||
|
widget.deleteLater()
|
||||||
|
self._plot_widgets = []
|
||||||
|
|
||||||
|
self._plot_items = []
|
||||||
|
|
||||||
|
if hasattr(self, "plots_list"):
|
||||||
|
self.plots_list.blockSignals(True)
|
||||||
|
self.plots_list.clear()
|
||||||
|
self.plots_list.blockSignals(False)
|
||||||
|
|
||||||
|
if hasattr(self, "plot_view"):
|
||||||
|
self.plot_view.clear()
|
||||||
|
if hasattr(self, "selected_plot_title"):
|
||||||
|
self.selected_plot_title.setText("No image selected.")
|
||||||
|
if hasattr(self, "selected_plot_path"):
|
||||||
|
self.selected_plot_path.setText("")
|
||||||
|
|
||||||
|
def _render_plots(self, model: Dict[str, Any]) -> None:
|
||||||
|
self._clear_plots()
|
||||||
|
|
||||||
|
plot_dirs = self._infer_run_directories(model)
|
||||||
|
plot_items = self._discover_plot_items(plot_dirs)
|
||||||
|
|
||||||
|
if not plot_items:
|
||||||
|
dirs_text = "\n".join(str(p) for p in plot_dirs if p)
|
||||||
|
self.plots_status.setText(
|
||||||
|
"No validation plot images found for this model.\n\n"
|
||||||
|
"Searched directories:\n" + (dirs_text or "(none)")
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
self._plot_items = list(plot_items)
|
||||||
|
self.plots_status.setText(f"Found {len(plot_items)} plot image(s). Select one to view/zoom.")
|
||||||
|
|
||||||
|
self.plots_list.blockSignals(True)
|
||||||
|
self.plots_list.clear()
|
||||||
|
for idx, item in enumerate(self._plot_items):
|
||||||
|
qitem = QListWidgetItem(item.label)
|
||||||
|
qitem.setData(Qt.ItemDataRole.UserRole, idx)
|
||||||
|
|
||||||
|
pix = QPixmap(str(item.path))
|
||||||
|
if not pix.isNull():
|
||||||
|
thumb = pix.scaled(
|
||||||
|
self.plots_list.iconSize(),
|
||||||
|
Qt.AspectRatioMode.KeepAspectRatio,
|
||||||
|
Qt.TransformationMode.SmoothTransformation,
|
||||||
|
)
|
||||||
|
qitem.setIcon(thumb)
|
||||||
|
self.plots_list.addItem(qitem)
|
||||||
|
self.plots_list.blockSignals(False)
|
||||||
|
|
||||||
|
if self.plots_list.count() > 0:
|
||||||
|
self.plots_list.setCurrentRow(0)
|
||||||
|
|
||||||
|
def _on_plot_item_selected(self) -> None:
|
||||||
|
if not self._plot_items:
|
||||||
|
return
|
||||||
|
|
||||||
|
selected = self.plots_list.selectedItems()
|
||||||
|
if not selected:
|
||||||
|
return
|
||||||
|
|
||||||
|
idx = selected[0].data(Qt.ItemDataRole.UserRole)
|
||||||
|
try:
|
||||||
|
idx_int = int(idx)
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
if idx_int < 0 or idx_int >= len(self._plot_items):
|
||||||
|
return
|
||||||
|
|
||||||
|
plot = self._plot_items[idx_int]
|
||||||
|
self.selected_plot_title.setText(plot.label)
|
||||||
|
self.selected_plot_path.setText(str(plot.path))
|
||||||
|
|
||||||
|
pix = QPixmap(str(plot.path))
|
||||||
|
if pix.isNull():
|
||||||
|
self.plot_view.clear()
|
||||||
|
return
|
||||||
|
self.plot_view.set_pixmap(pix, fit=True)
|
||||||
|
|
||||||
|
def _infer_run_directories(self, model: Dict[str, Any]) -> List[Path]:
|
||||||
|
dirs: List[Path] = []
|
||||||
|
|
||||||
|
# 1) Infer from model_path: .../<run>/weights/best.pt -> <run>
|
||||||
|
model_path = model.get("model_path")
|
||||||
|
if model_path:
|
||||||
|
try:
|
||||||
|
p = Path(str(model_path)).expanduser()
|
||||||
|
if p.name.lower().endswith(".pt"):
|
||||||
|
# If it lives under weights/, use parent.parent.
|
||||||
|
if p.parent.name == "weights" and p.parent.parent.exists():
|
||||||
|
dirs.append(p.parent.parent)
|
||||||
|
elif p.parent.exists():
|
||||||
|
dirs.append(p.parent)
|
||||||
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# 2) Look at training_params.stage_results[].results.save_dir
|
||||||
|
training_params = model.get("training_params") or {}
|
||||||
|
stage_results = None
|
||||||
|
if isinstance(training_params, dict):
|
||||||
|
stage_results = training_params.get("stage_results")
|
||||||
|
if isinstance(stage_results, list):
|
||||||
|
for stage in stage_results:
|
||||||
|
results = (stage or {}).get("results")
|
||||||
|
save_dir = (results or {}).get("save_dir") if isinstance(results, dict) else None
|
||||||
|
if save_dir:
|
||||||
|
try:
|
||||||
|
save_path = Path(str(save_dir)).expanduser()
|
||||||
|
if save_path.exists():
|
||||||
|
dirs.append(save_path)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Deduplicate while preserving order.
|
||||||
|
unique: List[Path] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for d in dirs:
|
||||||
|
try:
|
||||||
|
resolved = str(d.resolve())
|
||||||
|
except Exception:
|
||||||
|
resolved = str(d)
|
||||||
|
if resolved not in seen and d.exists() and d.is_dir():
|
||||||
|
seen.add(resolved)
|
||||||
|
unique.append(d)
|
||||||
|
return unique
|
||||||
|
|
||||||
|
def _discover_plot_items(self, directories: Sequence[Path]) -> List[_PlotItem]:
|
||||||
|
# Prefer canonical Ultralytics filenames first, then fall back to any png/jpg.
|
||||||
|
preferred_names = [
|
||||||
|
"results.png",
|
||||||
|
"results.jpg",
|
||||||
|
"confusion_matrix.png",
|
||||||
|
"confusion_matrix_normalized.png",
|
||||||
|
"labels.jpg",
|
||||||
|
"labels.png",
|
||||||
|
"BoxPR_curve.png",
|
||||||
|
"BoxP_curve.png",
|
||||||
|
"BoxR_curve.png",
|
||||||
|
"BoxF1_curve.png",
|
||||||
|
"MaskPR_curve.png",
|
||||||
|
"MaskP_curve.png",
|
||||||
|
"MaskR_curve.png",
|
||||||
|
"MaskF1_curve.png",
|
||||||
|
"val_batch0_pred.jpg",
|
||||||
|
"val_batch0_labels.jpg",
|
||||||
|
]
|
||||||
|
|
||||||
|
found: List[_PlotItem] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
|
||||||
|
for d in directories:
|
||||||
|
# 1) Preferred
|
||||||
|
for name in preferred_names:
|
||||||
|
p = d / name
|
||||||
|
if p.exists() and p.is_file():
|
||||||
|
key = str(p)
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
found.append(_PlotItem(label=f"{name} (from {d.name})", path=p))
|
||||||
|
|
||||||
|
# 2) Curated globs
|
||||||
|
for pattern in ("train_batch*.jpg", "val_batch*.jpg", "*curve*.png"):
|
||||||
|
for p in sorted(d.glob(pattern)):
|
||||||
|
if not p.is_file():
|
||||||
|
continue
|
||||||
|
key = str(p)
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
found.append(_PlotItem(label=f"{p.name} (from {d.name})", path=p))
|
||||||
|
|
||||||
|
# 3) Fallback: any top-level png/jpg (excluding weights dir contents)
|
||||||
|
for ext in ("*.png", "*.jpg", "*.jpeg", "*.webp"):
|
||||||
|
for p in sorted(d.glob(ext)):
|
||||||
|
if not p.is_file():
|
||||||
|
continue
|
||||||
|
key = str(p)
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
found.append(_PlotItem(label=f"{p.name} (from {d.name})", path=p))
|
||||||
|
|
||||||
|
# Keep list bounded to avoid UI overload for huge runs.
|
||||||
|
return found[:60]
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from PySide6.QtGui import (
|
|||||||
QPaintEvent,
|
QPaintEvent,
|
||||||
QPolygonF,
|
QPolygonF,
|
||||||
)
|
)
|
||||||
from PySide6.QtCore import Qt, QEvent, Signal, QPoint, QPointF, QRect
|
from PySide6.QtCore import Qt, QEvent, Signal, QPoint, QPointF, QRect, QTimer
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from src.utils.image import Image, ImageLoadError
|
from src.utils.image import Image, ImageLoadError
|
||||||
@@ -79,9 +79,7 @@ def rdp(points: List[Tuple[float, float]], epsilon: float) -> List[Tuple[float,
|
|||||||
return [start, end]
|
return [start, end]
|
||||||
|
|
||||||
|
|
||||||
def simplify_polyline(
|
def simplify_polyline(points: List[Tuple[float, float]], epsilon: float) -> List[Tuple[float, float]]:
|
||||||
points: List[Tuple[float, float]], epsilon: float
|
|
||||||
) -> List[Tuple[float, float]]:
|
|
||||||
"""
|
"""
|
||||||
Simplify a polyline with RDP while preserving closure semantics.
|
Simplify a polyline with RDP while preserving closure semantics.
|
||||||
|
|
||||||
@@ -145,6 +143,10 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
self.zoom_step = 0.1
|
self.zoom_step = 0.1
|
||||||
self.zoom_wheel_step = 0.15
|
self.zoom_wheel_step = 0.15
|
||||||
|
|
||||||
|
# Auto-fit behavior (opt-in): when enabled, newly loaded images (and resizes)
|
||||||
|
# will scale to fill the available viewport while preserving aspect ratio.
|
||||||
|
self._auto_fit_to_view: bool = False
|
||||||
|
|
||||||
# Drawing / interaction state
|
# Drawing / interaction state
|
||||||
self.is_drawing = False
|
self.is_drawing = False
|
||||||
self.polyline_enabled = False
|
self.polyline_enabled = False
|
||||||
@@ -175,6 +177,35 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
|
|
||||||
self._setup_ui()
|
self._setup_ui()
|
||||||
|
|
||||||
|
def set_auto_fit_to_view(self, enabled: bool):
|
||||||
|
"""Enable/disable automatic zoom-to-fit behavior."""
|
||||||
|
self._auto_fit_to_view = bool(enabled)
|
||||||
|
if self._auto_fit_to_view and self.original_pixmap is not None:
|
||||||
|
QTimer.singleShot(0, self.fit_to_view)
|
||||||
|
|
||||||
|
def fit_to_view(self, padding_px: int = 6):
|
||||||
|
"""Zoom the image so it fits the scroll area's viewport (aspect preserved)."""
|
||||||
|
if self.original_pixmap is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
viewport = self.scroll_area.viewport().size()
|
||||||
|
available_w = max(1, int(viewport.width()) - int(padding_px))
|
||||||
|
available_h = max(1, int(viewport.height()) - int(padding_px))
|
||||||
|
|
||||||
|
img_w = max(1, int(self.original_pixmap.width()))
|
||||||
|
img_h = max(1, int(self.original_pixmap.height()))
|
||||||
|
|
||||||
|
scale_w = available_w / img_w
|
||||||
|
scale_h = available_h / img_h
|
||||||
|
new_scale = min(scale_w, scale_h)
|
||||||
|
new_scale = max(self.zoom_min, min(self.zoom_max, float(new_scale)))
|
||||||
|
|
||||||
|
if abs(new_scale - self.zoom_scale) < 1e-4:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.zoom_scale = new_scale
|
||||||
|
self._apply_zoom()
|
||||||
|
|
||||||
def _setup_ui(self):
|
def _setup_ui(self):
|
||||||
"""Setup user interface."""
|
"""Setup user interface."""
|
||||||
layout = QVBoxLayout()
|
layout = QVBoxLayout()
|
||||||
@@ -187,9 +218,7 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
|
|
||||||
self.canvas_label = QLabel("No image loaded")
|
self.canvas_label = QLabel("No image loaded")
|
||||||
self.canvas_label.setAlignment(Qt.AlignCenter)
|
self.canvas_label.setAlignment(Qt.AlignCenter)
|
||||||
self.canvas_label.setStyleSheet(
|
self.canvas_label.setStyleSheet("QLabel { background-color: #2b2b2b; color: #888; }")
|
||||||
"QLabel { background-color: #2b2b2b; color: #888; }"
|
|
||||||
)
|
|
||||||
self.canvas_label.setScaledContents(False)
|
self.canvas_label.setScaledContents(False)
|
||||||
self.canvas_label.setMouseTracking(True)
|
self.canvas_label.setMouseTracking(True)
|
||||||
|
|
||||||
@@ -212,9 +241,18 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
self.zoom_scale = 1.0
|
self.zoom_scale = 1.0
|
||||||
self.clear_annotations()
|
self.clear_annotations()
|
||||||
self._display_image()
|
self._display_image()
|
||||||
logger.debug(
|
|
||||||
f"Loaded image into annotation canvas: {image.width}x{image.height}"
|
# Defer fit-to-view until the widget has a valid viewport size.
|
||||||
)
|
if self._auto_fit_to_view:
|
||||||
|
QTimer.singleShot(0, self.fit_to_view)
|
||||||
|
|
||||||
|
logger.debug(f"Loaded image into annotation canvas: {image.width}x{image.height}")
|
||||||
|
|
||||||
|
def resizeEvent(self, event):
|
||||||
|
"""Optionally keep the image fitted when the widget is resized."""
|
||||||
|
super().resizeEvent(event)
|
||||||
|
if self._auto_fit_to_view and self.original_pixmap is not None:
|
||||||
|
QTimer.singleShot(0, self.fit_to_view)
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
"""Clear the displayed image and all annotations."""
|
"""Clear the displayed image and all annotations."""
|
||||||
@@ -250,12 +288,10 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
# Get image data in a format compatible with Qt
|
# Get image data in a format compatible with Qt
|
||||||
if self.current_image.channels in (3, 4):
|
if self.current_image.channels in (3, 4):
|
||||||
image_data = self.current_image.get_rgb()
|
image_data = self.current_image.get_rgb()
|
||||||
height, width = image_data.shape[:2]
|
|
||||||
else:
|
else:
|
||||||
image_data = self.current_image.get_grayscale()
|
image_data = self.current_image.get_qt_rgb()
|
||||||
height, width = image_data.shape
|
|
||||||
|
|
||||||
image_data = np.ascontiguousarray(image_data)
|
height, width = image_data.shape[:2]
|
||||||
bytes_per_line = image_data.strides[0]
|
bytes_per_line = image_data.strides[0]
|
||||||
|
|
||||||
qimage = QImage(
|
qimage = QImage(
|
||||||
@@ -263,7 +299,7 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
bytes_per_line,
|
bytes_per_line,
|
||||||
self.current_image.qtimage_format,
|
QImage.Format_RGBX32FPx4, # self.current_image.qtimage_format,
|
||||||
).copy() # Copy so Qt owns the buffer even after numpy array goes out of scope
|
).copy() # Copy so Qt owns the buffer even after numpy array goes out of scope
|
||||||
|
|
||||||
self.original_pixmap = QPixmap.fromImage(qimage)
|
self.original_pixmap = QPixmap.fromImage(qimage)
|
||||||
@@ -291,22 +327,14 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
scaled_width,
|
scaled_width,
|
||||||
scaled_height,
|
scaled_height,
|
||||||
Qt.KeepAspectRatio,
|
Qt.KeepAspectRatio,
|
||||||
(
|
(Qt.SmoothTransformation if self.zoom_scale >= 1.0 else Qt.FastTransformation),
|
||||||
Qt.SmoothTransformation
|
|
||||||
if self.zoom_scale >= 1.0
|
|
||||||
else Qt.FastTransformation
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
scaled_annotations = self.annotation_pixmap.scaled(
|
scaled_annotations = self.annotation_pixmap.scaled(
|
||||||
scaled_width,
|
scaled_width,
|
||||||
scaled_height,
|
scaled_height,
|
||||||
Qt.KeepAspectRatio,
|
Qt.KeepAspectRatio,
|
||||||
(
|
(Qt.SmoothTransformation if self.zoom_scale >= 1.0 else Qt.FastTransformation),
|
||||||
Qt.SmoothTransformation
|
|
||||||
if self.zoom_scale >= 1.0
|
|
||||||
else Qt.FastTransformation
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Composite image and annotations
|
# Composite image and annotations
|
||||||
@@ -392,16 +420,11 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
y = (pos.y() - offset_y) / self.zoom_scale
|
y = (pos.y() - offset_y) / self.zoom_scale
|
||||||
|
|
||||||
# Check bounds
|
# Check bounds
|
||||||
if (
|
if 0 <= x < self.original_pixmap.width() and 0 <= y < self.original_pixmap.height():
|
||||||
0 <= x < self.original_pixmap.width()
|
|
||||||
and 0 <= y < self.original_pixmap.height()
|
|
||||||
):
|
|
||||||
return (int(x), int(y))
|
return (int(x), int(y))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _find_polyline_at(
|
def _find_polyline_at(self, img_x: float, img_y: float, threshold_px: float = 5.0) -> Optional[int]:
|
||||||
self, img_x: float, img_y: float, threshold_px: float = 5.0
|
|
||||||
) -> Optional[int]:
|
|
||||||
"""
|
"""
|
||||||
Find index of polyline whose geometry is within threshold_px of (img_x, img_y).
|
Find index of polyline whose geometry is within threshold_px of (img_x, img_y).
|
||||||
Returns the index in self.polylines, or None if none is close enough.
|
Returns the index in self.polylines, or None if none is close enough.
|
||||||
@@ -423,9 +446,7 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
|
|
||||||
# Precise distance to all segments
|
# Precise distance to all segments
|
||||||
for (x1, y1), (x2, y2) in zip(polyline[:-1], polyline[1:]):
|
for (x1, y1), (x2, y2) in zip(polyline[:-1], polyline[1:]):
|
||||||
d = perpendicular_distance(
|
d = perpendicular_distance((img_x, img_y), (float(x1), float(y1)), (float(x2), float(y2)))
|
||||||
(img_x, img_y), (float(x1), float(y1)), (float(x2), float(y2))
|
|
||||||
)
|
|
||||||
if d < best_dist:
|
if d < best_dist:
|
||||||
best_dist = d
|
best_dist = d
|
||||||
best_index = idx
|
best_index = idx
|
||||||
@@ -626,11 +647,7 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
|
|
||||||
def mouseMoveEvent(self, event: QMouseEvent):
|
def mouseMoveEvent(self, event: QMouseEvent):
|
||||||
"""Handle mouse move events for drawing."""
|
"""Handle mouse move events for drawing."""
|
||||||
if (
|
if not self.is_drawing or not self.polyline_enabled or self.annotation_pixmap is None:
|
||||||
not self.is_drawing
|
|
||||||
or not self.polyline_enabled
|
|
||||||
or self.annotation_pixmap is None
|
|
||||||
):
|
|
||||||
super().mouseMoveEvent(event)
|
super().mouseMoveEvent(event)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -690,15 +707,10 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
|
|
||||||
if len(simplified) >= 2:
|
if len(simplified) >= 2:
|
||||||
# Store polyline and redraw all annotations
|
# Store polyline and redraw all annotations
|
||||||
self._add_polyline(
|
self._add_polyline(simplified, self.polyline_pen_color, self.polyline_pen_width)
|
||||||
simplified, self.polyline_pen_color, self.polyline_pen_width
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert to normalized coordinates for metadata + signal
|
# Convert to normalized coordinates for metadata + signal
|
||||||
normalized_stroke = [
|
normalized_stroke = [self._image_to_normalized_coords(int(x), int(y)) for (x, y) in simplified]
|
||||||
self._image_to_normalized_coords(int(x), int(y))
|
|
||||||
for (x, y) in simplified
|
|
||||||
]
|
|
||||||
self.all_strokes.append(
|
self.all_strokes.append(
|
||||||
{
|
{
|
||||||
"points": normalized_stroke,
|
"points": normalized_stroke,
|
||||||
@@ -711,8 +723,7 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
# Emit signal with normalized coordinates
|
# Emit signal with normalized coordinates
|
||||||
self.annotation_drawn.emit(normalized_stroke)
|
self.annotation_drawn.emit(normalized_stroke)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Completed stroke with {len(simplified)} points "
|
f"Completed stroke with {len(simplified)} points " f"(normalized len={len(normalized_stroke)})"
|
||||||
f"(normalized len={len(normalized_stroke)})"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.current_stroke = []
|
self.current_stroke = []
|
||||||
@@ -752,9 +763,7 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
|
|
||||||
# Store polyline as [y_norm, x_norm] to match DB convention and
|
# Store polyline as [y_norm, x_norm] to match DB convention and
|
||||||
# the expectations of draw_saved_polyline().
|
# the expectations of draw_saved_polyline().
|
||||||
normalized_polyline = [
|
normalized_polyline = [[y / img_height, x / img_width] for (x, y) in polyline]
|
||||||
[y / img_height, x / img_width] for (x, y) in polyline
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Polyline {idx}: {len(polyline)} points, "
|
f"Polyline {idx}: {len(polyline)} points, "
|
||||||
@@ -774,7 +783,7 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
self,
|
self,
|
||||||
polyline: List[List[float]],
|
polyline: List[List[float]],
|
||||||
color: str,
|
color: str,
|
||||||
width: int = 3,
|
width: int = 1,
|
||||||
annotation_id: Optional[int] = None,
|
annotation_id: Optional[int] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -812,17 +821,13 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
|
|
||||||
# Store and redraw using common pipeline
|
# Store and redraw using common pipeline
|
||||||
pen_color = QColor(color)
|
pen_color = QColor(color)
|
||||||
pen_color.setAlpha(128) # Add semi-transparency
|
pen_color.setAlpha(255) # Add semi-transparency
|
||||||
self._add_polyline(img_coords, pen_color, width, annotation_id=annotation_id)
|
self._add_polyline(img_coords, pen_color, width, annotation_id=annotation_id)
|
||||||
|
|
||||||
# Store in all_strokes for consistency (uses normalized coordinates)
|
# Store in all_strokes for consistency (uses normalized coordinates)
|
||||||
self.all_strokes.append(
|
self.all_strokes.append({"points": polyline, "color": color, "alpha": 255, "width": width})
|
||||||
{"points": polyline, "color": color, "alpha": 128, "width": width}
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(f"Drew saved polyline with {len(polyline)} points in color {color}")
|
||||||
f"Drew saved polyline with {len(polyline)} points in color {color}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def draw_saved_bbox(
|
def draw_saved_bbox(
|
||||||
self,
|
self,
|
||||||
@@ -846,9 +851,7 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if len(bbox) != 4:
|
if len(bbox) != 4:
|
||||||
logger.warning(
|
logger.warning(f"Invalid bounding box format: expected 4 values, got {len(bbox)}")
|
||||||
f"Invalid bounding box format: expected 4 values, got {len(bbox)}"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Convert normalized coordinates to image coordinates (for logging/debug)
|
# Convert normalized coordinates to image coordinates (for logging/debug)
|
||||||
@@ -869,15 +872,11 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
# in _redraw_annotations() together with all polylines.
|
# in _redraw_annotations() together with all polylines.
|
||||||
pen_color = QColor(color)
|
pen_color = QColor(color)
|
||||||
pen_color.setAlpha(128) # Add semi-transparency
|
pen_color.setAlpha(128) # Add semi-transparency
|
||||||
self.bboxes.append(
|
self.bboxes.append([float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)])
|
||||||
[float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)]
|
|
||||||
)
|
|
||||||
self.bbox_meta.append({"color": pen_color, "width": int(width), "label": label})
|
self.bbox_meta.append({"color": pen_color, "width": int(width), "label": label})
|
||||||
|
|
||||||
# Store in all_strokes for consistency
|
# Store in all_strokes for consistency
|
||||||
self.all_strokes.append(
|
self.all_strokes.append({"bbox": bbox, "color": color, "alpha": 128, "width": width, "label": label})
|
||||||
{"bbox": bbox, "color": color, "alpha": 128, "width": width, "label": label}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Redraw overlay (polylines + all bounding boxes)
|
# Redraw overlay (polylines + all bounding boxes)
|
||||||
self._redraw_annotations()
|
self._redraw_annotations()
|
||||||
|
|||||||
@@ -1,18 +1,21 @@
|
|||||||
"""
|
"""YOLO model wrapper for the microscopy object detection application.
|
||||||
YOLO model wrapper for the microscopy object detection application.
|
|
||||||
Provides a clean interface to YOLOv8 for training, validation, and inference.
|
Notes on 16-bit TIFF support:
|
||||||
|
- Ultralytics training defaults assume 8-bit images and normalize by dividing by 255.
|
||||||
|
- This project can patch Ultralytics at runtime to decode TIFFs via `tifffile` and
|
||||||
|
normalize `uint16` correctly.
|
||||||
|
|
||||||
|
See [`apply_ultralytics_16bit_tiff_patches()`](src/utils/ultralytics_16bit_patch.py:1).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from ultralytics import YOLO
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, List, Dict, Callable, Any
|
from typing import Optional, List, Dict, Callable, Any
|
||||||
import torch
|
import torch
|
||||||
import tempfile
|
import tempfile
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
from src.utils.image import Image
|
||||||
from src.utils.image import Image, convert_grayscale_to_rgb_preserve_range
|
|
||||||
from src.utils.logger import get_logger
|
from src.utils.logger import get_logger
|
||||||
from src.utils.train_ultralytics_float import train_with_float32_loader
|
from src.utils.ultralytics_16bit_patch import apply_ultralytics_16bit_tiff_patches
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -33,6 +36,9 @@ class YOLOWrapper:
|
|||||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
logger.info(f"YOLOWrapper initialized with device: {self.device}")
|
logger.info(f"YOLOWrapper initialized with device: {self.device}")
|
||||||
|
|
||||||
|
# Apply Ultralytics runtime patches early (before first import/instantiation of YOLO datasets/trainers).
|
||||||
|
apply_ultralytics_16bit_tiff_patches()
|
||||||
|
|
||||||
def load_model(self) -> bool:
|
def load_model(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Load YOLO model from path.
|
Load YOLO model from path.
|
||||||
@@ -42,6 +48,9 @@ class YOLOWrapper:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info(f"Loading YOLO model from {self.model_path}")
|
logger.info(f"Loading YOLO model from {self.model_path}")
|
||||||
|
# Import YOLO lazily to ensure runtime patches are applied first.
|
||||||
|
from ultralytics import YOLO
|
||||||
|
|
||||||
self.model = YOLO(self.model_path)
|
self.model = YOLO(self.model_path)
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
logger.info("Model loaded successfully")
|
logger.info("Model loaded successfully")
|
||||||
@@ -61,11 +70,10 @@ class YOLOWrapper:
|
|||||||
name: str = "custom_model",
|
name: str = "custom_model",
|
||||||
resume: bool = False,
|
resume: bool = False,
|
||||||
callbacks: Optional[Dict[str, Callable]] = None,
|
callbacks: Optional[Dict[str, Callable]] = None,
|
||||||
use_float32_loader: bool = True,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Train the YOLO model with optional float32 loader for 16-bit TIFFs.
|
Train the YOLO model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data_yaml: Path to data.yaml configuration file
|
data_yaml: Path to data.yaml configuration file
|
||||||
@@ -77,43 +85,30 @@ class YOLOWrapper:
|
|||||||
name: Name for the training run
|
name: Name for the training run
|
||||||
resume: Resume training from last checkpoint
|
resume: Resume training from last checkpoint
|
||||||
callbacks: Optional Ultralytics callback dictionary
|
callbacks: Optional Ultralytics callback dictionary
|
||||||
use_float32_loader: Use custom Float32Dataset for 16-bit TIFFs (default: True)
|
|
||||||
**kwargs: Additional training arguments
|
**kwargs: Additional training arguments
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with training results
|
Dictionary with training results
|
||||||
"""
|
"""
|
||||||
if 1:
|
|
||||||
logger.info(f"Starting training: {name}")
|
|
||||||
logger.info(
|
|
||||||
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if dataset has 16-bit TIFFs and use float32 loader
|
|
||||||
if use_float32_loader:
|
|
||||||
logger.info("Using Float32Dataset loader for 16-bit TIFF support")
|
|
||||||
return train_with_float32_loader(
|
|
||||||
model_path=self.model_path,
|
|
||||||
data_yaml=data_yaml,
|
|
||||||
epochs=epochs,
|
|
||||||
imgsz=imgsz,
|
|
||||||
batch=batch,
|
|
||||||
patience=patience,
|
|
||||||
save_dir=save_dir,
|
|
||||||
name=name,
|
|
||||||
callbacks=callbacks,
|
|
||||||
device=self.device,
|
|
||||||
resume=resume,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Standard training (old behavior)
|
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
if not self.load_model():
|
if not self.load_model():
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
||||||
f"Failed to load model from {self.model_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Starting training: {name}")
|
||||||
|
logger.info(f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}")
|
||||||
|
|
||||||
|
# Defaults for 16-bit safety: disable augmentations that force uint8 and HSV ops that assume 0..255.
|
||||||
|
# Users can override by passing explicit kwargs.
|
||||||
|
kwargs.setdefault("mosaic", 0.0)
|
||||||
|
kwargs.setdefault("mixup", 0.0)
|
||||||
|
kwargs.setdefault("cutmix", 0.0)
|
||||||
|
kwargs.setdefault("copy_paste", 0.0)
|
||||||
|
kwargs.setdefault("hsv_h", 0.0)
|
||||||
|
kwargs.setdefault("hsv_s", 0.0)
|
||||||
|
kwargs.setdefault("hsv_v", 0.0)
|
||||||
|
|
||||||
|
# Train the model
|
||||||
results = self.model.train(
|
results = self.model.train(
|
||||||
data=data_yaml,
|
data=data_yaml,
|
||||||
epochs=epochs,
|
epochs=epochs,
|
||||||
@@ -130,9 +125,9 @@ class YOLOWrapper:
|
|||||||
logger.info("Training completed successfully")
|
logger.info("Training completed successfully")
|
||||||
return self._format_training_results(results)
|
return self._format_training_results(results)
|
||||||
|
|
||||||
# except Exception as e:
|
except Exception as e:
|
||||||
# logger.error(f"Error during training: {e}")
|
logger.error(f"Error during training: {e}")
|
||||||
# raise
|
raise
|
||||||
|
|
||||||
def validate(self, data_yaml: str, split: str = "val", **kwargs) -> Dict[str, Any]:
|
def validate(self, data_yaml: str, split: str = "val", **kwargs) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@@ -152,9 +147,7 @@ class YOLOWrapper:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Starting validation on {split} split")
|
logger.info(f"Starting validation on {split} split")
|
||||||
results = self.model.val(
|
results = self.model.val(data=data_yaml, split=split, device=self.device, **kwargs)
|
||||||
data=data_yaml, split=split, device=self.device, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Validation completed successfully")
|
logger.info("Validation completed successfully")
|
||||||
return self._format_validation_results(results)
|
return self._format_validation_results(results)
|
||||||
@@ -193,17 +186,18 @@ class YOLOWrapper:
|
|||||||
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
||||||
|
|
||||||
prepared_source, cleanup_path = self._prepare_source(source)
|
prepared_source, cleanup_path = self._prepare_source(source)
|
||||||
|
imgsz = 1088
|
||||||
try:
|
try:
|
||||||
logger.info(f"Running inference on {source}")
|
logger.info(f"Running inference on {source} -> prepared_source {prepared_source}")
|
||||||
results = self.model.predict(
|
results = self.model.predict(
|
||||||
source=prepared_source,
|
source=source,
|
||||||
conf=conf,
|
conf=conf,
|
||||||
iou=iou,
|
iou=iou,
|
||||||
save=save,
|
save=save,
|
||||||
save_txt=save_txt,
|
save_txt=save_txt,
|
||||||
save_conf=save_conf,
|
save_conf=save_conf,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
imgsz=imgsz,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -215,20 +209,13 @@ class YOLOWrapper:
|
|||||||
logger.error(f"Error during inference: {e}")
|
logger.error(f"Error during inference: {e}")
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
# Clean up temporary files (only for non-16-bit images)
|
if 0: # cleanup_path:
|
||||||
# 16-bit TIFFs return numpy arrays directly, so cleanup_path is None
|
|
||||||
if cleanup_path:
|
|
||||||
try:
|
try:
|
||||||
os.remove(cleanup_path)
|
os.remove(cleanup_path)
|
||||||
logger.debug(f"Cleaned up temporary file: {cleanup_path}")
|
|
||||||
except OSError as cleanup_error:
|
except OSError as cleanup_error:
|
||||||
logger.warning(
|
logger.warning(f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}")
|
||||||
f"Failed to delete temporary file {cleanup_path}: {cleanup_error}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def export(
|
def export(self, format: str = "onnx", output_path: Optional[str] = None, **kwargs) -> str:
|
||||||
self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
|
|
||||||
) -> str:
|
|
||||||
"""
|
"""
|
||||||
Export model to different format.
|
Export model to different format.
|
||||||
|
|
||||||
@@ -255,14 +242,7 @@ class YOLOWrapper:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def _prepare_source(self, source):
|
def _prepare_source(self, source):
|
||||||
"""Convert single-channel images to RGB for inference.
|
"""Convert single-channel images to RGB temporarily for inference."""
|
||||||
|
|
||||||
For 16-bit TIFF files, this will:
|
|
||||||
1. Load using tifffile
|
|
||||||
2. Normalize to float32 [0-1] (NO uint8 conversion to avoid data loss)
|
|
||||||
3. Replicate grayscale → RGB (3 channels)
|
|
||||||
4. Pass directly as numpy array to YOLO
|
|
||||||
"""
|
|
||||||
cleanup_path = None
|
cleanup_path = None
|
||||||
|
|
||||||
if isinstance(source, (str, Path)):
|
if isinstance(source, (str, Path)):
|
||||||
@@ -270,60 +250,13 @@ class YOLOWrapper:
|
|||||||
if source_path.is_file():
|
if source_path.is_file():
|
||||||
try:
|
try:
|
||||||
img_obj = Image(source_path)
|
img_obj = Image(source_path)
|
||||||
|
|
||||||
# Check if it's a 16-bit TIFF file
|
|
||||||
is_16bit_tiff = (
|
|
||||||
source_path.suffix.lower() in [".tif", ".tiff"]
|
|
||||||
and img_obj.dtype == np.uint16
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_16bit_tiff:
|
|
||||||
# Process 16-bit TIFF: normalize to float32 [0-1]
|
|
||||||
# NO uint8 conversion - pass float32 directly to avoid data loss
|
|
||||||
normalized_float = img_obj.to_normalized_float32()
|
|
||||||
|
|
||||||
# Convert grayscale to RGB by replicating channels
|
|
||||||
if len(normalized_float.shape) == 2:
|
|
||||||
# Grayscale: H,W → H,W,3
|
|
||||||
rgb_float = np.stack([normalized_float] * 3, axis=-1)
|
|
||||||
elif (
|
|
||||||
len(normalized_float.shape) == 3
|
|
||||||
and normalized_float.shape[2] == 1
|
|
||||||
):
|
|
||||||
# Grayscale with channel dim: H,W,1 → H,W,3
|
|
||||||
rgb_float = np.repeat(normalized_float, 3, axis=2)
|
|
||||||
else:
|
|
||||||
# Already multi-channel
|
|
||||||
rgb_float = normalized_float
|
|
||||||
|
|
||||||
# Ensure contiguous array and float32
|
|
||||||
rgb_float = np.ascontiguousarray(rgb_float, dtype=np.float32)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Loaded 16-bit TIFF {source_path} as float32 [0-1] RGB "
|
|
||||||
f"(shape: {rgb_float.shape}, dtype: {rgb_float.dtype}, "
|
|
||||||
f"range: [{rgb_float.min():.4f}, {rgb_float.max():.4f}])"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return numpy array directly - YOLO can handle it
|
|
||||||
return rgb_float, cleanup_path
|
|
||||||
else:
|
|
||||||
# Standard processing for other images
|
|
||||||
pil_img = img_obj.pil_image
|
|
||||||
if len(pil_img.getbands()) == 1:
|
|
||||||
rgb_img = convert_grayscale_to_rgb_preserve_range(pil_img)
|
|
||||||
else:
|
|
||||||
rgb_img = pil_img.convert("RGB")
|
|
||||||
|
|
||||||
suffix = source_path.suffix or ".png"
|
suffix = source_path.suffix or ".png"
|
||||||
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
|
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
|
||||||
tmp_path = tmp.name
|
tmp_path = tmp.name
|
||||||
tmp.close()
|
tmp.close()
|
||||||
rgb_img.save(tmp_path)
|
img_obj.save(tmp_path)
|
||||||
cleanup_path = tmp_path
|
cleanup_path = tmp_path
|
||||||
logger.info(
|
logger.info(f"Converted image {source_path} to RGB for inference at {tmp_path}")
|
||||||
f"Converted image {source_path} to RGB for inference at {tmp_path}"
|
|
||||||
)
|
|
||||||
return tmp_path, cleanup_path
|
return tmp_path, cleanup_path
|
||||||
except Exception as convert_error:
|
except Exception as convert_error:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -336,9 +269,7 @@ class YOLOWrapper:
|
|||||||
"""Format training results into dictionary."""
|
"""Format training results into dictionary."""
|
||||||
try:
|
try:
|
||||||
# Get the results dict
|
# Get the results dict
|
||||||
results_dict = (
|
results_dict = results.results_dict if hasattr(results, "results_dict") else {}
|
||||||
results.results_dict if hasattr(results, "results_dict") else {}
|
|
||||||
)
|
|
||||||
|
|
||||||
formatted = {
|
formatted = {
|
||||||
"success": True,
|
"success": True,
|
||||||
@@ -371,9 +302,7 @@ class YOLOWrapper:
|
|||||||
"mAP50-95": float(box_metrics.map),
|
"mAP50-95": float(box_metrics.map),
|
||||||
"precision": float(box_metrics.mp),
|
"precision": float(box_metrics.mp),
|
||||||
"recall": float(box_metrics.mr),
|
"recall": float(box_metrics.mr),
|
||||||
"fitness": (
|
"fitness": (float(results.fitness) if hasattr(results, "fitness") else 0.0),
|
||||||
float(results.fitness) if hasattr(results, "fitness") else 0.0
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add per-class metrics if available
|
# Add per-class metrics if available
|
||||||
@@ -383,11 +312,7 @@ class YOLOWrapper:
|
|||||||
if idx < len(box_metrics.ap):
|
if idx < len(box_metrics.ap):
|
||||||
class_metrics[name] = {
|
class_metrics[name] = {
|
||||||
"ap": float(box_metrics.ap[idx]),
|
"ap": float(box_metrics.ap[idx]),
|
||||||
"ap50": (
|
"ap50": (float(box_metrics.ap50[idx]) if hasattr(box_metrics, "ap50") else 0.0),
|
||||||
float(box_metrics.ap50[idx])
|
|
||||||
if hasattr(box_metrics, "ap50")
|
|
||||||
else 0.0
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
formatted["class_metrics"] = class_metrics
|
formatted["class_metrics"] = class_metrics
|
||||||
|
|
||||||
@@ -420,21 +345,15 @@ class YOLOWrapper:
|
|||||||
"class_id": int(boxes.cls[i]),
|
"class_id": int(boxes.cls[i]),
|
||||||
"class_name": result.names[int(boxes.cls[i])],
|
"class_name": result.names[int(boxes.cls[i])],
|
||||||
"confidence": float(boxes.conf[i]),
|
"confidence": float(boxes.conf[i]),
|
||||||
"bbox_normalized": [
|
"bbox_normalized": [float(v) for v in xyxyn], # [x_min, y_min, x_max, y_max]
|
||||||
float(v) for v in xyxyn
|
"bbox_absolute": [float(v) for v in boxes.xyxy[i].cpu().numpy()], # Absolute pixels
|
||||||
], # [x_min, y_min, x_max, y_max]
|
|
||||||
"bbox_absolute": [
|
|
||||||
float(v) for v in boxes.xyxy[i].cpu().numpy()
|
|
||||||
], # Absolute pixels
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Extract segmentation mask if available
|
# Extract segmentation mask if available
|
||||||
if has_masks:
|
if has_masks:
|
||||||
try:
|
try:
|
||||||
# Get the mask for this detection
|
# Get the mask for this detection
|
||||||
mask_data = result.masks.xy[
|
mask_data = result.masks.xy[i] # Polygon coordinates in absolute pixels
|
||||||
i
|
|
||||||
] # Polygon coordinates in absolute pixels
|
|
||||||
|
|
||||||
# Convert to normalized coordinates
|
# Convert to normalized coordinates
|
||||||
if len(mask_data) > 0:
|
if len(mask_data) > 0:
|
||||||
@@ -447,9 +366,7 @@ class YOLOWrapper:
|
|||||||
else:
|
else:
|
||||||
detection["segmentation_mask"] = None
|
detection["segmentation_mask"] = None
|
||||||
except Exception as mask_error:
|
except Exception as mask_error:
|
||||||
logger.warning(
|
logger.warning(f"Error extracting mask for detection {i}: {mask_error}")
|
||||||
f"Error extracting mask for detection {i}: {mask_error}"
|
|
||||||
)
|
|
||||||
detection["segmentation_mask"] = None
|
detection["segmentation_mask"] = None
|
||||||
else:
|
else:
|
||||||
detection["segmentation_mask"] = None
|
detection["segmentation_mask"] = None
|
||||||
@@ -463,9 +380,7 @@ class YOLOWrapper:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_bbox_format(
|
def convert_bbox_format(bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy") -> List[float]:
|
||||||
bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy"
|
|
||||||
) -> List[float]:
|
|
||||||
"""
|
"""
|
||||||
Convert bounding box between formats.
|
Convert bounding box between formats.
|
||||||
|
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class ConfigManager:
|
|||||||
"models_directory": "data/models",
|
"models_directory": "data/models",
|
||||||
"base_model_choices": [
|
"base_model_choices": [
|
||||||
"yolov8s-seg.pt",
|
"yolov8s-seg.pt",
|
||||||
"yolov11s-seg.pt",
|
"yolo11s-seg.pt",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
"training": {
|
"training": {
|
||||||
@@ -225,6 +225,4 @@ class ConfigManager:
|
|||||||
|
|
||||||
def get_allowed_extensions(self) -> list:
|
def get_allowed_extensions(self) -> list:
|
||||||
"""Get list of allowed image file extensions."""
|
"""Get list of allowed image file extensions."""
|
||||||
return self.get(
|
return self.get("image_repository.allowed_extensions", Image.SUPPORTED_EXTENSIONS)
|
||||||
"image_repository.allowed_extensions", Image.SUPPORTED_EXTENSIONS
|
|
||||||
)
|
|
||||||
|
|||||||
103
src/utils/create_mask_from_detection.py
Normal file
103
src/utils/create_mask_from_detection.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from skimage.draw import polygon
|
||||||
|
from tifffile import TiffFile
|
||||||
|
|
||||||
|
from src.database.db_manager import DatabaseManager
|
||||||
|
|
||||||
|
|
||||||
|
def read_image(image_path: Path) -> np.ndarray:
|
||||||
|
metadata = {}
|
||||||
|
with TiffFile(image_path) as tif:
|
||||||
|
image = tif.asarray()
|
||||||
|
metadata = tif.imagej_metadata
|
||||||
|
return image, metadata
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
polygon_vertices = np.array([[10, 10], [50, 10], [50, 50], [10, 50]])
|
||||||
|
image = np.zeros((100, 100), dtype=np.uint8)
|
||||||
|
rr, cc = polygon(polygon_vertices[:, 0], polygon_vertices[:, 1])
|
||||||
|
image[rr, cc] = 255
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
db = DatabaseManager()
|
||||||
|
model_name = "c17"
|
||||||
|
model_id = db.get_models(filters={"model_name": model_name})[0]["id"]
|
||||||
|
print(f"Model name {model_name}, id {model_id}")
|
||||||
|
detections = db.get_detections(filters={"model_id": model_id})
|
||||||
|
|
||||||
|
file_stems = set()
|
||||||
|
|
||||||
|
for detection in detections:
|
||||||
|
file_stems.add(detection["image_filename"].split("_")[0])
|
||||||
|
|
||||||
|
print("Files:", file_stems)
|
||||||
|
|
||||||
|
for stem in file_stems:
|
||||||
|
print(stem)
|
||||||
|
detections = db.get_detections(filters={"model_id": model_id, "i.filename": f"LIKE %{stem}%"})
|
||||||
|
annotations = []
|
||||||
|
for detection in detections:
|
||||||
|
source_path = Path(detection["metadata"]["source_path"])
|
||||||
|
image, metadata = read_image(source_path)
|
||||||
|
|
||||||
|
offset = np.array(list(map(int, metadata["tile_section"].split(","))))[::-1]
|
||||||
|
scale = np.array(list(map(int, metadata["patch_size"].split(","))))[::-1]
|
||||||
|
# tile_size = np.array(list(map(int, metadata["tile_size"].split(","))))
|
||||||
|
segmentation = np.array(detection["segmentation_mask"]) # * tile_size
|
||||||
|
|
||||||
|
# print(source_path, image, metadata, segmentation.shape)
|
||||||
|
# print(offset)
|
||||||
|
# print(scale)
|
||||||
|
# print(segmentation)
|
||||||
|
|
||||||
|
# segmentation = (segmentation + offset * tile_size) / (tile_size * scale)
|
||||||
|
segmentation = (segmentation + offset) / scale
|
||||||
|
|
||||||
|
yolo_annotation = f"{detection['metadata']['class_id']} " + " ".join(
|
||||||
|
[f"{x:.6f} {y:.6f}" for x, y in segmentation]
|
||||||
|
)
|
||||||
|
annotations.append(yolo_annotation)
|
||||||
|
# print(segmentation)
|
||||||
|
# print(yolo_annotation)
|
||||||
|
|
||||||
|
# aa
|
||||||
|
print(
|
||||||
|
" ",
|
||||||
|
detection["model_name"],
|
||||||
|
detection["image_id"],
|
||||||
|
detection["image_filename"],
|
||||||
|
source_path,
|
||||||
|
metadata["label_path"],
|
||||||
|
)
|
||||||
|
# section_i_section_j = detection["image_filename"].split("_")[1].split(".")[0]
|
||||||
|
# print(" ", section_i_section_j)
|
||||||
|
|
||||||
|
label_path = metadata["label_path"]
|
||||||
|
print(" ", label_path)
|
||||||
|
with open(label_path, "w") as f:
|
||||||
|
f.write("\n".join(annotations))
|
||||||
|
|
||||||
|
exit()
|
||||||
|
|
||||||
|
for detection in detections:
|
||||||
|
print(detection["model_name"], detection["image_id"], detection["image_filename"])
|
||||||
|
|
||||||
|
print(detections[0])
|
||||||
|
# polygon_vertices = np.array([[10, 10], [50, 10], [50, 50], [10, 50]])
|
||||||
|
|
||||||
|
# image = np.zeros((100, 100), dtype=np.uint8)
|
||||||
|
|
||||||
|
# rr, cc = polygon(polygon_vertices[:, 0], polygon_vertices[:, 1])
|
||||||
|
|
||||||
|
# image[rr, cc] = 255
|
||||||
|
|
||||||
|
# import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
# plt.imshow(image, cmap='gray')
|
||||||
|
# plt.show()
|
||||||
@@ -6,17 +6,55 @@ import cv2
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
from PIL import Image as PILImage
|
|
||||||
import tifffile
|
|
||||||
|
|
||||||
from src.utils.logger import get_logger
|
from src.utils.logger import get_logger
|
||||||
from src.utils.file_utils import validate_file_path, is_image_file
|
from src.utils.file_utils import validate_file_path, is_image_file
|
||||||
|
|
||||||
from PySide6.QtGui import QImage
|
from PySide6.QtGui import QImage
|
||||||
|
|
||||||
|
from tifffile import imread, imwrite
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.5) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Convert a grayscale image to a pseudo-RGB image using a gamma correction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arr: Input grayscale image as numpy array
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pseudo-RGB image as numpy array
|
||||||
|
"""
|
||||||
|
if arr.ndim != 2:
|
||||||
|
raise ValueError("Input array must be a grayscale image with shape (H, W)")
|
||||||
|
|
||||||
|
a1 = arr.copy().astype(np.float32)
|
||||||
|
a1 -= np.percentile(a1, 2)
|
||||||
|
a1[a1 < 0] = 0
|
||||||
|
p999 = np.percentile(a1, 99.9)
|
||||||
|
a1[a1 > p999] = p999
|
||||||
|
a1 /= a1.max()
|
||||||
|
|
||||||
|
if 1:
|
||||||
|
a2 = a1.copy()
|
||||||
|
a2 = a2**gamma
|
||||||
|
a2 /= a2.max()
|
||||||
|
|
||||||
|
a3 = a1.copy()
|
||||||
|
p9999 = np.percentile(a3, 99.99)
|
||||||
|
a3[a3 > p9999] = p9999
|
||||||
|
a3 /= a3.max()
|
||||||
|
|
||||||
|
# return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
|
||||||
|
# return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
|
||||||
|
out = np.stack([a1, a2, a3], axis=0)
|
||||||
|
# print(any(np.isnan(out).flatten()))
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class ImageLoadError(Exception):
|
class ImageLoadError(Exception):
|
||||||
"""Exception raised when an image cannot be loaded."""
|
"""Exception raised when an image cannot be loaded."""
|
||||||
|
|
||||||
@@ -55,7 +93,6 @@ class Image:
|
|||||||
"""
|
"""
|
||||||
self.path = Path(image_path)
|
self.path = Path(image_path)
|
||||||
self._data: Optional[np.ndarray] = None
|
self._data: Optional[np.ndarray] = None
|
||||||
self._pil_image: Optional[PILImage.Image] = None
|
|
||||||
self._width: int = 0
|
self._width: int = 0
|
||||||
self._height: int = 0
|
self._height: int = 0
|
||||||
self._channels: int = 0
|
self._channels: int = 0
|
||||||
@@ -81,75 +118,34 @@ class Image:
|
|||||||
if not is_image_file(str(self.path), self.SUPPORTED_EXTENSIONS):
|
if not is_image_file(str(self.path), self.SUPPORTED_EXTENSIONS):
|
||||||
ext = self.path.suffix.lower()
|
ext = self.path.suffix.lower()
|
||||||
raise ImageLoadError(
|
raise ImageLoadError(
|
||||||
f"Unsupported image format: {ext}. "
|
f"Unsupported image format: {ext}. " f"Supported formats: {', '.join(self.SUPPORTED_EXTENSIONS)}"
|
||||||
f"Supported formats: {', '.join(self.SUPPORTED_EXTENSIONS)}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Check if it's a TIFF file - use tifffile for better support
|
|
||||||
if self.path.suffix.lower() in [".tif", ".tiff"]:
|
if self.path.suffix.lower() in [".tif", ".tiff"]:
|
||||||
self._data = tifffile.imread(str(self.path))
|
self._data = imread(str(self.path))
|
||||||
|
|
||||||
if self._data is None:
|
|
||||||
raise ImageLoadError(
|
|
||||||
f"Failed to load TIFF with tifffile: {self.path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract metadata
|
|
||||||
self._height, self._width = (
|
|
||||||
self._data.shape[:2]
|
|
||||||
if len(self._data.shape) >= 2
|
|
||||||
else (self._data.shape[0], 1)
|
|
||||||
)
|
|
||||||
self._channels = (
|
|
||||||
self._data.shape[2] if len(self._data.shape) == 3 else 1
|
|
||||||
)
|
|
||||||
self._format = self.path.suffix.lower().lstrip(".")
|
|
||||||
self._size_bytes = self.path.stat().st_size
|
|
||||||
self._dtype = self._data.dtype
|
|
||||||
|
|
||||||
# Load PIL version for compatibility
|
|
||||||
if self._channels == 1:
|
|
||||||
# Grayscale
|
|
||||||
self._pil_image = PILImage.fromarray(self._data)
|
|
||||||
else:
|
else:
|
||||||
# Multi-channel (RGB or RGBA)
|
# raise NotImplementedError("RGB is not implemented")
|
||||||
self._pil_image = PILImage.fromarray(self._data)
|
# Load with OpenCV (returns BGR format)
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Successfully loaded TIFF image: {self.path.name} "
|
|
||||||
f"({self._width}x{self._height}, {self._channels} channels, "
|
|
||||||
f"dtype={self._dtype}, {self._format.upper()})"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Load with OpenCV (returns BGR format) for non-TIFF images
|
|
||||||
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
|
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
|
||||||
|
|
||||||
if self._data is None:
|
if self._data is None:
|
||||||
raise ImageLoadError(
|
raise ImageLoadError(f"Failed to load image with OpenCV: {self.path}")
|
||||||
f"Failed to load image with OpenCV: {self.path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract metadata
|
# Extract metadata
|
||||||
|
# print(self._data.shape)
|
||||||
|
if len(self._data.shape) == 2:
|
||||||
self._height, self._width = self._data.shape[:2]
|
self._height, self._width = self._data.shape[:2]
|
||||||
self._channels = (
|
self._channels = 1
|
||||||
self._data.shape[2] if len(self._data.shape) == 3 else 1
|
else:
|
||||||
)
|
self._height, self._width = self._data.shape[1:]
|
||||||
|
self._channels = self._data.shape[0]
|
||||||
|
# self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1
|
||||||
self._format = self.path.suffix.lower().lstrip(".")
|
self._format = self.path.suffix.lower().lstrip(".")
|
||||||
self._size_bytes = self.path.stat().st_size
|
self._size_bytes = self.path.stat().st_size
|
||||||
self._dtype = self._data.dtype
|
self._dtype = self._data.dtype
|
||||||
|
|
||||||
# Load PIL version for compatibility (convert BGR to RGB)
|
if 0:
|
||||||
if self._channels == 3:
|
|
||||||
rgb_data = cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
|
|
||||||
self._pil_image = PILImage.fromarray(rgb_data)
|
|
||||||
elif self._channels == 4:
|
|
||||||
rgba_data = cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA)
|
|
||||||
self._pil_image = PILImage.fromarray(rgba_data)
|
|
||||||
else:
|
|
||||||
# Grayscale
|
|
||||||
self._pil_image = PILImage.fromarray(self._data)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Successfully loaded image: {self.path.name} "
|
f"Successfully loaded image: {self.path.name} "
|
||||||
f"({self._width}x{self._height}, {self._channels} channels, "
|
f"({self._width}x{self._height}, {self._channels} channels, "
|
||||||
@@ -172,18 +168,6 @@ class Image:
|
|||||||
raise ImageLoadError("Image data not available")
|
raise ImageLoadError("Image data not available")
|
||||||
return self._data
|
return self._data
|
||||||
|
|
||||||
@property
|
|
||||||
def pil_image(self) -> PILImage.Image:
|
|
||||||
"""
|
|
||||||
Get image data as PIL Image (RGB or grayscale).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
PIL Image object
|
|
||||||
"""
|
|
||||||
if self._pil_image is None:
|
|
||||||
raise ImageLoadError("PIL image not available")
|
|
||||||
return self._pil_image
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def width(self) -> int:
|
def width(self) -> int:
|
||||||
"""Get image width in pixels."""
|
"""Get image width in pixels."""
|
||||||
@@ -228,6 +212,7 @@ class Image:
|
|||||||
@property
|
@property
|
||||||
def dtype(self) -> np.dtype:
|
def dtype(self) -> np.dtype:
|
||||||
"""Get the data type of the image array."""
|
"""Get the data type of the image array."""
|
||||||
|
|
||||||
if self._dtype is None:
|
if self._dtype is None:
|
||||||
raise ImageLoadError("Image dtype not available")
|
raise ImageLoadError("Image dtype not available")
|
||||||
return self._dtype
|
return self._dtype
|
||||||
@@ -247,8 +232,10 @@ class Image:
|
|||||||
elif self._channels == 1:
|
elif self._channels == 1:
|
||||||
if self._dtype == np.uint16:
|
if self._dtype == np.uint16:
|
||||||
return QImage.Format_Grayscale16
|
return QImage.Format_Grayscale16
|
||||||
else:
|
elif self._dtype == np.uint8:
|
||||||
return QImage.Format_Grayscale8
|
return QImage.Format_Grayscale8
|
||||||
|
elif self._dtype == np.float32:
|
||||||
|
return QImage.Format_BGR30
|
||||||
else:
|
else:
|
||||||
raise ImageLoadError(f"Unsupported number of channels: {self._channels}")
|
raise ImageLoadError(f"Unsupported number of channels: {self._channels}")
|
||||||
|
|
||||||
@@ -259,12 +246,36 @@ class Image:
|
|||||||
Returns:
|
Returns:
|
||||||
Image data in RGB format as numpy array
|
Image data in RGB format as numpy array
|
||||||
"""
|
"""
|
||||||
if self._channels == 3:
|
if self.channels == 1:
|
||||||
return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
|
img = get_pseudo_rgb(self.data)
|
||||||
|
self._dtype = img.dtype
|
||||||
|
return img, True
|
||||||
|
|
||||||
|
elif self._channels == 3:
|
||||||
|
return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB), False
|
||||||
elif self._channels == 4:
|
elif self._channels == 4:
|
||||||
return cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA)
|
return cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA), False
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return self._data
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# else:
|
||||||
|
# return self._data
|
||||||
|
|
||||||
|
def get_qt_rgb(self) -> np.ascontiguousarray:
|
||||||
|
# we keep data as (C, H, W)
|
||||||
|
_img, pseudo = self.get_rgb()
|
||||||
|
|
||||||
|
if pseudo:
|
||||||
|
img = np.zeros((self.height, self.width, 4), dtype=np.float32)
|
||||||
|
img[..., 0] = _img[0] # R gradient
|
||||||
|
img[..., 1] = _img[1] # G gradient
|
||||||
|
img[..., 2] = _img[2] # B constant
|
||||||
|
img[..., 3] = 1.0 # A = 1.0 (opaque)
|
||||||
|
|
||||||
|
return np.ascontiguousarray(img)
|
||||||
|
else:
|
||||||
|
return np.ascontiguousarray(_img)
|
||||||
|
|
||||||
def get_grayscale(self) -> np.ndarray:
|
def get_grayscale(self) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
@@ -318,49 +329,26 @@ class Image:
|
|||||||
"""
|
"""
|
||||||
return self._channels >= 3
|
return self._channels >= 3
|
||||||
|
|
||||||
def to_normalized_float32(self) -> np.ndarray:
|
def save(self, path: Union[str, Path], pseudo_rgb: bool = True) -> None:
|
||||||
"""
|
|
||||||
Convert image data to normalized float32 in range [0, 1].
|
|
||||||
|
|
||||||
For 16-bit images, this properly scales the full dynamic range.
|
if self.channels == 1:
|
||||||
For 8-bit images, divides by 255.
|
if pseudo_rgb:
|
||||||
Already float images are clipped to [0, 1].
|
img = get_pseudo_rgb(self.data)
|
||||||
|
print("Image.save", img.shape)
|
||||||
Returns:
|
|
||||||
Normalized image data as float32 numpy array [0, 1]
|
|
||||||
"""
|
|
||||||
data = self._data.astype(np.float32)
|
|
||||||
|
|
||||||
if self._dtype == np.uint16:
|
|
||||||
# 16-bit: normalize by max value (65535)
|
|
||||||
data = data / 65535.0
|
|
||||||
elif self._dtype == np.uint8:
|
|
||||||
# 8-bit: normalize by 255
|
|
||||||
data = data / 255.0
|
|
||||||
elif np.issubdtype(self._dtype, np.floating):
|
|
||||||
# Already float, just clip to [0, 1]
|
|
||||||
data = np.clip(data, 0.0, 1.0)
|
|
||||||
else:
|
else:
|
||||||
# Other integer types: use dtype info
|
img = np.repeat(self.data, 3, axis=2)
|
||||||
if np.issubdtype(self._dtype, np.integer):
|
|
||||||
max_val = np.iinfo(self._dtype).max
|
|
||||||
data = data / float(max_val)
|
|
||||||
else:
|
|
||||||
# Unknown type: attempt min-max normalization
|
|
||||||
min_val = data.min()
|
|
||||||
max_val = data.max()
|
|
||||||
if max_val > min_val:
|
|
||||||
data = (data - min_val) / (max_val - min_val)
|
|
||||||
else:
|
|
||||||
data = np.zeros_like(data)
|
|
||||||
|
|
||||||
return np.clip(data, 0.0, 1.0)
|
else:
|
||||||
|
raise NotImplementedError("Only grayscale images are supported for now.")
|
||||||
|
|
||||||
|
imwrite(path, data=img)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
"""String representation of the Image object."""
|
"""String representation of the Image object."""
|
||||||
return (
|
return (
|
||||||
f"Image(path='{self.path.name}', "
|
f"Image(path='{self.path.name}', "
|
||||||
f"shape=({self._width}x{self._height}x{self._channels}), "
|
# Display as HxWxC to match the conventional NumPy shape semantics.
|
||||||
|
f"shape=({self._height}x{self._width}x{self._channels}), "
|
||||||
f"format={self._format}, "
|
f"format={self._format}, "
|
||||||
f"size={self.size_mb:.2f}MB)"
|
f"size={self.size_mb:.2f}MB)"
|
||||||
)
|
)
|
||||||
@@ -370,38 +358,13 @@ class Image:
|
|||||||
return self.__repr__()
|
return self.__repr__()
|
||||||
|
|
||||||
|
|
||||||
def convert_grayscale_to_rgb_preserve_range(
|
if __name__ == "__main__":
|
||||||
pil_image: PILImage.Image,
|
import argparse
|
||||||
) -> PILImage.Image:
|
|
||||||
"""Convert a single-channel PIL image to RGB while preserving dynamic range.
|
|
||||||
|
|
||||||
Args:
|
parser = argparse.ArgumentParser()
|
||||||
pil_image: Single-channel PIL image (e.g., 16-bit grayscale).
|
parser.add_argument("--path", type=str, required=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
Returns:
|
img = Image(args.path)
|
||||||
PIL Image in RGB mode with intensities normalized to 0-255.
|
img.save(args.path + "test.tif")
|
||||||
"""
|
print(img)
|
||||||
|
|
||||||
if pil_image.mode == "RGB":
|
|
||||||
return pil_image
|
|
||||||
|
|
||||||
grayscale = np.array(pil_image)
|
|
||||||
if grayscale.ndim == 3:
|
|
||||||
grayscale = grayscale[:, :, 0]
|
|
||||||
|
|
||||||
original_dtype = grayscale.dtype
|
|
||||||
grayscale = grayscale.astype(np.float32)
|
|
||||||
|
|
||||||
if grayscale.size == 0:
|
|
||||||
return PILImage.new("RGB", pil_image.size, color=(0, 0, 0))
|
|
||||||
|
|
||||||
if np.issubdtype(original_dtype, np.integer):
|
|
||||||
denom = float(max(np.iinfo(original_dtype).max, 1))
|
|
||||||
else:
|
|
||||||
max_val = float(grayscale.max())
|
|
||||||
denom = max(max_val, 1.0)
|
|
||||||
|
|
||||||
grayscale = np.clip(grayscale / denom, 0.0, 1.0)
|
|
||||||
grayscale_u8 = (grayscale * 255.0).round().astype(np.uint8)
|
|
||||||
rgb_arr = np.repeat(grayscale_u8[:, :, None], 3, axis=2)
|
|
||||||
return PILImage.fromarray(rgb_arr, mode="RGB")
|
|
||||||
|
|||||||
@@ -12,23 +12,38 @@ class UT:
|
|||||||
Operetta files along with rois drawn in ImageJ
|
Operetta files along with rois drawn in ImageJ
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, roifile_fn: Path):
|
def __init__(self, roifile_fn: Path, no_labels: bool):
|
||||||
self.roifile_fn = roifile_fn
|
self.roifile_fn = roifile_fn
|
||||||
|
print("is file", self.roifile_fn.is_file())
|
||||||
|
self.rois = None
|
||||||
|
if no_labels:
|
||||||
self.rois = ImagejRoi.fromfile(self.roifile_fn)
|
self.rois = ImagejRoi.fromfile(self.roifile_fn)
|
||||||
self.stem = self.roifile_fn.stem.strip("-RoiSet")
|
print(self.roifile_fn.stem)
|
||||||
|
print(self.roifile_fn.parent.parts[-1])
|
||||||
|
if "Roi-" in self.roifile_fn.stem:
|
||||||
|
self.stem = self.roifile_fn.stem.split("Roi-")[1]
|
||||||
|
else:
|
||||||
|
self.stem = self.roifile_fn.parent.parts[-1]
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.roifile_fn = roifile_fn / roifile_fn.parts[-1]
|
||||||
|
self.stem = self.roifile_fn.stem
|
||||||
|
|
||||||
|
print(self.roifile_fn)
|
||||||
|
|
||||||
|
print(self.stem)
|
||||||
self.image, self.image_props = self._load_images()
|
self.image, self.image_props = self._load_images()
|
||||||
|
|
||||||
def _load_images(self):
|
def _load_images(self):
|
||||||
"""Loading sequence of tif files
|
"""Loading sequence of tif files
|
||||||
array sequence is CZYX
|
array sequence is CZYX
|
||||||
"""
|
"""
|
||||||
print(self.roifile_fn.parent, self.stem)
|
print("Loading images:", self.roifile_fn.parent, self.stem)
|
||||||
fns = list(self.roifile_fn.parent.glob(f"{self.stem}*.tif*"))
|
fns = list(self.roifile_fn.parent.glob(f"{self.stem.lower()}*.tif*"))
|
||||||
stems = [fn.stem.split(self.stem)[-1] for fn in fns]
|
stems = [fn.stem.split(self.stem)[-1] for fn in fns]
|
||||||
n_ch = len(set([stem.split("-ch")[-1].split("t")[0] for stem in stems]))
|
n_ch = len(set([stem.split("-ch")[-1].split("t")[0] for stem in stems]))
|
||||||
n_p = len(set([stem.split("-")[0] for stem in stems]))
|
n_p = len(set([stem.split("-")[0] for stem in stems]))
|
||||||
n_t = len(set([stem.split("t")[1] for stem in stems]))
|
n_t = len(set([stem.split("t")[1] for stem in stems]))
|
||||||
print(n_ch, n_p, n_t)
|
|
||||||
|
|
||||||
with TiffFile(fns[0]) as tif:
|
with TiffFile(fns[0]) as tif:
|
||||||
img = tif.asarray()
|
img = tif.asarray()
|
||||||
@@ -42,6 +57,7 @@ class UT:
|
|||||||
"height": h,
|
"height": h,
|
||||||
"dtype": dtype,
|
"dtype": dtype,
|
||||||
}
|
}
|
||||||
|
print("Image props", self.image_props)
|
||||||
|
|
||||||
image_stack = np.zeros((n_ch, n_p, w, h), dtype=dtype)
|
image_stack = np.zeros((n_ch, n_p, w, h), dtype=dtype)
|
||||||
for fn in fns:
|
for fn in fns:
|
||||||
@@ -49,7 +65,7 @@ class UT:
|
|||||||
img = tif.asarray()
|
img = tif.asarray()
|
||||||
stem = fn.stem.split(self.stem)[-1]
|
stem = fn.stem.split(self.stem)[-1]
|
||||||
ch = int(stem.split("-ch")[-1].split("t")[0])
|
ch = int(stem.split("-ch")[-1].split("t")[0])
|
||||||
p = int(stem.split("-")[0].lstrip("p"))
|
p = int(stem.split("-")[0].split("p")[1])
|
||||||
t = int(stem.split("t")[1])
|
t = int(stem.split("t")[1])
|
||||||
print(fn.stem, "ch", ch, "p", p, "t", t)
|
print(fn.stem, "ch", ch, "p", p, "t", t)
|
||||||
image_stack[ch - 1, p - 1] = img
|
image_stack[ch - 1, p - 1] = img
|
||||||
@@ -82,10 +98,19 @@ class UT:
|
|||||||
):
|
):
|
||||||
"""Export rois to a file"""
|
"""Export rois to a file"""
|
||||||
with open(path / subfolder / f"{self.stem}.txt", "w") as f:
|
with open(path / subfolder / f"{self.stem}.txt", "w") as f:
|
||||||
for roi in self.rois:
|
for i, roi in enumerate(self.rois):
|
||||||
# TODO add image coordinates normalization
|
rc = roi.subpixel_coordinates
|
||||||
coords = ""
|
if rc is None:
|
||||||
for x, y in roi.subpixel_coordinates:
|
print(f"No coordinates: {self.roifile_fn}, element {i}, out of {len(self.rois)}")
|
||||||
|
continue
|
||||||
|
xmn, ymn = rc.min(axis=0)
|
||||||
|
xmx, ymx = rc.max(axis=0)
|
||||||
|
xc = (xmn + xmx) / 2
|
||||||
|
yc = (ymn + ymx) / 2
|
||||||
|
bw = xmx - xmn
|
||||||
|
bh = ymx - ymn
|
||||||
|
coords = f"{xc/self.width} {yc/self.height} {bw/self.width} {bh/self.height} "
|
||||||
|
for x, y in rc:
|
||||||
coords += f"{x/self.width} {y/self.height} "
|
coords += f"{x/self.width} {y/self.height} "
|
||||||
f.write(f"{class_index} {coords}\n")
|
f.write(f"{class_index} {coords}\n")
|
||||||
|
|
||||||
@@ -104,6 +129,7 @@ class UT:
|
|||||||
self.image = np.max(self.image[channel], axis=0)
|
self.image = np.max(self.image[channel], axis=0)
|
||||||
print(self.image.shape)
|
print(self.image.shape)
|
||||||
|
|
||||||
|
print(path / subfolder / f"{self.stem}.tif")
|
||||||
with TiffWriter(path / subfolder / f"{self.stem}.tif") as tif:
|
with TiffWriter(path / subfolder / f"{self.stem}.tif") as tif:
|
||||||
tif.write(self.image)
|
tif.write(self.image)
|
||||||
|
|
||||||
@@ -112,11 +138,31 @@ if __name__ == "__main__":
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("input", type=Path)
|
parser.add_argument("-i", "--input", nargs="*", type=Path)
|
||||||
parser.add_argument("output", type=Path)
|
parser.add_argument("-o", "--output", type=Path)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-labels",
|
||||||
|
action="store_false",
|
||||||
|
help="Source does not have labels, export only images",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
for rfn in args.input.glob("*.zip"):
|
# print(args)
|
||||||
ut = UT(rfn)
|
# aa
|
||||||
|
|
||||||
|
for path in args.input:
|
||||||
|
print("Path:", path)
|
||||||
|
if not args.no_labels:
|
||||||
|
print("No labels")
|
||||||
|
ut = UT(path, args.no_labels)
|
||||||
|
ut.export_image(args.output, plane_mode="max projection", channel=0)
|
||||||
|
|
||||||
|
else:
|
||||||
|
for rfn in Path(path).glob("*.zip"):
|
||||||
|
# if Path(path).suffix == ".zip":
|
||||||
|
print("Roi FN:", rfn)
|
||||||
|
ut = UT(rfn, args.no_labels)
|
||||||
ut.export_rois(args.output, class_index=0)
|
ut.export_rois(args.output, class_index=0)
|
||||||
ut.export_image(args.output, plane_mode="max projection", channel=0)
|
ut.export_image(args.output, plane_mode="max projection", channel=0)
|
||||||
|
|
||||||
|
print()
|
||||||
|
|||||||
368
src/utils/image_splitter.py
Normal file
368
src/utils/image_splitter.py
Normal file
@@ -0,0 +1,368 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from tifffile import imread, imwrite
|
||||||
|
from shapely.geometry import LineString
|
||||||
|
from copy import deepcopy
|
||||||
|
from scipy.ndimage import zoom
|
||||||
|
|
||||||
|
|
||||||
|
# debug
|
||||||
|
from src.utils.image import Image
|
||||||
|
from show_yolo_seg import draw_annotations
|
||||||
|
|
||||||
|
import pylab as plt
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
|
class Label:
|
||||||
|
def __init__(self, yolo_annotation: str):
|
||||||
|
class_id, bbox, polygon = self.parse_yolo_annotation(yolo_annotation)
|
||||||
|
self.class_id = class_id
|
||||||
|
self.bbox = bbox
|
||||||
|
self.polygon = polygon
|
||||||
|
|
||||||
|
def parse_yolo_annotation(self, yolo_annotation: str):
|
||||||
|
class_id, *coords = yolo_annotation.split()
|
||||||
|
class_id = int(class_id)
|
||||||
|
bbox = np.array(coords[:4], dtype=np.float32)
|
||||||
|
polygon = np.array(coords[4:], dtype=np.float32).reshape(-1, 2) if len(coords) > 4 else None
|
||||||
|
if not any(np.isclose(polygon[0], polygon[-1])):
|
||||||
|
polygon = np.vstack([polygon, polygon[0]])
|
||||||
|
return class_id, bbox, polygon
|
||||||
|
|
||||||
|
def offset_label(
|
||||||
|
self,
|
||||||
|
img_w,
|
||||||
|
img_h,
|
||||||
|
distance: float = 1.0,
|
||||||
|
cap_style: int = 2,
|
||||||
|
join_style: int = 2,
|
||||||
|
):
|
||||||
|
if self.polygon is None:
|
||||||
|
self.bbox = np.array(
|
||||||
|
[
|
||||||
|
self.bbox[0] - distance if self.bbox[0] - distance > 0 else 0,
|
||||||
|
self.bbox[1] - distance if self.bbox[1] - distance > 0 else 0,
|
||||||
|
self.bbox[2] + distance if self.bbox[2] + distance < 1 else 1,
|
||||||
|
self.bbox[3] + distance if self.bbox[3] + distance < 1 else 1,
|
||||||
|
],
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
return self.bbox
|
||||||
|
|
||||||
|
def coords_are_normalized(coords):
|
||||||
|
# If every coordinate is between 0 and 1 (inclusive-ish), assume normalized
|
||||||
|
print(coords)
|
||||||
|
# if not coords:
|
||||||
|
# return False
|
||||||
|
return all(max(coords.flatten)) <= 1.001
|
||||||
|
|
||||||
|
def poly_to_pts(coords, img_w, img_h):
|
||||||
|
# coords: [x1 y1 x2 y2 ...] either normalized or absolute
|
||||||
|
# if coords_are_normalized(coords):
|
||||||
|
coords = [coords[i] * (img_w if i % 2 == 0 else img_h) for i in range(len(coords))]
|
||||||
|
pts = np.array(coords, dtype=np.int32).reshape(-1, 2)
|
||||||
|
return pts
|
||||||
|
|
||||||
|
pts = poly_to_pts(self.polygon, img_w, img_h)
|
||||||
|
line = LineString(pts)
|
||||||
|
# Buffer distance in pixels
|
||||||
|
buffered = line.buffer(distance=distance, cap_style=cap_style, join_style=join_style)
|
||||||
|
self.polygon = np.array(buffered.exterior.coords, dtype=np.float32) / (img_w, img_h)
|
||||||
|
xmn, ymn = self.polygon.min(axis=0)
|
||||||
|
xmx, ymx = self.polygon.max(axis=0)
|
||||||
|
xc = (xmn + xmx) / 2
|
||||||
|
yc = (ymn + ymx) / 2
|
||||||
|
bw = xmx - xmn
|
||||||
|
bh = ymx - ymn
|
||||||
|
self.bbox = np.array([xc, yc, bw, bh], dtype=np.float32)
|
||||||
|
|
||||||
|
return self.bbox, self.polygon
|
||||||
|
|
||||||
|
def translate(self, x, y, scale_x, scale_y):
|
||||||
|
self.bbox[0] -= x
|
||||||
|
self.bbox[0] *= scale_x
|
||||||
|
self.bbox[1] -= y
|
||||||
|
self.bbox[1] *= scale_y
|
||||||
|
self.bbox[2] *= scale_x
|
||||||
|
self.bbox[3] *= scale_y
|
||||||
|
if self.polygon is not None:
|
||||||
|
self.polygon[:, 0] -= x
|
||||||
|
self.polygon[:, 0] *= scale_x
|
||||||
|
self.polygon[:, 1] -= y
|
||||||
|
self.polygon[:, 1] *= scale_y
|
||||||
|
|
||||||
|
def in_range(self, hrange, wrange):
|
||||||
|
xc, yc, h, w = self.bbox
|
||||||
|
x1 = xc - w / 2
|
||||||
|
y1 = yc - h / 2
|
||||||
|
x2 = xc + w / 2
|
||||||
|
y2 = yc + h / 2
|
||||||
|
truth_val = (
|
||||||
|
xc >= wrange[0]
|
||||||
|
and x1 <= wrange[1]
|
||||||
|
and x2 >= wrange[0]
|
||||||
|
and x2 <= wrange[1]
|
||||||
|
and y1 >= hrange[0]
|
||||||
|
and y1 <= hrange[1]
|
||||||
|
and y2 >= hrange[0]
|
||||||
|
and y2 <= hrange[1]
|
||||||
|
)
|
||||||
|
|
||||||
|
print(x1, x2, wrange, y1, y2, hrange, truth_val)
|
||||||
|
return truth_val
|
||||||
|
|
||||||
|
def to_string(self, bbox: list = None, polygon: list = None):
|
||||||
|
coords = ""
|
||||||
|
if bbox is None:
|
||||||
|
bbox = self.bbox
|
||||||
|
# coords += " ".join([f"{x:.6f}" for x in self.bbox])
|
||||||
|
if polygon is None:
|
||||||
|
polygon = self.polygon
|
||||||
|
if self.polygon is not None:
|
||||||
|
coords += " " + " ".join([f"{x:.6f} {y:.6f}" for x, y in self.polygon])
|
||||||
|
return f"{self.class_id} {coords}"
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"Class: {self.class_id}, BBox: {self.bbox}, Polygon: {self.polygon}"
|
||||||
|
|
||||||
|
|
||||||
|
class YoloLabelReader:
|
||||||
|
def __init__(self, label_path: Path):
|
||||||
|
self.label_path = label_path
|
||||||
|
self.labels = self._read_labels()
|
||||||
|
|
||||||
|
def _read_labels(self):
|
||||||
|
with open(self.label_path, "r") as f:
|
||||||
|
labels = [Label(line) for line in f.readlines()]
|
||||||
|
|
||||||
|
return labels
|
||||||
|
|
||||||
|
def get_labels(self, hrange, wrange):
|
||||||
|
"""hrange and wrange are tuples of (start, end) normalized to [0, 1]"""
|
||||||
|
labels = []
|
||||||
|
# print(hrange, wrange)
|
||||||
|
for lbl in self.labels:
|
||||||
|
# print(lbl)
|
||||||
|
if lbl.in_range(hrange, wrange):
|
||||||
|
labels.append(lbl)
|
||||||
|
return labels if len(labels) > 0 else None
|
||||||
|
|
||||||
|
def __get_item__(self, index):
|
||||||
|
return self.labels[index]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.labels)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.labels)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageSplitter:
|
||||||
|
def __init__(self, image_path: Path, label_path: Path):
|
||||||
|
self.image = imread(image_path)
|
||||||
|
self.image_path = image_path
|
||||||
|
self.label_path = label_path
|
||||||
|
if not label_path.exists():
|
||||||
|
print(f"Label file {label_path} not found")
|
||||||
|
self.labels = None
|
||||||
|
else:
|
||||||
|
self.labels = YoloLabelReader(label_path)
|
||||||
|
|
||||||
|
def split_into_tiles(self, patch_size: tuple = (2, 2)):
|
||||||
|
"""Split image into patches of size patch_size"""
|
||||||
|
hstep, wstep = (
|
||||||
|
self.image.shape[0] // patch_size[0],
|
||||||
|
self.image.shape[1] // patch_size[1],
|
||||||
|
)
|
||||||
|
h, w = self.image.shape[:2]
|
||||||
|
|
||||||
|
for i in range(patch_size[0]):
|
||||||
|
for j in range(patch_size[1]):
|
||||||
|
metadata = {
|
||||||
|
"image_path": str(self.image_path),
|
||||||
|
"label_path": str(self.label_path),
|
||||||
|
"tile_section": f"{i}, {j}",
|
||||||
|
"tile_size": f"{hstep}, {wstep}",
|
||||||
|
"patch_size": f"{patch_size[0]}, {patch_size[1]}",
|
||||||
|
}
|
||||||
|
tile_reference = f"i{i}j{j}"
|
||||||
|
hrange = (i * hstep / h, (i + 1) * hstep / h)
|
||||||
|
wrange = (j * wstep / w, (j + 1) * wstep / w)
|
||||||
|
tile = self.image[i * hstep : (i + 1) * hstep, j * wstep : (j + 1) * wstep]
|
||||||
|
|
||||||
|
labels = None
|
||||||
|
if self.labels is not None:
|
||||||
|
labels = deepcopy(self.labels.get_labels(hrange, wrange))
|
||||||
|
print(id(labels))
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
print(hrange[0], wrange[0])
|
||||||
|
for l in labels:
|
||||||
|
print(l.bbox)
|
||||||
|
[l.translate(wrange[0], hrange[0], 2, 2) for l in labels]
|
||||||
|
print("translated")
|
||||||
|
for l in labels:
|
||||||
|
print(l.bbox)
|
||||||
|
|
||||||
|
# print(labels)
|
||||||
|
yield tile_reference, tile, labels, metadata
|
||||||
|
|
||||||
|
def split_respective_to_label(self, padding: int = 67):
|
||||||
|
if self.labels is None:
|
||||||
|
raise ValueError("No labels found. Only images having labels can be split.")
|
||||||
|
|
||||||
|
for i, label in enumerate(self.labels):
|
||||||
|
tile_reference = f"_lbl-{i+1:02d}"
|
||||||
|
# print(label.bbox)
|
||||||
|
metadata = {"image_path": str(self.image_path), "label_path": str(self.label_path), "label_index": str(i)}
|
||||||
|
|
||||||
|
xc_norm, yc_norm, h_norm, w_norm = label.bbox # normalized coords
|
||||||
|
xc, yc, h, w = [
|
||||||
|
int(np.round(f))
|
||||||
|
for f in [
|
||||||
|
xc_norm * self.image.shape[1],
|
||||||
|
yc_norm * self.image.shape[0],
|
||||||
|
h_norm * self.image.shape[0],
|
||||||
|
w_norm * self.image.shape[1],
|
||||||
|
]
|
||||||
|
] # image coords
|
||||||
|
|
||||||
|
# print("img coords:", xc, yc, h, w)
|
||||||
|
pad_xneg = padding + 1 # int(w / 2) + padding
|
||||||
|
pad_xpos = padding # int(w / 2) + padding
|
||||||
|
pad_yneg = padding + 1 # int(h / 2) + padding
|
||||||
|
pad_ypos = padding # int(h / 2) + padding
|
||||||
|
if xc - pad_xneg < 0:
|
||||||
|
pad_xneg = xc
|
||||||
|
if pad_xpos + xc > self.image.shape[1]:
|
||||||
|
pad_xpos = self.image.shape[1] - xc
|
||||||
|
if yc - pad_yneg < 0:
|
||||||
|
pad_yneg = yc
|
||||||
|
if pad_ypos + yc > self.image.shape[0]:
|
||||||
|
pad_ypos = self.image.shape[0] - yc
|
||||||
|
|
||||||
|
# print("pads:", pad_xneg, pad_xpos, pad_yneg, pad_ypos)
|
||||||
|
|
||||||
|
tile = self.image[
|
||||||
|
yc - pad_yneg : yc + pad_ypos,
|
||||||
|
xc - pad_xneg : xc + pad_xpos,
|
||||||
|
]
|
||||||
|
ny, nx = tile.shape
|
||||||
|
x_offset = pad_xneg
|
||||||
|
y_offset = pad_yneg
|
||||||
|
|
||||||
|
# print("tile shape:", tile.shape)
|
||||||
|
|
||||||
|
yolo_annotation = f"{label.class_id} " # {x_offset/nx} {y_offset/ny} {h_norm} {w_norm} "
|
||||||
|
yolo_annotation += " ".join(
|
||||||
|
[
|
||||||
|
f"{(x*self.image.shape[1]-(xc - x_offset))/nx:.6f} {(y*self.image.shape[0]-(yc-y_offset))/ny:.6f}"
|
||||||
|
for x, y in label.polygon
|
||||||
|
]
|
||||||
|
)
|
||||||
|
print(yolo_annotation)
|
||||||
|
new_label = Label(yolo_annotation=yolo_annotation)
|
||||||
|
|
||||||
|
yield tile_reference, tile, [new_label], metadata
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
|
||||||
|
if args.output:
|
||||||
|
args.output.mkdir(exist_ok=True, parents=True)
|
||||||
|
(args.output / "images").mkdir(exist_ok=True)
|
||||||
|
(args.output / "images-zoomed").mkdir(exist_ok=True)
|
||||||
|
(args.output / "labels").mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
for image_path in (args.input / "images").glob("*.tif"):
|
||||||
|
data = ImageSplitter(
|
||||||
|
image_path=image_path,
|
||||||
|
label_path=(args.input / "labels" / image_path.stem).with_suffix(".txt"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.split_around_label:
|
||||||
|
data = data.split_respective_to_label(padding=args.padding)
|
||||||
|
else:
|
||||||
|
data = data.split_into_tiles(patch_size=args.patch_size)
|
||||||
|
|
||||||
|
for tile_reference, tile, labels, metadata in data:
|
||||||
|
print()
|
||||||
|
print(tile_reference, tile.shape, labels, metadata) # len(labels) if labels else None)
|
||||||
|
|
||||||
|
# { debug
|
||||||
|
debug = False
|
||||||
|
if debug:
|
||||||
|
plt.figure(figsize=(10, 10 * tile.shape[0] / tile.shape[1]))
|
||||||
|
if labels is None:
|
||||||
|
plt.imshow(tile, cmap="gray")
|
||||||
|
plt.axis("off")
|
||||||
|
plt.title(f"{image_path.name} ({tile_reference})")
|
||||||
|
plt.show()
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(labels[0].bbox)
|
||||||
|
# Draw annotations
|
||||||
|
out = draw_annotations(
|
||||||
|
cv2.cvtColor((tile / tile.max() * 255).astype(np.uint8), cv2.COLOR_GRAY2BGR),
|
||||||
|
[l.to_string() for l in labels],
|
||||||
|
alpha=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert BGR -> RGB for matplotlib display
|
||||||
|
out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
|
||||||
|
plt.imshow(out_rgb)
|
||||||
|
plt.axis("off")
|
||||||
|
plt.title(f"{image_path.name} ({tile_reference})")
|
||||||
|
plt.show()
|
||||||
|
# } debug
|
||||||
|
|
||||||
|
if args.output:
|
||||||
|
# imwrite(args.output / "images" / f"{image_path.stem}_{tile_reference}.tif", tile, metadata=metadata)
|
||||||
|
scale = 5
|
||||||
|
tile_zoomed = zoom(tile, zoom=scale)
|
||||||
|
metadata["scale"] = scale
|
||||||
|
imwrite(
|
||||||
|
args.output / "images" / f"{image_path.stem}_{tile_reference}.tif",
|
||||||
|
tile_zoomed,
|
||||||
|
metadata=metadata,
|
||||||
|
imagej=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
with open(args.output / "labels" / f"{image_path.stem}_{tile_reference}.txt", "w") as f:
|
||||||
|
for label in labels:
|
||||||
|
# label.offset_label(tile.shape[1], tile.shape[0])
|
||||||
|
f.write(label.to_string() + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("-i", "--input", type=Path)
|
||||||
|
parser.add_argument("-o", "--output", type=Path)
|
||||||
|
parser.add_argument(
|
||||||
|
"-p",
|
||||||
|
"--patch-size",
|
||||||
|
nargs=2,
|
||||||
|
type=int,
|
||||||
|
default=[2, 2],
|
||||||
|
help="Number of patches along height and width, rows and columns, respectively",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-sal",
|
||||||
|
"--split-around-label",
|
||||||
|
action="store_true",
|
||||||
|
help="If enabled, the image will be split around the label and for each label, a separate image will be created.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--padding",
|
||||||
|
type=int,
|
||||||
|
default=67,
|
||||||
|
help="Padding around the label when splitting around the label.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
||||||
1
src/utils/show_yolo_seg.py
Symbolic link
1
src/utils/show_yolo_seg.py
Symbolic link
@@ -0,0 +1 @@
|
|||||||
|
../../tests/show_yolo_seg.py
|
||||||
@@ -1,561 +0,0 @@
|
|||||||
"""
|
|
||||||
Custom YOLO training with on-the-fly float32 conversion for 16-bit grayscale images.
|
|
||||||
|
|
||||||
This module provides a custom dataset class and training function that:
|
|
||||||
1. Load 16-bit TIFF images directly with tifffile (no PIL/cv2)
|
|
||||||
2. Convert to float32 [0-1] on-the-fly (no data loss)
|
|
||||||
3. Replicate grayscale to 3-channel RGB in memory
|
|
||||||
4. Use custom training loop to bypass Ultralytics' dataset infrastructure
|
|
||||||
5. No disk caching required
|
|
||||||
"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import tifffile
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional, Dict, Any, List, Tuple
|
|
||||||
from ultralytics import YOLO
|
|
||||||
import yaml
|
|
||||||
import time
|
|
||||||
|
|
||||||
from src.utils.logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class Float32YOLODataset(Dataset):
|
|
||||||
"""
|
|
||||||
Custom PyTorch dataset for YOLO that loads 16-bit grayscale TIFFs as float32 RGB.
|
|
||||||
|
|
||||||
This dataset:
|
|
||||||
- Loads with tifffile (not PIL/cv2)
|
|
||||||
- Converts uint16 → float32 [0-1] (preserves full dynamic range)
|
|
||||||
- Replicates grayscale to 3 channels
|
|
||||||
- Returns torch tensors in (C, H, W) format
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, images_dir: str, labels_dir: str, img_size: int = 640):
|
|
||||||
"""
|
|
||||||
Initialize dataset.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images_dir: Directory containing images
|
|
||||||
labels_dir: Directory containing YOLO label files (.txt)
|
|
||||||
img_size: Target image size (for reference, actual resizing done by model)
|
|
||||||
"""
|
|
||||||
self.images_dir = Path(images_dir)
|
|
||||||
self.labels_dir = Path(labels_dir)
|
|
||||||
self.img_size = img_size
|
|
||||||
|
|
||||||
# Find all image files
|
|
||||||
extensions = {".tif", ".tiff", ".png", ".jpg", ".jpeg", ".bmp"}
|
|
||||||
self.image_paths = sorted(
|
|
||||||
[
|
|
||||||
p
|
|
||||||
for p in self.images_dir.rglob("*")
|
|
||||||
if p.is_file() and p.suffix.lower() in extensions
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.image_paths:
|
|
||||||
raise ValueError(f"No images found in {images_dir}")
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Float32YOLODataset initialized with {len(self.image_paths)} images from {images_dir}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.image_paths)
|
|
||||||
|
|
||||||
def _read_image(self, img_path: Path) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Read image and convert to float32 [0-1] RGB.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
numpy array, shape (H, W, 3), dtype float32, range [0, 1]
|
|
||||||
"""
|
|
||||||
# Load image with tifffile
|
|
||||||
img = tifffile.imread(str(img_path))
|
|
||||||
|
|
||||||
# Convert to float32
|
|
||||||
img = img.astype(np.float32)
|
|
||||||
|
|
||||||
# Normalize if 16-bit (values > 1.5 indicates uint16)
|
|
||||||
if img.max() > 1.5:
|
|
||||||
img = img / 65535.0
|
|
||||||
|
|
||||||
# Ensure [0, 1] range
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
|
|
||||||
# Convert grayscale to RGB
|
|
||||||
if img.ndim == 2:
|
|
||||||
# H,W → H,W,3
|
|
||||||
img = np.repeat(img[..., None], 3, axis=2)
|
|
||||||
elif img.ndim == 3 and img.shape[2] == 1:
|
|
||||||
# H,W,1 → H,W,3
|
|
||||||
img = np.repeat(img, 3, axis=2)
|
|
||||||
|
|
||||||
return img # float32 (H, W, 3) in [0, 1]
|
|
||||||
|
|
||||||
def _parse_label(self, label_path: Path) -> List[np.ndarray]:
|
|
||||||
"""
|
|
||||||
Parse YOLO label file with variable-length rows (segmentation polygons).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of numpy arrays, one per annotation
|
|
||||||
"""
|
|
||||||
if not label_path.exists():
|
|
||||||
return []
|
|
||||||
|
|
||||||
labels = []
|
|
||||||
try:
|
|
||||||
with open(label_path, "r") as f:
|
|
||||||
for line in f:
|
|
||||||
line = line.strip()
|
|
||||||
if not line:
|
|
||||||
continue
|
|
||||||
# Parse space-separated values
|
|
||||||
values = line.split()
|
|
||||||
if len(values) >= 5: # At minimum: class_id x y w h
|
|
||||||
labels.append(
|
|
||||||
np.array([float(v) for v in values], dtype=np.float32)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error parsing label {label_path}: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
return labels
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, List[np.ndarray], str]:
|
|
||||||
"""
|
|
||||||
Get a single training sample.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (image_tensor, labels, filename)
|
|
||||||
- image_tensor: shape (3, H, W), dtype float32, range [0, 1]
|
|
||||||
- labels: list of numpy arrays with YOLO format labels (variable length for segmentation)
|
|
||||||
- filename: image filename
|
|
||||||
"""
|
|
||||||
img_path = self.image_paths[idx]
|
|
||||||
label_path = self.labels_dir / f"{img_path.stem}.txt"
|
|
||||||
|
|
||||||
# Load image as float32 RGB
|
|
||||||
img = self._read_image(img_path)
|
|
||||||
|
|
||||||
# Convert to tensor: (H, W, 3) → (3, H, W)
|
|
||||||
img_tensor = torch.from_numpy(img).permute(2, 0, 1).contiguous()
|
|
||||||
|
|
||||||
# Load labels (list of variable-length arrays for segmentation)
|
|
||||||
labels = self._parse_label(label_path)
|
|
||||||
|
|
||||||
return img_tensor, labels, img_path.name
|
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(
|
|
||||||
batch: List[Tuple[torch.Tensor, List[np.ndarray], str]],
|
|
||||||
) -> Tuple[torch.Tensor, List[List[np.ndarray]], List[str]]:
|
|
||||||
"""
|
|
||||||
Collate function for DataLoader.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch: List of (img_tensor, labels_list, filename) tuples
|
|
||||||
where labels_list is a list of variable-length numpy arrays
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (stacked_images, list_of_labels_lists, list_of_filenames)
|
|
||||||
"""
|
|
||||||
imgs = [b[0] for b in batch]
|
|
||||||
labels = [b[1] for b in batch] # Each element is a list of arrays
|
|
||||||
names = [b[2] for b in batch]
|
|
||||||
|
|
||||||
# Stack images - requires same H,W
|
|
||||||
# For different sizes, implement letterbox/resize in dataset
|
|
||||||
imgs_batch = torch.stack(imgs, dim=0)
|
|
||||||
|
|
||||||
return imgs_batch, labels, names
|
|
||||||
|
|
||||||
|
|
||||||
def train_with_float32_loader(
|
|
||||||
model_path: str,
|
|
||||||
data_yaml: str,
|
|
||||||
epochs: int = 100,
|
|
||||||
imgsz: int = 640,
|
|
||||||
batch: int = 16,
|
|
||||||
patience: int = 50,
|
|
||||||
save_dir: str = "data/models",
|
|
||||||
name: str = "custom_model",
|
|
||||||
callbacks: Optional[Dict] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Train YOLO model with custom Float32 dataset for 16-bit TIFF support.
|
|
||||||
|
|
||||||
Uses a custom training loop to bypass Ultralytics' dataset pipeline,
|
|
||||||
avoiding channel conversion issues.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path: Path to base model weights (.pt file)
|
|
||||||
data_yaml: Path to dataset YAML configuration
|
|
||||||
epochs: Number of training epochs
|
|
||||||
imgsz: Input image size
|
|
||||||
batch: Batch size
|
|
||||||
patience: Early stopping patience
|
|
||||||
save_dir: Directory to save trained model
|
|
||||||
name: Name for the training run
|
|
||||||
callbacks: Optional callback dictionary (for progress reporting)
|
|
||||||
**kwargs: Additional training arguments (lr0, freeze, device, etc.)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict with training results including model paths and metrics
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.info(f"Starting Float32 custom training: {name}")
|
|
||||||
logger.info(
|
|
||||||
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Parse data.yaml to get dataset paths
|
|
||||||
with open(data_yaml, "r") as f:
|
|
||||||
data_config = yaml.safe_load(f)
|
|
||||||
|
|
||||||
dataset_root = Path(data_config.get("path", Path(data_yaml).parent))
|
|
||||||
train_images = dataset_root / data_config.get("train", "train/images")
|
|
||||||
val_images = dataset_root / data_config.get("val", "val/images")
|
|
||||||
|
|
||||||
# Infer label directories
|
|
||||||
train_labels = train_images.parent / "labels"
|
|
||||||
val_labels = val_images.parent / "labels"
|
|
||||||
|
|
||||||
logger.info(f"Train images: {train_images}")
|
|
||||||
logger.info(f"Train labels: {train_labels}")
|
|
||||||
logger.info(f"Val images: {val_images}")
|
|
||||||
logger.info(f"Val labels: {val_labels}")
|
|
||||||
|
|
||||||
# Create datasets
|
|
||||||
train_dataset = Float32YOLODataset(
|
|
||||||
str(train_images), str(train_labels), img_size=imgsz
|
|
||||||
)
|
|
||||||
val_dataset = Float32YOLODataset(
|
|
||||||
str(val_images), str(val_labels), img_size=imgsz
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create data loaders
|
|
||||||
train_loader = DataLoader(
|
|
||||||
train_dataset,
|
|
||||||
batch_size=batch,
|
|
||||||
shuffle=True,
|
|
||||||
num_workers=4,
|
|
||||||
pin_memory=True,
|
|
||||||
collate_fn=collate_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
val_loader = DataLoader(
|
|
||||||
val_dataset,
|
|
||||||
batch_size=batch,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=2,
|
|
||||||
pin_memory=True,
|
|
||||||
collate_fn=collate_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load model
|
|
||||||
logger.info(f"Loading model from {model_path}")
|
|
||||||
ul_model = YOLO(model_path)
|
|
||||||
|
|
||||||
# Get PyTorch model
|
|
||||||
pt_model, loss_fn = _get_pytorch_model(ul_model)
|
|
||||||
|
|
||||||
# Setup device
|
|
||||||
device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
# Configure model args for loss function
|
|
||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
# Required args for segmentation loss
|
|
||||||
required_args = {
|
|
||||||
"overlap_mask": True,
|
|
||||||
"mask_ratio": 4,
|
|
||||||
"task": "segment",
|
|
||||||
"single_cls": False,
|
|
||||||
"box": 7.5,
|
|
||||||
"cls": 0.5,
|
|
||||||
"dfl": 1.5,
|
|
||||||
}
|
|
||||||
|
|
||||||
if not hasattr(pt_model, "args"):
|
|
||||||
# No args - create SimpleNamespace
|
|
||||||
pt_model.args = SimpleNamespace(**required_args)
|
|
||||||
elif isinstance(pt_model.args, dict):
|
|
||||||
# Args is dict - MUST convert to SimpleNamespace for attribute access
|
|
||||||
# The loss function uses model.args.overlap_mask (attribute access)
|
|
||||||
merged = {**pt_model.args, **required_args}
|
|
||||||
pt_model.args = SimpleNamespace(**merged)
|
|
||||||
logger.info(
|
|
||||||
"Converted model.args from dict to SimpleNamespace for loss function compatibility"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Args is SimpleNamespace or other - set attributes
|
|
||||||
for key, value in required_args.items():
|
|
||||||
if not hasattr(pt_model.args, key):
|
|
||||||
setattr(pt_model.args, key, value)
|
|
||||||
|
|
||||||
pt_model.to(device)
|
|
||||||
pt_model.train()
|
|
||||||
|
|
||||||
logger.info(f"Training on device: {device}")
|
|
||||||
logger.info(f"PyTorch model type: {type(pt_model)}")
|
|
||||||
logger.info(f"Model args configured for segmentation loss")
|
|
||||||
|
|
||||||
# Setup optimizer
|
|
||||||
lr0 = kwargs.get("lr0", 0.01)
|
|
||||||
optimizer = torch.optim.AdamW(pt_model.parameters(), lr=lr0)
|
|
||||||
|
|
||||||
# Training loop
|
|
||||||
save_path = Path(save_dir) / name
|
|
||||||
save_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
weights_dir = save_path / "weights"
|
|
||||||
weights_dir.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
best_loss = float("inf")
|
|
||||||
patience_counter = 0
|
|
||||||
|
|
||||||
for epoch in range(epochs):
|
|
||||||
epoch_start = time.time()
|
|
||||||
running_loss = 0.0
|
|
||||||
num_batches = 0
|
|
||||||
|
|
||||||
logger.info(f"Epoch {epoch+1}/{epochs} starting...")
|
|
||||||
|
|
||||||
for batch_idx, (imgs, labels_list, names) in enumerate(train_loader):
|
|
||||||
imgs = imgs.to(device) # (B, 3, H, W) float32
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
# Forward pass
|
|
||||||
try:
|
|
||||||
preds = pt_model(imgs)
|
|
||||||
except Exception as e:
|
|
||||||
# Try with labels
|
|
||||||
preds = pt_model(imgs, labels_list)
|
|
||||||
|
|
||||||
# Compute loss
|
|
||||||
# For Ultralytics models, the easiest approach is to construct a batch dict
|
|
||||||
# and call the model in training mode which returns preds + loss
|
|
||||||
batch_dict = {
|
|
||||||
"img": imgs, # Already on device
|
|
||||||
"batch_idx": (
|
|
||||||
torch.cat(
|
|
||||||
[
|
|
||||||
torch.full((len(lab),), i, dtype=torch.long)
|
|
||||||
for i, lab in enumerate(labels_list)
|
|
||||||
]
|
|
||||||
).to(device)
|
|
||||||
if any(len(lab) > 0 for lab in labels_list)
|
|
||||||
else torch.tensor([], dtype=torch.long, device=device)
|
|
||||||
),
|
|
||||||
"cls": (
|
|
||||||
torch.cat(
|
|
||||||
[
|
|
||||||
torch.from_numpy(lab[:, 0:1])
|
|
||||||
for lab in labels_list
|
|
||||||
if len(lab) > 0
|
|
||||||
]
|
|
||||||
).to(device)
|
|
||||||
if any(len(lab) > 0 for lab in labels_list)
|
|
||||||
else torch.tensor([], dtype=torch.float32, device=device)
|
|
||||||
),
|
|
||||||
"bboxes": (
|
|
||||||
torch.cat(
|
|
||||||
[
|
|
||||||
torch.from_numpy(lab[:, 1:5])
|
|
||||||
for lab in labels_list
|
|
||||||
if len(lab) > 0
|
|
||||||
]
|
|
||||||
).to(device)
|
|
||||||
if any(len(lab) > 0 for lab in labels_list)
|
|
||||||
else torch.tensor([], dtype=torch.float32, device=device)
|
|
||||||
),
|
|
||||||
"ori_shape": (imgs.shape[2], imgs.shape[3]), # H, W
|
|
||||||
"resized_shape": (imgs.shape[2], imgs.shape[3]),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add masks if segmentation labels exist
|
|
||||||
if any(len(lab) > 5 for lab in labels_list if len(lab) > 0):
|
|
||||||
masks = []
|
|
||||||
for lab in labels_list:
|
|
||||||
if len(lab) > 0 and lab.shape[1] > 5:
|
|
||||||
# Has segmentation points
|
|
||||||
masks.append(torch.from_numpy(lab[:, 5:]))
|
|
||||||
if masks:
|
|
||||||
batch_dict["masks"] = masks
|
|
||||||
|
|
||||||
# Call model loss (it will compute loss internally)
|
|
||||||
try:
|
|
||||||
loss_output = pt_model.loss(batch_dict, preds)
|
|
||||||
if isinstance(loss_output, (tuple, list)):
|
|
||||||
loss = loss_output[0]
|
|
||||||
else:
|
|
||||||
loss = loss_output
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Model loss computation failed: {e}")
|
|
||||||
# Last resort: maybe preds is already a dict with 'loss'
|
|
||||||
if isinstance(preds, dict) and "loss" in preds:
|
|
||||||
loss = preds["loss"]
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Cannot compute loss: {e}")
|
|
||||||
|
|
||||||
# Backward pass
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
running_loss += loss.item()
|
|
||||||
num_batches += 1
|
|
||||||
|
|
||||||
# Report progress via callback
|
|
||||||
if callbacks and "on_fit_epoch_end" in callbacks:
|
|
||||||
# Create a mock trainer object for callback
|
|
||||||
class MockTrainer:
|
|
||||||
def __init__(self, epoch):
|
|
||||||
self.epoch = epoch
|
|
||||||
self.loss_items = [loss.item()]
|
|
||||||
|
|
||||||
callbacks["on_fit_epoch_end"](MockTrainer(epoch))
|
|
||||||
|
|
||||||
epoch_loss = running_loss / max(1, num_batches)
|
|
||||||
epoch_time = time.time() - epoch_start
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Epoch {epoch+1}/{epochs} completed. Avg Loss: {epoch_loss:.4f}, Time: {epoch_time:.1f}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save checkpoint
|
|
||||||
ckpt_path = weights_dir / f"epoch{epoch+1}.pt"
|
|
||||||
torch.save(
|
|
||||||
{
|
|
||||||
"epoch": epoch + 1,
|
|
||||||
"model_state_dict": pt_model.state_dict(),
|
|
||||||
"optimizer_state_dict": optimizer.state_dict(),
|
|
||||||
"loss": epoch_loss,
|
|
||||||
},
|
|
||||||
ckpt_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save as last.pt
|
|
||||||
last_path = weights_dir / "last.pt"
|
|
||||||
torch.save(pt_model.state_dict(), last_path)
|
|
||||||
|
|
||||||
# Check for best model
|
|
||||||
if epoch_loss < best_loss:
|
|
||||||
best_loss = epoch_loss
|
|
||||||
patience_counter = 0
|
|
||||||
best_path = weights_dir / "best.pt"
|
|
||||||
torch.save(pt_model.state_dict(), best_path)
|
|
||||||
logger.info(f"New best model saved: {best_path}")
|
|
||||||
else:
|
|
||||||
patience_counter += 1
|
|
||||||
|
|
||||||
# Early stopping
|
|
||||||
if patience_counter >= patience:
|
|
||||||
logger.info(f"Early stopping triggered after {epoch+1} epochs")
|
|
||||||
break
|
|
||||||
|
|
||||||
logger.info("Training completed successfully")
|
|
||||||
|
|
||||||
# Format results
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"final_epoch": epoch + 1,
|
|
||||||
"metrics": {
|
|
||||||
"final_loss": epoch_loss,
|
|
||||||
"best_loss": best_loss,
|
|
||||||
},
|
|
||||||
"best_model_path": str(weights_dir / "best.pt"),
|
|
||||||
"last_model_path": str(weights_dir / "last.pt"),
|
|
||||||
"save_dir": str(save_path),
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error during Float32 training: {e}")
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def _get_pytorch_model(ul_model: YOLO) -> Tuple[torch.nn.Module, Optional[callable]]:
|
|
||||||
"""
|
|
||||||
Extract PyTorch model and loss function from Ultralytics YOLO wrapper.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ul_model: Ultralytics YOLO model wrapper
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (pytorch_model, loss_function)
|
|
||||||
"""
|
|
||||||
# Try to get the underlying PyTorch model
|
|
||||||
candidates = []
|
|
||||||
|
|
||||||
# Direct model attribute
|
|
||||||
if hasattr(ul_model, "model"):
|
|
||||||
candidates.append(ul_model.model)
|
|
||||||
|
|
||||||
# Sometimes nested
|
|
||||||
if hasattr(ul_model, "model") and hasattr(ul_model.model, "model"):
|
|
||||||
candidates.append(ul_model.model.model)
|
|
||||||
|
|
||||||
# The wrapper itself
|
|
||||||
if isinstance(ul_model, torch.nn.Module):
|
|
||||||
candidates.append(ul_model)
|
|
||||||
|
|
||||||
# Find a valid model
|
|
||||||
pt_model = None
|
|
||||||
loss_fn = None
|
|
||||||
|
|
||||||
for candidate in candidates:
|
|
||||||
if candidate is None or not isinstance(candidate, torch.nn.Module):
|
|
||||||
continue
|
|
||||||
|
|
||||||
pt_model = candidate
|
|
||||||
|
|
||||||
# Try to find loss function
|
|
||||||
if hasattr(candidate, "loss") and callable(getattr(candidate, "loss")):
|
|
||||||
loss_fn = getattr(candidate, "loss")
|
|
||||||
elif hasattr(candidate, "compute_loss") and callable(
|
|
||||||
getattr(candidate, "compute_loss")
|
|
||||||
):
|
|
||||||
loss_fn = getattr(candidate, "compute_loss")
|
|
||||||
|
|
||||||
break
|
|
||||||
|
|
||||||
if pt_model is None:
|
|
||||||
raise RuntimeError("Could not extract PyTorch model from Ultralytics wrapper")
|
|
||||||
|
|
||||||
logger.info(f"Extracted PyTorch model: {type(pt_model)}")
|
|
||||||
logger.info(
|
|
||||||
f"Loss function: {type(loss_fn) if loss_fn else 'None (will attempt fallback)'}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return pt_model, loss_fn
|
|
||||||
|
|
||||||
|
|
||||||
# Compatibility function (kept for backwards compatibility)
|
|
||||||
def train_float32(model: YOLO, data_yaml: str, **train_kwargs) -> Any:
|
|
||||||
"""
|
|
||||||
Train YOLO model with Float32YOLODataset (alternative API).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: Initialized YOLO model instance
|
|
||||||
data_yaml: Path to dataset YAML
|
|
||||||
**train_kwargs: Training parameters
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Training results dict
|
|
||||||
"""
|
|
||||||
return train_with_float32_loader(
|
|
||||||
model_path=(
|
|
||||||
model.model_path if hasattr(model, "model_path") else "yolov8s-seg.pt"
|
|
||||||
),
|
|
||||||
data_yaml=data_yaml,
|
|
||||||
**train_kwargs,
|
|
||||||
)
|
|
||||||
157
src/utils/ultralytics_16bit_patch.py
Normal file
157
src/utils/ultralytics_16bit_patch.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
"""Ultralytics runtime patches for 16-bit TIFF training.
|
||||||
|
|
||||||
|
Goals:
|
||||||
|
- Use `tifffile` to decode `.tif/.tiff` reliably (OpenCV can silently drop bit-depth depending on codec).
|
||||||
|
- Preserve 16-bit data through the dataloader as `uint16` tensors.
|
||||||
|
- Fix Ultralytics trainer normalization (default divides by 255) to scale `uint16` correctly.
|
||||||
|
- Avoid uint8-forcing augmentations by recommending/setting hyp values (handled by caller).
|
||||||
|
|
||||||
|
This module is intended to be imported/called **before** instantiating/using YOLO.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from src.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
|
||||||
|
"""Apply runtime monkey-patches to Ultralytics to better support 16-bit TIFFs.
|
||||||
|
|
||||||
|
This function is safe to call multiple times.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
force: If True, re-apply patches even if already applied.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Import inside function to ensure patching occurs before YOLO model/dataset is created.
|
||||||
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# import tifffile
|
||||||
|
import torch
|
||||||
|
from src.utils.image import Image
|
||||||
|
|
||||||
|
from ultralytics.utils import patches as ul_patches
|
||||||
|
|
||||||
|
already_patched = getattr(ul_patches.imread, "__name__", "") == "tifffile_imread"
|
||||||
|
if already_patched and not force:
|
||||||
|
return
|
||||||
|
|
||||||
|
_original_imread = ul_patches.imread
|
||||||
|
|
||||||
|
def tifffile_imread(filename: str, flags: int = cv2.IMREAD_COLOR, pseudo_rgb: bool = True) -> Optional[np.ndarray]:
|
||||||
|
"""Replacement for [`ultralytics.utils.patches.imread()`](venv/lib/python3.12/site-packages/ultralytics/utils/patches.py:20).
|
||||||
|
|
||||||
|
- For `.tif/.tiff`, uses `tifffile.imread()` and preserves dtype (e.g. uint16).
|
||||||
|
- For other formats, falls back to Ultralytics' original implementation.
|
||||||
|
- Always returns HWC (3 dims). For grayscale, returns (H, W, 1) or (H, W, 3) depending on requested flags.
|
||||||
|
"""
|
||||||
|
# print("here")
|
||||||
|
# return _original_imread(filename, flags)
|
||||||
|
ext = os.path.splitext(filename)[1].lower()
|
||||||
|
if ext in (".tif", ".tiff"):
|
||||||
|
arr = Image(filename).get_qt_rgb()[:, :, :3]
|
||||||
|
|
||||||
|
# Normalize common shapes:
|
||||||
|
# - (H, W) -> (H, W, 1)
|
||||||
|
# - (C, H, W) -> (H, W, C) (heuristic)
|
||||||
|
if arr is None:
|
||||||
|
return None
|
||||||
|
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[0] < arr.shape[1]:
|
||||||
|
arr = np.transpose(arr, (1, 2, 0))
|
||||||
|
if arr.ndim == 2:
|
||||||
|
arr = arr[..., None]
|
||||||
|
|
||||||
|
# Ensure contiguous array for downstream OpenCV ops.
|
||||||
|
# logger.info(f"Loading with monkey-patched imread: {filename}")
|
||||||
|
arr = arr.astype(np.float32)
|
||||||
|
arr /= arr.max()
|
||||||
|
arr *= 2**8 - 1
|
||||||
|
arr = arr.astype(np.uint8)
|
||||||
|
# print(arr.shape, arr.dtype, any(np.isnan(arr).flatten()), np.where(np.isnan(arr)), arr.min(), arr.max())
|
||||||
|
return np.ascontiguousarray(arr)
|
||||||
|
|
||||||
|
# logger.info(f"Loading with original imread: {filename}")
|
||||||
|
return _original_imread(filename, flags)
|
||||||
|
|
||||||
|
# Patch the canonical reference.
|
||||||
|
ul_patches.imread = tifffile_imread
|
||||||
|
|
||||||
|
# Patch common module-level imports (some Ultralytics modules do `from ... import imread`).
|
||||||
|
# Importing these modules is safe and helps ensure the patched function is used.
|
||||||
|
try:
|
||||||
|
import ultralytics.data.base as _ul_base
|
||||||
|
|
||||||
|
_ul_base.imread = tifffile_imread
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
import ultralytics.data.loaders as _ul_loaders
|
||||||
|
|
||||||
|
_ul_loaders.imread = tifffile_imread
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Patch trainer normalization: default divides by 255 regardless of input dtype.
|
||||||
|
from ultralytics.models.yolo.detect import train as detect_train
|
||||||
|
|
||||||
|
_orig_preprocess_batch = detect_train.DetectionTrainer.preprocess_batch
|
||||||
|
|
||||||
|
def preprocess_batch_16bit(self, batch: dict) -> dict: # type: ignore[override]
|
||||||
|
# Start from upstream behavior to keep device placement + multiscale identical,
|
||||||
|
# but replace the 255 division with dtype-aware scaling.
|
||||||
|
# logger.info(f"Preprocessing batch with monkey-patched preprocess_batch")
|
||||||
|
for k, v in batch.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")
|
||||||
|
|
||||||
|
img = batch.get("img")
|
||||||
|
if isinstance(img, torch.Tensor):
|
||||||
|
# Decide scaling denom based on dtype (avoid expensive reductions if possible).
|
||||||
|
if img.dtype == torch.uint8:
|
||||||
|
denom = 255.0
|
||||||
|
elif img.dtype == torch.uint16:
|
||||||
|
denom = 65535.0
|
||||||
|
elif img.dtype.is_floating_point:
|
||||||
|
# Assume already in 0-1 range if float.
|
||||||
|
denom = 1.0
|
||||||
|
else:
|
||||||
|
# Generic integer fallback.
|
||||||
|
try:
|
||||||
|
denom = float(torch.iinfo(img.dtype).max)
|
||||||
|
except Exception:
|
||||||
|
denom = 255.0
|
||||||
|
|
||||||
|
batch["img"] = img.float() / denom
|
||||||
|
|
||||||
|
# Multi-scale branch copied from upstream to avoid re-introducing `/255` scaling.
|
||||||
|
if getattr(self.args, "multi_scale", False):
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
imgs = batch["img"]
|
||||||
|
sz = (
|
||||||
|
random.randrange(int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride))
|
||||||
|
// self.stride
|
||||||
|
* self.stride
|
||||||
|
)
|
||||||
|
sf = sz / max(imgs.shape[2:])
|
||||||
|
if sf != 1:
|
||||||
|
ns = [math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]]
|
||||||
|
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
|
||||||
|
batch["img"] = imgs
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
detect_train.DetectionTrainer.preprocess_batch = preprocess_batch_16bit
|
||||||
|
|
||||||
|
# Tag function to make it easier to detect patch state.
|
||||||
|
setattr(detect_train.DetectionTrainer.preprocess_batch, "_ultralytics_16bit_patch", True)
|
||||||
@@ -17,6 +17,9 @@ import matplotlib.pyplot as plt
|
|||||||
import argparse
|
import argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import random
|
import random
|
||||||
|
from shapely.geometry import LineString
|
||||||
|
|
||||||
|
from src.utils.image import Image
|
||||||
|
|
||||||
|
|
||||||
def parse_label_line(line):
|
def parse_label_line(line):
|
||||||
@@ -52,36 +55,55 @@ def yolo_bbox_to_xyxy(coords, img_w, img_h):
|
|||||||
|
|
||||||
def poly_to_pts(coords, img_w, img_h):
|
def poly_to_pts(coords, img_w, img_h):
|
||||||
# coords: [x1 y1 x2 y2 ...] either normalized or absolute
|
# coords: [x1 y1 x2 y2 ...] either normalized or absolute
|
||||||
if coords_are_normalized(coords):
|
if coords_are_normalized(coords[4:]):
|
||||||
coords = [
|
coords = [coords[i] * (img_w if i % 2 == 0 else img_h) for i in range(len(coords))]
|
||||||
coords[i] * (img_w if i % 2 == 0 else img_h) for i in range(len(coords))
|
|
||||||
]
|
|
||||||
pts = np.array(coords, dtype=np.int32).reshape(-1, 2)
|
pts = np.array(coords, dtype=np.int32).reshape(-1, 2)
|
||||||
return pts
|
return pts
|
||||||
|
|
||||||
|
|
||||||
def random_color_for_class(cls):
|
def random_color_for_class(cls):
|
||||||
random.seed(cls) # deterministic per class
|
random.seed(cls) # deterministic per class
|
||||||
return tuple(int(x) for x in np.array([random.randint(0, 255) for _ in range(3)]))
|
return (
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
255,
|
||||||
|
) # tuple(int(x) for x in np.array([random.randint(0, 255) for _ in range(3)]))
|
||||||
|
|
||||||
|
|
||||||
def draw_annotations(img, labels, alpha=0.4, draw_bbox_for_poly=True):
|
def draw_annotations(img, labels, alpha=0.4, draw_bbox_for_poly=True):
|
||||||
# img: BGR numpy array
|
# img: BGR numpy array
|
||||||
overlay = img.copy()
|
overlay = img.copy()
|
||||||
h, w = img.shape[:2]
|
h, w = img.shape[:2]
|
||||||
for cls, coords in labels:
|
for line in labels:
|
||||||
|
if isinstance(line, str):
|
||||||
|
cls, coords = parse_label_line(line)
|
||||||
|
if isinstance(line, tuple):
|
||||||
|
cls, coords = line
|
||||||
|
|
||||||
if not coords:
|
if not coords:
|
||||||
continue
|
continue
|
||||||
# polygon case (>=6 coordinates)
|
# polygon case (>=6 coordinates)
|
||||||
if len(coords) >= 6:
|
if len(coords) >= 6:
|
||||||
pts = poly_to_pts(coords, w, h)
|
|
||||||
color = random_color_for_class(cls)
|
color = random_color_for_class(cls)
|
||||||
|
|
||||||
|
x1, y1, x2, y2 = yolo_bbox_to_xyxy(coords[:4], w, h)
|
||||||
|
print(x1, y1, x2, y2)
|
||||||
|
cv2.rectangle(img, (x1, y1), (x2, y2), color, 1)
|
||||||
|
|
||||||
|
pts = poly_to_pts(coords[4:], w, h)
|
||||||
|
# line = LineString(pts)
|
||||||
|
# # Buffer distance in pixels
|
||||||
|
# buffered = line.buffer(3, cap_style=2, join_style=2)
|
||||||
|
# coords = np.array(buffered.exterior.coords, dtype=np.int32)
|
||||||
|
# cv2.fillPoly(overlay, [coords], color=(255, 255, 255))
|
||||||
|
|
||||||
# fill on overlay
|
# fill on overlay
|
||||||
cv2.fillPoly(overlay, [pts], color)
|
cv2.fillPoly(overlay, [pts], color)
|
||||||
# outline on base image
|
# outline on base image
|
||||||
cv2.polylines(img, [pts], isClosed=True, color=color, thickness=2)
|
cv2.polylines(img, [pts], isClosed=True, color=color, thickness=1)
|
||||||
# put class text at first point
|
# put class text at first point
|
||||||
x, y = int(pts[0, 0]), int(pts[0, 1]) - 6
|
x, y = int(pts[0, 0]), int(pts[0, 1]) - 6
|
||||||
|
if 0:
|
||||||
cv2.putText(
|
cv2.putText(
|
||||||
img,
|
img,
|
||||||
str(cls),
|
str(cls),
|
||||||
@@ -92,9 +114,7 @@ def draw_annotations(img, labels, alpha=0.4, draw_bbox_for_poly=True):
|
|||||||
2,
|
2,
|
||||||
cv2.LINE_AA,
|
cv2.LINE_AA,
|
||||||
)
|
)
|
||||||
if draw_bbox_for_poly:
|
|
||||||
x, y, w_box, h_box = cv2.boundingRect(pts)
|
|
||||||
cv2.rectangle(img, (x, y), (x + w_box, y + h_box), color, 1)
|
|
||||||
# YOLO bbox case (4 coords)
|
# YOLO bbox case (4 coords)
|
||||||
elif len(coords) == 4:
|
elif len(coords) == 4:
|
||||||
x1, y1, x2, y2 = yolo_bbox_to_xyxy(coords, w, h)
|
x1, y1, x2, y2 = yolo_bbox_to_xyxy(coords, w, h)
|
||||||
@@ -133,21 +153,21 @@ def load_labels_file(label_path):
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(description="Show YOLO segmentation / polygon annotations")
|
||||||
description="Show YOLO segmentation / polygon annotations"
|
|
||||||
)
|
|
||||||
parser.add_argument("image", type=str, help="Path to image file")
|
parser.add_argument("image", type=str, help="Path to image file")
|
||||||
parser.add_argument("labels", type=str, help="Path to YOLO label file (polygons)")
|
parser.add_argument("--labels", type=str, help="Path to YOLO label file (polygons)")
|
||||||
parser.add_argument(
|
parser.add_argument("--alpha", type=float, default=0.4, help="Polygon fill alpha (0..1)")
|
||||||
"--alpha", type=float, default=0.4, help="Polygon fill alpha (0..1)"
|
parser.add_argument("--no-bbox", action="store_true", help="Don't draw bounding boxes for polygons")
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--no-bbox", action="store_true", help="Don't draw bounding boxes for polygons"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(args)
|
||||||
|
|
||||||
img_path = Path(args.image)
|
img_path = Path(args.image)
|
||||||
|
if args.labels:
|
||||||
lbl_path = Path(args.labels)
|
lbl_path = Path(args.labels)
|
||||||
|
else:
|
||||||
|
lbl_path = img_path.with_suffix(".txt")
|
||||||
|
lbl_path = Path(str(lbl_path).replace("images", "labels"))
|
||||||
|
|
||||||
if not img_path.exists():
|
if not img_path.exists():
|
||||||
print("Image not found:", img_path)
|
print("Image not found:", img_path)
|
||||||
@@ -156,7 +176,9 @@ def main():
|
|||||||
print("Label file not found:", lbl_path)
|
print("Label file not found:", lbl_path)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
img = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
|
# img = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
|
||||||
|
img = (Image(img_path).get_qt_rgb() * 255).astype(np.uint8)
|
||||||
|
|
||||||
if img is None:
|
if img is None:
|
||||||
print("Could not load image:", img_path)
|
print("Could not load image:", img_path)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
@@ -165,15 +187,42 @@ def main():
|
|||||||
if not labels:
|
if not labels:
|
||||||
print("No labels parsed from", lbl_path)
|
print("No labels parsed from", lbl_path)
|
||||||
# continue and just show image
|
# continue and just show image
|
||||||
out = draw_annotations(
|
out = draw_annotations(img.copy(), labels, alpha=args.alpha, draw_bbox_for_poly=(not args.no_bbox))
|
||||||
img.copy(), labels, alpha=args.alpha, draw_bbox_for_poly=(not args.no_bbox)
|
|
||||||
)
|
out_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
plt.figure(figsize=(10, 10 * out.shape[0] / out.shape[1]))
|
||||||
|
if 0:
|
||||||
|
plt.imshow(out_rgb.transpose(1, 0, 2))
|
||||||
|
else:
|
||||||
|
plt.imshow(out_rgb)
|
||||||
|
|
||||||
|
for label in labels:
|
||||||
|
lclass, coords = label
|
||||||
|
# print(lclass, coords)
|
||||||
|
bbox = coords[:4]
|
||||||
|
# print("bbox", bbox)
|
||||||
|
bbox = np.array(bbox) * np.array([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
|
||||||
|
yc, xc, h, w = bbox
|
||||||
|
# print("bbox", bbox)
|
||||||
|
|
||||||
|
# polyline = np.array(coords[4:]).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
|
||||||
|
polyline = np.array(coords).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
|
||||||
|
# print("pl", coords[4:])
|
||||||
|
# print("pl", polyline)
|
||||||
|
|
||||||
# Convert BGR -> RGB for matplotlib display
|
# Convert BGR -> RGB for matplotlib display
|
||||||
out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
|
# out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
|
||||||
plt.figure(figsize=(10, 10 * out.shape[0] / out.shape[1]))
|
# out_rgb = Image()
|
||||||
plt.imshow(out_rgb)
|
plt.plot(polyline[:, 0], polyline[:, 1], "y", linewidth=2)
|
||||||
plt.axis("off")
|
if 0:
|
||||||
|
plt.plot(
|
||||||
|
[yc - h / 2, yc - h / 2, yc + h / 2, yc + h / 2, yc - h / 2],
|
||||||
|
[xc - w / 2, xc + w / 2, xc + w / 2, xc - w / 2, xc - w / 2],
|
||||||
|
"r",
|
||||||
|
linewidth=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# plt.axis("off")
|
||||||
plt.title(f"{img_path.name} ({lbl_path.name})")
|
plt.title(f"{img_path.name} ({lbl_path.name})")
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|||||||
@@ -1,109 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Test script for 16-bit TIFF loading and normalization.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import tifffile
|
|
||||||
from pathlib import Path
|
|
||||||
import tempfile
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
# Add parent directory to path to import modules
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
||||||
|
|
||||||
from src.utils.image import Image
|
|
||||||
|
|
||||||
|
|
||||||
def create_test_16bit_tiff(output_path: str) -> str:
|
|
||||||
"""Create a test 16-bit grayscale TIFF file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
output_path: Path where to save the test TIFF
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path to the created TIFF file
|
|
||||||
"""
|
|
||||||
# Create a 16-bit grayscale test image (100x100)
|
|
||||||
# With values ranging from 0 to 65535 (full 16-bit range)
|
|
||||||
height, width = 100, 100
|
|
||||||
|
|
||||||
# Create a gradient pattern
|
|
||||||
test_data = np.zeros((height, width), dtype=np.uint16)
|
|
||||||
for i in range(height):
|
|
||||||
for j in range(width):
|
|
||||||
# Create a diagonal gradient
|
|
||||||
test_data[i, j] = int((i + j) / (height + width - 2) * 65535)
|
|
||||||
|
|
||||||
# Save as TIFF
|
|
||||||
tifffile.imwrite(output_path, test_data)
|
|
||||||
print(f"Created test 16-bit TIFF: {output_path}")
|
|
||||||
print(f" Shape: {test_data.shape}")
|
|
||||||
print(f" Dtype: {test_data.dtype}")
|
|
||||||
print(f" Min value: {test_data.min()}")
|
|
||||||
print(f" Max value: {test_data.max()}")
|
|
||||||
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_loading():
|
|
||||||
"""Test loading 16-bit TIFF with the Image class."""
|
|
||||||
print("\n=== Testing Image Loading ===")
|
|
||||||
|
|
||||||
# Create temporary test file
|
|
||||||
with tempfile.NamedTemporaryFile(suffix=".tif", delete=False) as tmp:
|
|
||||||
test_path = tmp.name
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Create test image
|
|
||||||
create_test_16bit_tiff(test_path)
|
|
||||||
|
|
||||||
# Load with Image class
|
|
||||||
print("\nLoading with Image class...")
|
|
||||||
img = Image(test_path)
|
|
||||||
|
|
||||||
print(f"Successfully loaded image:")
|
|
||||||
print(f" Width: {img.width}")
|
|
||||||
print(f" Height: {img.height}")
|
|
||||||
print(f" Channels: {img.channels}")
|
|
||||||
print(f" Dtype: {img.dtype}")
|
|
||||||
print(f" Format: {img.format}")
|
|
||||||
|
|
||||||
# Test normalization
|
|
||||||
print("\nTesting normalization to float32 [0-1]...")
|
|
||||||
normalized = img.to_normalized_float32()
|
|
||||||
|
|
||||||
print(f"Normalized image:")
|
|
||||||
print(f" Shape: {normalized.shape}")
|
|
||||||
print(f" Dtype: {normalized.dtype}")
|
|
||||||
print(f" Min value: {normalized.min():.6f}")
|
|
||||||
print(f" Max value: {normalized.max():.6f}")
|
|
||||||
print(f" Mean value: {normalized.mean():.6f}")
|
|
||||||
|
|
||||||
# Verify normalization
|
|
||||||
assert normalized.dtype == np.float32, "Dtype should be float32"
|
|
||||||
assert (
|
|
||||||
0.0 <= normalized.min() <= normalized.max() <= 1.0
|
|
||||||
), "Values should be in [0, 1]"
|
|
||||||
|
|
||||||
print("\n✓ All tests passed!")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n✗ Test failed with error: {e}")
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# Cleanup
|
|
||||||
if os.path.exists(test_path):
|
|
||||||
os.remove(test_path)
|
|
||||||
print(f"\nCleaned up test file: {test_path}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
success = test_image_loading()
|
|
||||||
sys.exit(0 if success else 1)
|
|
||||||
@@ -1,211 +0,0 @@
|
|||||||
"""
|
|
||||||
Test script for Float32 on-the-fly loading for 16-bit TIFFs.
|
|
||||||
|
|
||||||
This test verifies that:
|
|
||||||
1. Float32YOLODataset can load 16-bit TIFF files
|
|
||||||
2. Images are converted to float32 [0-1] in memory
|
|
||||||
3. Grayscale is replicated to 3 channels (RGB)
|
|
||||||
4. No disk caching is used
|
|
||||||
5. Full 16-bit precision is preserved
|
|
||||||
"""
|
|
||||||
|
|
||||||
import tempfile
|
|
||||||
import numpy as np
|
|
||||||
import tifffile
|
|
||||||
from pathlib import Path
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
|
|
||||||
def create_test_dataset():
|
|
||||||
"""Create a minimal test dataset with 16-bit TIFF images."""
|
|
||||||
temp_dir = Path(tempfile.mkdtemp())
|
|
||||||
dataset_dir = temp_dir / "test_dataset"
|
|
||||||
|
|
||||||
# Create directory structure
|
|
||||||
train_images = dataset_dir / "train" / "images"
|
|
||||||
train_labels = dataset_dir / "train" / "labels"
|
|
||||||
train_images.mkdir(parents=True, exist_ok=True)
|
|
||||||
train_labels.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Create a 16-bit TIFF test image
|
|
||||||
img_16bit = np.random.randint(0, 65536, (100, 100), dtype=np.uint16)
|
|
||||||
img_path = train_images / "test_image.tif"
|
|
||||||
tifffile.imwrite(str(img_path), img_16bit)
|
|
||||||
|
|
||||||
# Create a dummy label file
|
|
||||||
label_path = train_labels / "test_image.txt"
|
|
||||||
with open(label_path, "w") as f:
|
|
||||||
f.write("0 0.5 0.5 0.2 0.2\n") # class_id x_center y_center width height
|
|
||||||
|
|
||||||
# Create data.yaml
|
|
||||||
data_yaml = {
|
|
||||||
"path": str(dataset_dir),
|
|
||||||
"train": "train/images",
|
|
||||||
"val": "train/images", # Use same for val in test
|
|
||||||
"names": {0: "object"},
|
|
||||||
"nc": 1,
|
|
||||||
}
|
|
||||||
yaml_path = dataset_dir / "data.yaml"
|
|
||||||
with open(yaml_path, "w") as f:
|
|
||||||
yaml.safe_dump(data_yaml, f)
|
|
||||||
|
|
||||||
print(f"✓ Created test dataset at: {dataset_dir}")
|
|
||||||
print(f" - Image: {img_path} (shape={img_16bit.shape}, dtype={img_16bit.dtype})")
|
|
||||||
print(f" - Min value: {img_16bit.min()}, Max value: {img_16bit.max()}")
|
|
||||||
print(f" - data.yaml: {yaml_path}")
|
|
||||||
|
|
||||||
return dataset_dir, img_path, img_16bit
|
|
||||||
|
|
||||||
|
|
||||||
def test_float32_dataset():
|
|
||||||
"""Test the Float32YOLODataset class directly."""
|
|
||||||
print("\n=== Testing Float32YOLODataset ===\n")
|
|
||||||
|
|
||||||
try:
|
|
||||||
from src.utils.train_ultralytics_float import Float32YOLODataset
|
|
||||||
|
|
||||||
print("✓ Successfully imported Float32YOLODataset")
|
|
||||||
except ImportError as e:
|
|
||||||
print(f"✗ Failed to import Float32YOLODataset: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Create test dataset
|
|
||||||
dataset_dir, img_path, original_img = create_test_dataset()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Initialize the dataset
|
|
||||||
print("\nInitializing Float32YOLODataset...")
|
|
||||||
dataset = Float32YOLODataset(
|
|
||||||
images_dir=str(dataset_dir / "train" / "images"),
|
|
||||||
labels_dir=str(dataset_dir / "train" / "labels"),
|
|
||||||
img_size=640,
|
|
||||||
)
|
|
||||||
print(f"✓ Float32YOLODataset initialized with {len(dataset)} images")
|
|
||||||
|
|
||||||
# Get an item
|
|
||||||
if len(dataset) > 0:
|
|
||||||
print("\nGetting first item...")
|
|
||||||
img_tensor, labels, filename = dataset[0]
|
|
||||||
|
|
||||||
print(f"✓ Item retrieved successfully")
|
|
||||||
print(f" - Image tensor shape: {img_tensor.shape}")
|
|
||||||
print(f" - Image tensor dtype: {img_tensor.dtype}")
|
|
||||||
print(f" - Value range: [{img_tensor.min():.6f}, {img_tensor.max():.6f}]")
|
|
||||||
print(f" - Filename: {filename}")
|
|
||||||
print(f" - Labels: {len(labels)} annotations")
|
|
||||||
if labels:
|
|
||||||
print(
|
|
||||||
f" - First label shape: {labels[0].shape if len(labels) > 0 else 'N/A'}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify it's float32
|
|
||||||
if img_tensor.dtype == torch.float32:
|
|
||||||
print("✓ Correct dtype: float32")
|
|
||||||
else:
|
|
||||||
print(f"✗ Wrong dtype: {img_tensor.dtype} (expected float32)")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Verify it's 3-channel in correct format (C, H, W)
|
|
||||||
if len(img_tensor.shape) == 3 and img_tensor.shape[0] == 3:
|
|
||||||
print(
|
|
||||||
f"✓ Correct format: (C, H, W) = {img_tensor.shape} with 3 channels"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print(f"✗ Wrong shape: {img_tensor.shape} (expected (3, H, W))")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Verify it's in [0, 1] range
|
|
||||||
if 0.0 <= img_tensor.min() and img_tensor.max() <= 1.0:
|
|
||||||
print("✓ Values in correct range: [0, 1]")
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
f"✗ Values out of range: [{img_tensor.min()}, {img_tensor.max()}]"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Verify precision (should have many unique values)
|
|
||||||
unique_values = len(torch.unique(img_tensor))
|
|
||||||
print(f" - Unique values: {unique_values}")
|
|
||||||
if unique_values > 256:
|
|
||||||
print(f"✓ High precision maintained ({unique_values} > 256 levels)")
|
|
||||||
else:
|
|
||||||
print(f"⚠ Low precision: only {unique_values} unique values")
|
|
||||||
|
|
||||||
print("\n✓ All Float32YOLODataset tests passed!")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
print("✗ No items in dataset")
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ Error during testing: {e}")
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def test_integration():
|
|
||||||
"""Test integration with train_with_float32_loader."""
|
|
||||||
print("\n=== Testing Integration with train_with_float32_loader ===\n")
|
|
||||||
|
|
||||||
# Create test dataset
|
|
||||||
dataset_dir, img_path, original_img = create_test_dataset()
|
|
||||||
data_yaml = dataset_dir / "data.yaml"
|
|
||||||
|
|
||||||
print(f"\nTest dataset ready at: {data_yaml}")
|
|
||||||
print("\nTo test full training, run:")
|
|
||||||
print(f" from src.utils.train_ultralytics_float import train_with_float32_loader")
|
|
||||||
print(f" results = train_with_float32_loader(")
|
|
||||||
print(f" model_path='yolov8n-seg.pt',")
|
|
||||||
print(f" data_yaml='{data_yaml}',")
|
|
||||||
print(f" epochs=1,")
|
|
||||||
print(f" batch=1,")
|
|
||||||
print(f" imgsz=640")
|
|
||||||
print(f" )")
|
|
||||||
print("\nThis will use custom training loop with Float32YOLODataset")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Run all tests."""
|
|
||||||
import torch # Import here to ensure torch is available
|
|
||||||
|
|
||||||
print("=" * 70)
|
|
||||||
print("Float32 Training Loader Test Suite")
|
|
||||||
print("=" * 70)
|
|
||||||
|
|
||||||
results = []
|
|
||||||
|
|
||||||
# Test 1: Float32YOLODataset
|
|
||||||
results.append(("Float32YOLODataset", test_float32_dataset()))
|
|
||||||
|
|
||||||
# Test 2: Integration check
|
|
||||||
results.append(("Integration Check", test_integration()))
|
|
||||||
|
|
||||||
# Summary
|
|
||||||
print("\n" + "=" * 70)
|
|
||||||
print("Test Summary")
|
|
||||||
print("=" * 70)
|
|
||||||
for test_name, passed in results:
|
|
||||||
status = "✓ PASSED" if passed else "✗ FAILED"
|
|
||||||
print(f"{status}: {test_name}")
|
|
||||||
|
|
||||||
all_passed = all(passed for _, passed in results)
|
|
||||||
print("=" * 70)
|
|
||||||
if all_passed:
|
|
||||||
print("✓ All tests passed!")
|
|
||||||
else:
|
|
||||||
print("✗ Some tests failed")
|
|
||||||
print("=" * 70)
|
|
||||||
|
|
||||||
return all_passed
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import sys
|
|
||||||
import torch # Make torch available
|
|
||||||
|
|
||||||
success = main()
|
|
||||||
sys.exit(0 if success else 1)
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,142 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Test script for training dataset preparation with 16-bit TIFFs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import tifffile
|
|
||||||
from pathlib import Path
|
|
||||||
import tempfile
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
# Add parent directory to path to import modules
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
||||||
|
|
||||||
from src.utils.image import Image
|
|
||||||
|
|
||||||
|
|
||||||
def test_float32_3ch_conversion():
|
|
||||||
"""Test conversion of 16-bit TIFF to 16-bit RGB PNG."""
|
|
||||||
print("\n=== Testing 16-bit RGB PNG Conversion ===")
|
|
||||||
|
|
||||||
# Create temporary directory structure
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
|
||||||
tmpdir = Path(tmpdir)
|
|
||||||
src_dir = tmpdir / "original"
|
|
||||||
dst_dir = tmpdir / "converted"
|
|
||||||
src_dir.mkdir()
|
|
||||||
dst_dir.mkdir()
|
|
||||||
|
|
||||||
# Create test 16-bit TIFF
|
|
||||||
test_data = np.zeros((100, 100), dtype=np.uint16)
|
|
||||||
for i in range(100):
|
|
||||||
for j in range(100):
|
|
||||||
test_data[i, j] = int((i + j) / 198 * 65535)
|
|
||||||
|
|
||||||
test_file = src_dir / "test_16bit.tif"
|
|
||||||
tifffile.imwrite(test_file, test_data)
|
|
||||||
print(f"Created test 16-bit TIFF: {test_file}")
|
|
||||||
print(f" Shape: {test_data.shape}")
|
|
||||||
print(f" Dtype: {test_data.dtype}")
|
|
||||||
print(f" Range: [{test_data.min()}, {test_data.max()}]")
|
|
||||||
|
|
||||||
# Simulate the conversion process (matching training_tab.py)
|
|
||||||
print("\nConverting to 16-bit RGB PNG using PIL merge...")
|
|
||||||
img_obj = Image(test_file)
|
|
||||||
from PIL import Image as PILImage
|
|
||||||
|
|
||||||
# Get uint16 data
|
|
||||||
uint16_data = img_obj.data
|
|
||||||
|
|
||||||
# Use PIL's merge method with 'I;16' channels (proper way for 16-bit RGB)
|
|
||||||
if len(uint16_data.shape) == 2:
|
|
||||||
# Grayscale - replicate to RGB
|
|
||||||
r_img = PILImage.fromarray(uint16_data, mode="I;16")
|
|
||||||
g_img = PILImage.fromarray(uint16_data, mode="I;16")
|
|
||||||
b_img = PILImage.fromarray(uint16_data, mode="I;16")
|
|
||||||
else:
|
|
||||||
r_img = PILImage.fromarray(uint16_data[:, :, 0], mode="I;16")
|
|
||||||
g_img = PILImage.fromarray(
|
|
||||||
(
|
|
||||||
uint16_data[:, :, 1]
|
|
||||||
if uint16_data.shape[2] > 1
|
|
||||||
else uint16_data[:, :, 0]
|
|
||||||
),
|
|
||||||
mode="I;16",
|
|
||||||
)
|
|
||||||
b_img = PILImage.fromarray(
|
|
||||||
(
|
|
||||||
uint16_data[:, :, 2]
|
|
||||||
if uint16_data.shape[2] > 2
|
|
||||||
else uint16_data[:, :, 0]
|
|
||||||
),
|
|
||||||
mode="I;16",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Merge channels into RGB
|
|
||||||
rgb_img = PILImage.merge("RGB", (r_img, g_img, b_img))
|
|
||||||
|
|
||||||
# Save as PNG
|
|
||||||
output_file = dst_dir / "test_16bit_rgb.png"
|
|
||||||
rgb_img.save(output_file)
|
|
||||||
print(f"Saved 16-bit RGB PNG: {output_file}")
|
|
||||||
print(f" PIL mode after merge: {rgb_img.mode}")
|
|
||||||
|
|
||||||
# Verify the output - Load with OpenCV (as YOLO does)
|
|
||||||
import cv2
|
|
||||||
|
|
||||||
loaded = cv2.imread(str(output_file), cv2.IMREAD_UNCHANGED)
|
|
||||||
print(f"\nVerifying output (loaded with OpenCV):")
|
|
||||||
print(f" Shape: {loaded.shape}")
|
|
||||||
print(f" Dtype: {loaded.dtype}")
|
|
||||||
print(f" Channels: {loaded.shape[2] if len(loaded.shape) == 3 else 1}")
|
|
||||||
print(f" Range: [{loaded.min()}, {loaded.max()}]")
|
|
||||||
print(f" Unique values: {len(np.unique(loaded[:,:,0]))}")
|
|
||||||
|
|
||||||
# Assertions
|
|
||||||
assert loaded.dtype == np.uint16, f"Expected uint16, got {loaded.dtype}"
|
|
||||||
assert loaded.shape[2] == 3, f"Expected 3 channels, got {loaded.shape[2]}"
|
|
||||||
assert (
|
|
||||||
loaded.min() >= 0 and loaded.max() <= 65535
|
|
||||||
), f"Expected [0,65535] range, got [{loaded.min()}, {loaded.max()}]"
|
|
||||||
|
|
||||||
# Verify all channels are identical (replicated grayscale)
|
|
||||||
assert np.array_equal(
|
|
||||||
loaded[:, :, 0], loaded[:, :, 1]
|
|
||||||
), "Channel 0 and 1 should be identical"
|
|
||||||
assert np.array_equal(
|
|
||||||
loaded[:, :, 0], loaded[:, :, 2]
|
|
||||||
), "Channel 0 and 2 should be identical"
|
|
||||||
|
|
||||||
# Verify no data loss
|
|
||||||
unique_vals = len(np.unique(loaded[:, :, 0]))
|
|
||||||
print(f"\n Precision check:")
|
|
||||||
print(f" Unique values in channel: {unique_vals}")
|
|
||||||
print(f" Source unique values: {len(np.unique(test_data))}")
|
|
||||||
|
|
||||||
assert unique_vals == len(
|
|
||||||
np.unique(test_data)
|
|
||||||
), f"Expected {len(np.unique(test_data))} unique values, got {unique_vals}"
|
|
||||||
|
|
||||||
print("\n✓ All conversion tests passed!")
|
|
||||||
print(" - uint16 dtype preserved")
|
|
||||||
print(" - 3 channels created")
|
|
||||||
print(" - Range [0-65535] maintained")
|
|
||||||
print(" - No precision loss from conversion")
|
|
||||||
print(" - Channels properly replicated")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
try:
|
|
||||||
success = test_float32_3ch_conversion()
|
|
||||||
sys.exit(0 if success else 1)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n✗ Test failed with error: {e}")
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
sys.exit(1)
|
|
||||||
@@ -1,150 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Test script for YOLO preprocessing of 16-bit TIFF images with float32 passthrough.
|
|
||||||
Verifies that no uint8 conversion occurs and data is preserved.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import tifffile
|
|
||||||
from pathlib import Path
|
|
||||||
import tempfile
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
# Add parent directory to path to import modules
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
||||||
|
|
||||||
from src.model.yolo_wrapper import YOLOWrapper
|
|
||||||
|
|
||||||
|
|
||||||
def create_test_16bit_tiff(output_path: str) -> str:
|
|
||||||
"""Create a test 16-bit grayscale TIFF file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
output_path: Path where to save the test TIFF
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path to the created TIFF file
|
|
||||||
"""
|
|
||||||
# Create a 16-bit grayscale test image (200x200)
|
|
||||||
# With specific values to test precision preservation
|
|
||||||
height, width = 200, 200
|
|
||||||
|
|
||||||
# Create a gradient pattern with the full 16-bit range
|
|
||||||
test_data = np.zeros((height, width), dtype=np.uint16)
|
|
||||||
for i in range(height):
|
|
||||||
for j in range(width):
|
|
||||||
# Create a diagonal gradient using full 16-bit range
|
|
||||||
test_data[i, j] = int((i + j) / (height + width - 2) * 65535)
|
|
||||||
|
|
||||||
# Save as TIFF
|
|
||||||
tifffile.imwrite(output_path, test_data)
|
|
||||||
print(f"Created test 16-bit TIFF: {output_path}")
|
|
||||||
print(f" Shape: {test_data.shape}")
|
|
||||||
print(f" Dtype: {test_data.dtype}")
|
|
||||||
print(f" Min value: {test_data.min()}")
|
|
||||||
print(f" Max value: {test_data.max()}")
|
|
||||||
print(
|
|
||||||
f" Sample values: {test_data[50, 50]}, {test_data[100, 100]}, {test_data[150, 150]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
|
|
||||||
def test_float32_passthrough():
|
|
||||||
"""Test that 16-bit TIFF preprocessing passes float32 directly without uint8 conversion."""
|
|
||||||
print("\n=== Testing Float32 Passthrough (NO uint8) ===")
|
|
||||||
|
|
||||||
# Create temporary test file
|
|
||||||
with tempfile.NamedTemporaryFile(suffix=".tif", delete=False) as tmp:
|
|
||||||
test_path = tmp.name
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Create test image
|
|
||||||
create_test_16bit_tiff(test_path)
|
|
||||||
|
|
||||||
# Create YOLOWrapper instance
|
|
||||||
print("\nTesting YOLOWrapper._prepare_source() for float32 passthrough...")
|
|
||||||
wrapper = YOLOWrapper()
|
|
||||||
|
|
||||||
# Call _prepare_source to preprocess the image
|
|
||||||
prepared_source, cleanup_path = wrapper._prepare_source(test_path)
|
|
||||||
|
|
||||||
print(f"\nPreprocessing result:")
|
|
||||||
print(f" Original path: {test_path}")
|
|
||||||
print(f" Prepared source type: {type(prepared_source)}")
|
|
||||||
|
|
||||||
# Verify it returns a numpy array (not a file path)
|
|
||||||
if isinstance(prepared_source, np.ndarray):
|
|
||||||
print(
|
|
||||||
f"\n✓ SUCCESS: Prepared source is a numpy array (float32 passthrough)"
|
|
||||||
)
|
|
||||||
print(f" Shape: {prepared_source.shape}")
|
|
||||||
print(f" Dtype: {prepared_source.dtype}")
|
|
||||||
print(f" Min value: {prepared_source.min():.6f}")
|
|
||||||
print(f" Max value: {prepared_source.max():.6f}")
|
|
||||||
print(f" Mean value: {prepared_source.mean():.6f}")
|
|
||||||
|
|
||||||
# Verify it's float32 in [0, 1] range
|
|
||||||
assert (
|
|
||||||
prepared_source.dtype == np.float32
|
|
||||||
), f"Expected float32, got {prepared_source.dtype}"
|
|
||||||
assert (
|
|
||||||
0.0 <= prepared_source.min() <= prepared_source.max() <= 1.0
|
|
||||||
), f"Expected values in [0, 1], got [{prepared_source.min()}, {prepared_source.max()}]"
|
|
||||||
|
|
||||||
# Verify it has 3 channels (RGB)
|
|
||||||
assert (
|
|
||||||
prepared_source.shape[2] == 3
|
|
||||||
), f"Expected 3 channels (RGB), got {prepared_source.shape[2]}"
|
|
||||||
|
|
||||||
# Verify no quantization to 256 levels (would happen with uint8 conversion)
|
|
||||||
unique_values = len(np.unique(prepared_source))
|
|
||||||
print(f" Unique values: {unique_values}")
|
|
||||||
|
|
||||||
# With float32, we should have much more than 256 unique values
|
|
||||||
if unique_values > 256:
|
|
||||||
print(f"\n✓ SUCCESS: Data has {unique_values} unique values (> 256)")
|
|
||||||
print(f" This confirms NO uint8 quantization occurred!")
|
|
||||||
else:
|
|
||||||
print(f"\n✗ WARNING: Data has only {unique_values} unique values")
|
|
||||||
print(f" This might indicate uint8 quantization happened")
|
|
||||||
|
|
||||||
# Sample some values to show precision
|
|
||||||
print(f"\n Sample normalized values:")
|
|
||||||
print(f" [50, 50]: {prepared_source[50, 50, 0]:.8f}")
|
|
||||||
print(f" [100, 100]: {prepared_source[100, 100, 0]:.8f}")
|
|
||||||
print(f" [150, 150]: {prepared_source[150, 150, 0]:.8f}")
|
|
||||||
|
|
||||||
# No cleanup needed since we returned array directly
|
|
||||||
assert (
|
|
||||||
cleanup_path is None
|
|
||||||
), "Cleanup path should be None for float32 pass through"
|
|
||||||
|
|
||||||
print("\n✓ All float32 passthrough tests passed!")
|
|
||||||
return True
|
|
||||||
|
|
||||||
else:
|
|
||||||
print(f"\n✗ FAILED: Prepared source is a file path: {prepared_source}")
|
|
||||||
print(f" This means data was saved to disk, not passed as float32 array")
|
|
||||||
if cleanup_path and os.path.exists(cleanup_path):
|
|
||||||
os.remove(cleanup_path)
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n✗ Test failed with error: {e}")
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# Cleanup
|
|
||||||
if os.path.exists(test_path):
|
|
||||||
os.remove(test_path)
|
|
||||||
print(f"\nCleaned up test file: {test_path}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
success = test_float32_passthrough()
|
|
||||||
sys.exit(0 if success else 1)
|
|
||||||
@@ -1,126 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Test script for YOLO preprocessing of 16-bit TIFF images.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import tifffile
|
|
||||||
from pathlib import Path
|
|
||||||
import tempfile
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
# Add parent directory to path to import modules
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
||||||
|
|
||||||
from src.model.yolo_wrapper import YOLOWrapper
|
|
||||||
from src.utils.image import Image
|
|
||||||
from PIL import Image as PILImage
|
|
||||||
|
|
||||||
|
|
||||||
def create_test_16bit_tiff(output_path: str) -> str:
|
|
||||||
"""Create a test 16-bit grayscale TIFF file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
output_path: Path where to save the test TIFF
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path to the created TIFF file
|
|
||||||
"""
|
|
||||||
# Create a 16-bit grayscale test image (200x200)
|
|
||||||
# With values ranging from 0 to 65535 (full 16-bit range)
|
|
||||||
height, width = 200, 200
|
|
||||||
|
|
||||||
# Create a gradient pattern
|
|
||||||
test_data = np.zeros((height, width), dtype=np.uint16)
|
|
||||||
for i in range(height):
|
|
||||||
for j in range(width):
|
|
||||||
# Create a diagonal gradient
|
|
||||||
test_data[i, j] = int((i + j) / (height + width - 2) * 65535)
|
|
||||||
|
|
||||||
# Save as TIFF
|
|
||||||
tifffile.imwrite(output_path, test_data)
|
|
||||||
print(f"Created test 16-bit TIFF: {output_path}")
|
|
||||||
print(f" Shape: {test_data.shape}")
|
|
||||||
print(f" Dtype: {test_data.dtype}")
|
|
||||||
print(f" Min value: {test_data.min()}")
|
|
||||||
print(f" Max value: {test_data.max()}")
|
|
||||||
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
|
|
||||||
def test_yolo_preprocessing():
|
|
||||||
"""Test YOLO preprocessing of 16-bit TIFF images."""
|
|
||||||
print("\n=== Testing YOLO Preprocessing of 16-bit TIFF ===")
|
|
||||||
|
|
||||||
# Create temporary test file
|
|
||||||
with tempfile.NamedTemporaryFile(suffix=".tif", delete=False) as tmp:
|
|
||||||
test_path = tmp.name
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Create test image
|
|
||||||
create_test_16bit_tiff(test_path)
|
|
||||||
|
|
||||||
# Create YOLOWrapper instance (no actual model loading needed for this test)
|
|
||||||
print("\nTesting YOLOWrapper._prepare_source()...")
|
|
||||||
wrapper = YOLOWrapper()
|
|
||||||
|
|
||||||
# Call _prepare_source to preprocess the image
|
|
||||||
prepared_path, cleanup_path = wrapper._prepare_source(test_path)
|
|
||||||
|
|
||||||
print(f"\nPreprocessing complete:")
|
|
||||||
print(f" Original path: {test_path}")
|
|
||||||
print(f" Prepared path: {prepared_path}")
|
|
||||||
print(f" Cleanup path: {cleanup_path}")
|
|
||||||
|
|
||||||
# Verify the prepared image exists
|
|
||||||
assert os.path.exists(prepared_path), "Prepared image should exist"
|
|
||||||
|
|
||||||
# Load the prepared image and verify it's uint8 RGB
|
|
||||||
prepared_img = PILImage.open(prepared_path)
|
|
||||||
print(f"\nPrepared image properties:")
|
|
||||||
print(f" Mode: {prepared_img.mode}")
|
|
||||||
print(f" Size: {prepared_img.size}")
|
|
||||||
print(f" Format: {prepared_img.format}")
|
|
||||||
|
|
||||||
# Convert to numpy to check values
|
|
||||||
img_array = np.array(prepared_img)
|
|
||||||
print(f" Shape: {img_array.shape}")
|
|
||||||
print(f" Dtype: {img_array.dtype}")
|
|
||||||
print(f" Min value: {img_array.min()}")
|
|
||||||
print(f" Max value: {img_array.max()}")
|
|
||||||
print(f" Mean value: {img_array.mean():.2f}")
|
|
||||||
|
|
||||||
# Verify it's RGB uint8
|
|
||||||
assert prepared_img.mode == "RGB", "Prepared image should be RGB"
|
|
||||||
assert img_array.dtype == np.uint8, "Prepared image should be uint8"
|
|
||||||
assert img_array.shape[2] == 3, "Prepared image should have 3 channels"
|
|
||||||
assert (
|
|
||||||
0 <= img_array.min() <= img_array.max() <= 255
|
|
||||||
), "Values should be in [0, 255]"
|
|
||||||
|
|
||||||
# Cleanup prepared file if needed
|
|
||||||
if cleanup_path and os.path.exists(cleanup_path):
|
|
||||||
os.remove(cleanup_path)
|
|
||||||
print(f"\nCleaned up prepared image: {cleanup_path}")
|
|
||||||
|
|
||||||
print("\n✓ All YOLO preprocessing tests passed!")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n✗ Test failed with error: {e}")
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# Cleanup
|
|
||||||
if os.path.exists(test_path):
|
|
||||||
os.remove(test_path)
|
|
||||||
print(f"Cleaned up test file: {test_path}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
success = test_yolo_preprocessing()
|
|
||||||
sys.exit(0 if success else 1)
|
|
||||||
Reference in New Issue
Block a user