Compare commits
23 Commits
training
...
float32int
| Author | SHA1 | Date | |
|---|---|---|---|
| 87095ec3f0 | |||
| 2dbfa54256 | |||
| c7e1271193 | |||
| aec0fbf83c | |||
| 908e9a5b82 | |||
| edcd448a61 | |||
| 2411223a14 | |||
| b3b1e3acff | |||
| 9c4c39fb39 | |||
| 20a87c9040 | |||
| 9f7d2be1ac | |||
| dbde07c0e8 | |||
| b3c5a51dbb | |||
| 9a221acb63 | |||
| 32a6a122bd | |||
| 9ba44043ef | |||
| 8eb1cc8c86 | |||
| e4ce882a18 | |||
| 6b6d6fad03 | |||
| c0684a9c14 | |||
| 221c80aa8c | |||
| 833b222fad | |||
| 5370d31dce |
@@ -12,12 +12,26 @@ image_repository:
|
||||
models:
|
||||
default_base_model: yolov8s-seg.pt
|
||||
models_directory: data/models
|
||||
base_model_choices:
|
||||
- yolov8s-seg.pt
|
||||
- yolo11s-seg.pt
|
||||
training:
|
||||
default_epochs: 100
|
||||
default_batch_size: 16
|
||||
default_imgsz: 640
|
||||
default_imgsz: 1024
|
||||
default_patience: 50
|
||||
default_lr0: 0.01
|
||||
two_stage:
|
||||
enabled: false
|
||||
stage1:
|
||||
epochs: 20
|
||||
lr0: 0.0005
|
||||
patience: 10
|
||||
freeze: 10
|
||||
stage2:
|
||||
epochs: 150
|
||||
lr0: 0.0003
|
||||
patience: 30
|
||||
last_dataset_yaml: /home/martin/code/object_detection/data/datasets/data.yaml
|
||||
last_dataset_dir: /home/martin/code/object_detection/data/datasets
|
||||
detection:
|
||||
|
||||
300
docs/16BIT_TIFF_SUPPORT.md
Normal file
300
docs/16BIT_TIFF_SUPPORT.md
Normal file
@@ -0,0 +1,300 @@
|
||||
# 16-bit TIFF Support for YOLO Object Detection
|
||||
|
||||
## Overview
|
||||
|
||||
This document describes the implementation of 16-bit grayscale TIFF support for YOLO object detection. The system properly loads 16-bit TIFF images, normalizes them to float32 [0-1], and handles them appropriately for both **inference** and **training** **without uint8 conversion** to preserve the full dynamic range and avoid data loss.
|
||||
|
||||
## Key Features
|
||||
|
||||
✅ Reads 16-bit or float32 images using tifffile
|
||||
✅ Converts to float32 [0-1] (NO uint8 conversion)
|
||||
✅ Replicates grayscale → RGB (3 channels)
|
||||
✅ **Inference**: Passes numpy arrays directly to YOLO (no file I/O)
|
||||
✅ **Training**: On-the-fly float32 conversion (NO disk caching)
|
||||
✅ Uses Ultralytics YOLOv8/v11 models
|
||||
✅ Works with segmentation models
|
||||
✅ No data loss, no double normalization, no silent clipping
|
||||
|
||||
## Changes Made
|
||||
|
||||
### 1. Dependencies ([`requirements.txt`](../requirements.txt:14))
|
||||
- Added `tifffile>=2023.0.0` for reliable 16-bit TIFF loading
|
||||
|
||||
### 2. Image Loading ([`src/utils/image.py`](../src/utils/image.py))
|
||||
|
||||
#### Enhanced TIFF Loading
|
||||
- Modified [`Image._load()`](../src/utils/image.py:87) to use `tifffile` for `.tif` and `.tiff` files
|
||||
- Preserves original 16-bit data type during loading
|
||||
- Properly handles both grayscale and multi-channel TIFF files
|
||||
|
||||
#### New Normalization Method
|
||||
Added [`Image.to_normalized_float32()`](../src/utils/image.py:280) method that:
|
||||
- Converts image data to `float32`
|
||||
- Properly scales values to [0, 1] range:
|
||||
- **16-bit images**: divides by 65535 (full dynamic range)
|
||||
- 8-bit images: divides by 255
|
||||
- Float images: clips to [0, 1]
|
||||
- Handles various data types automatically
|
||||
|
||||
### 3. YOLO Preprocessing ([`src/model/yolo_wrapper.py`](../src/model/yolo_wrapper.py))
|
||||
|
||||
Enhanced [`YOLOWrapper._prepare_source()`](../src/model/yolo_wrapper.py:231) to:
|
||||
1. Detect 16-bit TIFF files automatically
|
||||
2. Load and normalize to float32 [0-1] using the new method
|
||||
3. Replicate grayscale to RGB (3 channels)
|
||||
4. **Return numpy array directly** (NO file saving, NO uint8 conversion)
|
||||
5. Pass float32 array directly to YOLO for inference
|
||||
|
||||
## Processing Pipeline
|
||||
|
||||
### For Inference (predict)
|
||||
|
||||
For 16-bit TIFF files during inference:
|
||||
|
||||
1. **Load**: File loaded using `tifffile` → preserves 16-bit uint16 data
|
||||
2. **Normalize**: Convert to float32 and scale to [0, 1]
|
||||
```python
|
||||
float_data = uint16_data.astype(np.float32) / 65535.0
|
||||
```
|
||||
3. **RGB Conversion**: Replicate grayscale to 3 channels
|
||||
```python
|
||||
rgb_float = np.stack([float_data] * 3, axis=-1)
|
||||
```
|
||||
4. **Pass to YOLO**: Return float32 array directly (no uint8, no file I/O)
|
||||
5. **Inference**: YOLO processes the float32 [0-1] RGB array
|
||||
|
||||
### For Training (train)
|
||||
|
||||
Training now uses a custom dataset loader with on-the-fly conversion (NO disk caching):
|
||||
|
||||
1. **Custom Dataset**: Uses `Float32Dataset` class that extends Ultralytics' `YOLODataset`
|
||||
2. **Load On-The-Fly**: Each image is loaded and converted during training:
|
||||
- Detect 16-bit TIFF files automatically
|
||||
- Load with `tifffile` (preserves uint16)
|
||||
- Convert to float32 [0-1] in memory
|
||||
- Replicate to 3 channels (RGB)
|
||||
3. **No Disk Cache**: Conversion happens in memory, no files written
|
||||
4. **Train**: YOLO trains on float32 [0-1] RGB arrays directly
|
||||
|
||||
See [`src/utils/train_ultralytics_float.py`](../src/utils/train_ultralytics_float.py) for implementation.
|
||||
|
||||
### No Data Loss!
|
||||
|
||||
Unlike approaches that convert to uint8 (256 levels), this implementation:
|
||||
- Preserves full 16-bit dynamic range (65536 levels)
|
||||
- Maintains precision with float32 representation
|
||||
- For inference: passes data directly without file conversions
|
||||
- For training: uses float32 TIFFs (not uint8 PNGs)
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Image Loading
|
||||
|
||||
```python
|
||||
from src.utils.image import Image
|
||||
|
||||
# Load a 16-bit TIFF file
|
||||
img = Image("path/to/16bit_image.tif")
|
||||
|
||||
# Get normalized float32 data [0-1]
|
||||
normalized = img.to_normalized_float32() # Shape: (H, W), dtype: float32
|
||||
|
||||
# Original data is preserved
|
||||
original = img.data # Still uint16
|
||||
```
|
||||
|
||||
### YOLO Inference
|
||||
|
||||
The preprocessing is automatic - just use YOLO as normal:
|
||||
|
||||
```python
|
||||
from src.model.yolo_wrapper import YOLOWrapper
|
||||
|
||||
# Initialize model
|
||||
yolo = YOLOWrapper("yolov8s-seg.pt")
|
||||
yolo.load_model()
|
||||
|
||||
# Perform inference on 16-bit TIFF
|
||||
# The image will be automatically normalized and passed as float32 [0-1]
|
||||
detections = yolo.predict("path/to/16bit_image.tif", conf=0.25)
|
||||
```
|
||||
|
||||
### With InferenceEngine
|
||||
|
||||
```python
|
||||
from src.model.inference import InferenceEngine
|
||||
from src.database.db_manager import DatabaseManager
|
||||
|
||||
# Setup
|
||||
db = DatabaseManager("database.db")
|
||||
engine = InferenceEngine("model.pt", db, model_id=1)
|
||||
|
||||
# Detect objects in 16-bit TIFF
|
||||
result = engine.detect_single(
|
||||
image_path="path/to/16bit_image.tif",
|
||||
relative_path="images/16bit_image.tif",
|
||||
conf=0.25
|
||||
)
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
Three test scripts are provided:
|
||||
|
||||
### 1. Image Loading Test
|
||||
```bash
|
||||
./venv/bin/python tests/test_16bit_tiff_loading.py
|
||||
```
|
||||
|
||||
Tests:
|
||||
- Loading 16-bit TIFF files with tifffile
|
||||
- Normalization to float32 [0-1]
|
||||
- Data type and value range verification
|
||||
|
||||
### 2. Float32 Passthrough Test (Most Important!)
|
||||
```bash
|
||||
./venv/bin/python tests/test_yolo_16bit_float32.py
|
||||
```
|
||||
|
||||
Tests:
|
||||
- YOLO preprocessing returns numpy array (not file path)
|
||||
- Data is float32 [0-1] (not uint8)
|
||||
- No quantization to 256 levels (proves no uint8 conversion)
|
||||
- Sample output:
|
||||
```
|
||||
✓ SUCCESS: Prepared source is a numpy array (float32 passthrough)
|
||||
Shape: (200, 200, 3)
|
||||
Dtype: float32
|
||||
Min value: 0.000000
|
||||
Max value: 1.000000
|
||||
Unique values: 399
|
||||
|
||||
✓ SUCCESS: Data has 399 unique values (> 256)
|
||||
This confirms NO uint8 quantization occurred!
|
||||
```
|
||||
|
||||
### 3. Legacy Test (Shows Old Behavior)
|
||||
```bash
|
||||
./venv/bin/python tests/test_yolo_16bit_preprocessing.py
|
||||
```
|
||||
|
||||
This test shows the old behavior (uint8 conversion) - kept for comparison.
|
||||
|
||||
## Benefits
|
||||
|
||||
1. **No Data Loss**: Preserves full 16-bit dynamic range (65536 levels vs 256)
|
||||
2. **High Precision**: Float32 maintains fine-grained intensity differences
|
||||
3. **Automatic Processing**: No manual preprocessing needed
|
||||
4. **YOLO Compatible**: Ultralytics YOLO accepts float32 [0-1] arrays
|
||||
5. **Performance**: No intermediate file I/O for 16-bit TIFFs
|
||||
6. **Backwards Compatible**: Regular images (8-bit PNG, JPEG, etc.) still work as before
|
||||
|
||||
## Technical Notes
|
||||
|
||||
### Float32 vs uint8
|
||||
|
||||
**With uint8 conversion (OLD - BAD):**
|
||||
- 16-bit (65536 levels) → uint8 (256 levels) = **99.6% data loss!**
|
||||
- Fine intensity differences are lost
|
||||
- Quantization artifacts
|
||||
|
||||
**With float32 [0-1] (NEW - GOOD):**
|
||||
- 16-bit (65536 levels) → float32 (continuous) = **No data loss**
|
||||
- Full dynamic range preserved
|
||||
- Smooth gradients maintained
|
||||
|
||||
### Memory Considerations
|
||||
|
||||
For a 2048×2048 single-channel image:
|
||||
|
||||
| Format | Memory | Disk Space | Notes |
|
||||
|--------|--------|------------|-------|
|
||||
| Original 16-bit | 8 MB | ~8 MB | uint16 grayscale TIFF |
|
||||
| Float32 grayscale | 16 MB | - | Intermediate |
|
||||
| Float32 3-channel | 48 MB | ~48 MB | Training cache |
|
||||
| uint8 RGB (old) | 12 MB | ~12 MB | OLD approach with data loss |
|
||||
|
||||
The float32 approach uses ~3× more memory than uint8 during training but preserves **all information**.
|
||||
|
||||
**No Disk Cache**: The new on-the-fly approach eliminates the need for cached datasets on disk.
|
||||
|
||||
### Why Direct Numpy Array?
|
||||
|
||||
Passing numpy arrays directly to YOLO (instead of saving to file):
|
||||
|
||||
1. **Faster**: No disk I/O overhead
|
||||
2. **No Quantization**: Avoids PNG/JPEG quantization
|
||||
3. **Memory Efficient**: Single copy in memory
|
||||
4. **Cleaner**: No temp file management
|
||||
|
||||
Ultralytics YOLO supports various input types:
|
||||
- File paths (str): `"image.jpg"`
|
||||
- Numpy arrays: `np.ndarray` ← **we use this**
|
||||
- PIL Images: `PIL.Image`
|
||||
- Torch tensors: `torch.Tensor`
|
||||
|
||||
## Training with Float32 Dataset Loader
|
||||
|
||||
The system now includes a custom dataset loader for 16-bit TIFF training:
|
||||
|
||||
```python
|
||||
from src.utils.train_ultralytics_float import train_with_float32_loader
|
||||
|
||||
# Train with on-the-fly float32 conversion
|
||||
results = train_with_float32_loader(
|
||||
model_path="yolov8s-seg.pt",
|
||||
data_yaml="data/my_dataset/data.yaml",
|
||||
epochs=100,
|
||||
batch=16,
|
||||
imgsz=640,
|
||||
)
|
||||
```
|
||||
|
||||
The `Float32Dataset` class automatically:
|
||||
- Detects 16-bit TIFF files
|
||||
- Loads with `tifffile` (not PIL/cv2)
|
||||
- Converts to float32 [0-1] on-the-fly
|
||||
- Replicates to 3 channels
|
||||
- Integrates seamlessly with Ultralytics training pipeline
|
||||
|
||||
This is used automatically by the training tab in the GUI.
|
||||
|
||||
## Installation
|
||||
|
||||
Install the updated dependencies:
|
||||
|
||||
```bash
|
||||
./venv/bin/pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Or install tifffile directly:
|
||||
|
||||
```bash
|
||||
./venv/bin/pip install tifffile>=2023.0.0
|
||||
```
|
||||
|
||||
## Example Test Output
|
||||
|
||||
```
|
||||
=== Testing Float32 Passthrough (NO uint8) ===
|
||||
Created test 16-bit TIFF: /tmp/tmpdt5hm0ab.tif
|
||||
Shape: (200, 200)
|
||||
Dtype: uint16
|
||||
Min value: 0
|
||||
Max value: 65535
|
||||
|
||||
Preprocessing result:
|
||||
Prepared source type: <class 'numpy.ndarray'>
|
||||
|
||||
✓ SUCCESS: Prepared source is a numpy array (float32 passthrough)
|
||||
Shape: (200, 200, 3)
|
||||
Dtype: float32
|
||||
Min value: 0.000000
|
||||
Max value: 1.000000
|
||||
Mean value: 0.499992
|
||||
Unique values: 399
|
||||
|
||||
✓ SUCCESS: Data has 399 unique values (> 256)
|
||||
This confirms NO uint8 quantization occurred!
|
||||
|
||||
✓ All float32 passthrough tests passed!
|
||||
269
docs/TRAINING_16BIT_TIFF.md
Normal file
269
docs/TRAINING_16BIT_TIFF.md
Normal file
@@ -0,0 +1,269 @@
|
||||
# Training YOLO with 16-bit TIFF Datasets
|
||||
|
||||
## Quick Start
|
||||
|
||||
If your dataset contains 16-bit grayscale TIFF files, the training tab will automatically:
|
||||
|
||||
1. Detect 16-bit TIFF images in your dataset
|
||||
2. Convert them to float32 [0-1] RGB **on-the-fly** during training
|
||||
3. Train without any disk caching (memory-efficient)
|
||||
|
||||
**No manual intervention or disk space needed!**
|
||||
|
||||
## Why Float32 On-The-Fly Conversion?
|
||||
|
||||
### The Problem
|
||||
|
||||
YOLO's training expects:
|
||||
- 3-channel images (RGB)
|
||||
- Images loaded from disk by the dataloader
|
||||
|
||||
16-bit grayscale TIFFs are:
|
||||
- 1-channel (grayscale)
|
||||
- Need to be converted to RGB format
|
||||
|
||||
### The Solution
|
||||
|
||||
**NEW APPROACH (Current)**: On-the-fly float32 conversion
|
||||
- Load 16-bit TIFF with `tifffile` (not PIL/cv2)
|
||||
- Convert uint16 [0-65535] → float32 [0-1] in memory
|
||||
- Replicate grayscale to 3 channels
|
||||
- Pass directly to YOLO training pipeline
|
||||
- **No disk caching required!**
|
||||
|
||||
**OLD APPROACH (Deprecated)**: Disk caching
|
||||
- Created 16-bit RGB PNG cache files on disk
|
||||
- Required ~2x dataset size in disk space
|
||||
- Slower first training run
|
||||
|
||||
## How It Works
|
||||
|
||||
### Custom Dataset Loader
|
||||
|
||||
The system uses a custom `Float32Dataset` class that extends Ultralytics' `YOLODataset`:
|
||||
|
||||
```python
|
||||
from src.utils.train_ultralytics_float import Float32Dataset
|
||||
|
||||
# This dataset loader:
|
||||
# 1. Intercepts image loading
|
||||
# 2. Detects 16-bit TIFFs
|
||||
# 3. Converts to float32 [0-1] RGB on-the-fly
|
||||
# 4. Passes to training pipeline
|
||||
```
|
||||
|
||||
### Conversion Process
|
||||
|
||||
For each 16-bit grayscale TIFF during training:
|
||||
|
||||
```
|
||||
1. Load with tifffile → uint16 [0, 65535]
|
||||
2. Convert to float32 → img.astype(float32) / 65535.0
|
||||
3. Replicate to RGB → np.stack([img] * 3, axis=-1)
|
||||
4. Result: float32 [0, 1] RGB array, shape (H, W, 3)
|
||||
```
|
||||
|
||||
### Memory vs Disk
|
||||
|
||||
| Aspect | On-the-fly (NEW) | Disk Cache (OLD) |
|
||||
|--------|------------------|------------------|
|
||||
| Disk Space | Dataset size only | ~2× dataset size |
|
||||
| First Training | Fast | Slow (creates cache) |
|
||||
| Subsequent Training | Fast | Fast |
|
||||
| Data Loss | None | None |
|
||||
| Setup Required | None | Cache creation |
|
||||
|
||||
## Data Preservation
|
||||
|
||||
### Float32 Precision
|
||||
|
||||
16-bit TIFF: 65,536 levels (0-65535)
|
||||
Float32: ~7 decimal digits precision
|
||||
|
||||
**Conversion accuracy:**
|
||||
```python
|
||||
Original: 32768 (uint16, middle intensity)
|
||||
Float32: 32768 / 65535 = 0.50000763 (exact)
|
||||
```
|
||||
|
||||
Full 16-bit precision is preserved in float32 representation.
|
||||
|
||||
### Comparison to uint8
|
||||
|
||||
| Approach | Precision Loss | Recommended |
|
||||
|----------|----------------|-------------|
|
||||
| **float32 [0-1]** | None | ✓ YES |
|
||||
| uint16 RGB | None | ✓ YES (but disk-heavy) |
|
||||
| uint8 | 99.6% data loss | ✗ NO |
|
||||
|
||||
**Why NO uint8:**
|
||||
```
|
||||
Original values: 32768, 32769, 32770 (distinct)
|
||||
Converted to uint8: 128, 128, 128 (collapsed!)
|
||||
```
|
||||
|
||||
Multiple 16-bit values collapse to the same uint8 value.
|
||||
|
||||
## Training Tab Behavior
|
||||
|
||||
When you click "Start Training" with a 16-bit TIFF dataset:
|
||||
|
||||
```
|
||||
[01:23:45] Exported 150 annotations across 50 image(s).
|
||||
[01:23:45] Using Float32 on-the-fly loader for 16-bit TIFF support (no disk caching)
|
||||
[01:23:45] Starting training run 'my_model_v1' using yolov8s-seg.pt
|
||||
[01:23:46] Using Float32Dataset loader for 16-bit TIFF support
|
||||
```
|
||||
|
||||
Every training run uses the same approach - fast and efficient!
|
||||
|
||||
## Inference vs Training
|
||||
|
||||
| Operation | Input | Processing | Output to YOLO |
|
||||
|-----------|-------|------------|----------------|
|
||||
| **Inference** | 16-bit TIFF file | Load → float32 [0-1] → 3ch | numpy array (float32) |
|
||||
| **Training** | 16-bit TIFF dataset | Load on-the-fly → float32 [0-1] → 3ch | numpy array (float32) |
|
||||
|
||||
Both preserve full 16-bit precision using float32 representation.
|
||||
|
||||
## Technical Details
|
||||
|
||||
### Custom Dataset Class
|
||||
|
||||
Located in `src/utils/train_ultralytics_float.py`:
|
||||
|
||||
```python
|
||||
class Float32Dataset(YOLODataset):
|
||||
"""
|
||||
Extends Ultralytics YOLODataset to handle 16-bit TIFFs.
|
||||
|
||||
Key methods:
|
||||
- load_image(): Intercepts image loading
|
||||
- Detects .tif/.tiff with dtype == uint16
|
||||
- Converts: uint16 → float32 [0-1] → RGB (3-channel)
|
||||
"""
|
||||
```
|
||||
|
||||
### Integration with YOLO
|
||||
|
||||
The `YOLOWrapper.train()` method automatically uses the custom loader:
|
||||
|
||||
```python
|
||||
# In src/model/yolo_wrapper.py
|
||||
def train(self, data_yaml, use_float32_loader=True, **kwargs):
|
||||
if use_float32_loader:
|
||||
# Use custom Float32Dataset
|
||||
return train_with_float32_loader(...)
|
||||
else:
|
||||
# Standard YOLO training
|
||||
return self.model.train(...)
|
||||
```
|
||||
|
||||
### No PIL or cv2 for 16-bit
|
||||
|
||||
16-bit TIFF loading uses `tifffile` directly:
|
||||
- PIL: Can load 16-bit but converts during processing
|
||||
- cv2: Limited 16-bit TIFF support
|
||||
- tifffile: Native 16-bit support, numpy output
|
||||
|
||||
## Advantages Over Disk Caching
|
||||
|
||||
### 1. No Disk Space Required
|
||||
```
|
||||
Dataset: 1000 images × 12 MB = 12 GB
|
||||
Old cache: Additional 24 GB (16-bit RGB PNGs)
|
||||
New approach: 0 GB additional (on-the-fly)
|
||||
```
|
||||
|
||||
### 2. Faster Setup
|
||||
```
|
||||
Old: First training requires cache creation (minutes)
|
||||
New: Start training immediately (seconds)
|
||||
```
|
||||
|
||||
### 3. Always In Sync
|
||||
```
|
||||
Old: Cache could become stale if images change
|
||||
New: Always loads current version from disk
|
||||
```
|
||||
|
||||
### 4. Simpler Workflow
|
||||
```
|
||||
Old: Manage cache directory, cleanup, etc.
|
||||
New: Just point to dataset and train
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Error: "expected input to have 3 channels, but got 1"
|
||||
|
||||
This shouldn't happen with the new Float32Dataset, but if it does:
|
||||
|
||||
1. Check that `use_float32_loader=True` in training call
|
||||
2. Verify `Float32Dataset` is being used (check logs)
|
||||
3. Ensure `tifffile` is installed: `pip install tifffile`
|
||||
|
||||
### Memory Usage
|
||||
|
||||
On-the-fly conversion uses memory during training:
|
||||
- Image loaded: ~24 MB (2048×2048 uint16)
|
||||
- Converted float32 RGB: ~48 MB (temporary)
|
||||
- Released after augmentation pipeline
|
||||
|
||||
**Mitigation:**
|
||||
- Reduce batch size if OOM errors occur
|
||||
- Images are processed one at a time during loading
|
||||
- Only active batch kept in memory
|
||||
|
||||
### Slow Training
|
||||
|
||||
If training seems slow:
|
||||
- Check disk I/O (slow disk can bottleneck loading)
|
||||
- Verify images aren't being re-converted each epoch (should cache after first load)
|
||||
- Monitor CPU usage during loading
|
||||
|
||||
## Migration from Old Approach
|
||||
|
||||
If you have existing cached datasets:
|
||||
|
||||
```bash
|
||||
# Old cache location (safe to delete)
|
||||
rm -rf data/datasets/_float32_cache/
|
||||
|
||||
# The new approach doesn't use this directory
|
||||
```
|
||||
|
||||
Your original dataset structure remains unchanged:
|
||||
```
|
||||
data/my_dataset/
|
||||
├── train/
|
||||
│ ├── images/ (original 16-bit TIFFs)
|
||||
│ └── labels/
|
||||
├── val/
|
||||
│ ├── images/
|
||||
│ └── labels/
|
||||
└── data.yaml
|
||||
```
|
||||
|
||||
Just point to the same `data.yaml` and train!
|
||||
|
||||
## Performance Comparison
|
||||
|
||||
| Metric | Old (Disk Cache) | New (On-the-fly) |
|
||||
|--------|------------------|------------------|
|
||||
| First training setup | 5-10 min | 0 sec |
|
||||
| Disk space overhead | 100% | 0% |
|
||||
| Training speed | Fast | Fast |
|
||||
| Subsequent runs | Fast | Fast |
|
||||
| Data accuracy | 16-bit preserved | 16-bit preserved |
|
||||
|
||||
## Summary
|
||||
|
||||
✓ **On-the-fly conversion**: Load and convert during training
|
||||
✓ **No disk caching**: Zero additional disk space
|
||||
✓ **Full precision**: Float32 preserves 16-bit dynamic range
|
||||
✓ **No PIL/cv2**: Direct tifffile loading
|
||||
✓ **Automatic**: Works transparently with training tab
|
||||
✓ **Fast**: Efficient memory-based conversion
|
||||
|
||||
The new approach is simpler, faster to set up, and requires no disk space overhead!
|
||||
@@ -11,6 +11,7 @@ pyqtgraph>=0.13.0
|
||||
opencv-python>=4.8.0
|
||||
Pillow>=10.0.0
|
||||
numpy>=1.24.0
|
||||
tifffile>=2023.0.0
|
||||
|
||||
# Database
|
||||
sqlalchemy>=2.0.0
|
||||
|
||||
179
scripts/README_FLOAT32_TRAINING.md
Normal file
179
scripts/README_FLOAT32_TRAINING.md
Normal file
@@ -0,0 +1,179 @@
|
||||
# Standalone Float32 Training Script for 16-bit TIFFs
|
||||
|
||||
## Overview
|
||||
|
||||
This standalone script (`train_float32_standalone.py`) trains YOLO models on 16-bit grayscale TIFF datasets with **no data loss**.
|
||||
|
||||
- Loads 16-bit TIFFs with `tifffile` (not PIL/cv2)
|
||||
- Converts to float32 [0-1] on-the-fly (preserves full 16-bit precision)
|
||||
- Replicates grayscale → 3-channel RGB in memory
|
||||
- **No disk caching required**
|
||||
- Uses custom PyTorch Dataset + training loop
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Activate virtual environment
|
||||
source venv/bin/activate
|
||||
|
||||
# Train on your 16-bit TIFF dataset
|
||||
python scripts/train_float32_standalone.py \
|
||||
--data data/my_dataset/data.yaml \
|
||||
--weights yolov8s-seg.pt \
|
||||
--epochs 100 \
|
||||
--batch 16 \
|
||||
--imgsz 640 \
|
||||
--lr 0.0001 \
|
||||
--save-dir runs/my_training \
|
||||
--device cuda
|
||||
```
|
||||
|
||||
## Arguments
|
||||
|
||||
| Argument | Required | Default | Description |
|
||||
|----------|----------|---------|-------------|
|
||||
| `--data` | Yes | - | Path to YOLO data.yaml file |
|
||||
| `--weights` | No | yolov8s-seg.pt | Pretrained model weights |
|
||||
| `--epochs` | No | 100 | Number of training epochs |
|
||||
| `--batch` | No | 16 | Batch size |
|
||||
| `--imgsz` | No | 640 | Input image size |
|
||||
| `--lr` | No | 0.0001 | Learning rate |
|
||||
| `--save-dir` | No | runs/train | Directory to save checkpoints |
|
||||
| `--device` | No | cuda/cpu | Training device (auto-detected) |
|
||||
|
||||
## Dataset Format
|
||||
|
||||
Your data.yaml should follow standard YOLO format:
|
||||
|
||||
```yaml
|
||||
path: /path/to/dataset
|
||||
train: train/images
|
||||
val: val/images
|
||||
test: test/images # optional
|
||||
|
||||
names:
|
||||
0: class1
|
||||
1: class2
|
||||
|
||||
nc: 2
|
||||
```
|
||||
|
||||
Directory structure:
|
||||
```
|
||||
dataset/
|
||||
├── train/
|
||||
│ ├── images/
|
||||
│ │ ├── img1.tif (16-bit grayscale TIFF)
|
||||
│ │ └── img2.tif
|
||||
│ └── labels/
|
||||
│ ├── img1.txt (YOLO format)
|
||||
│ └── img2.txt
|
||||
├── val/
|
||||
│ ├── images/
|
||||
│ └── labels/
|
||||
└── data.yaml
|
||||
```
|
||||
|
||||
## Output
|
||||
|
||||
The script saves:
|
||||
- `epoch{N}.pt`: Checkpoint after each epoch
|
||||
- `best.pt`: Best model weights (lowest loss)
|
||||
- Training logs to console
|
||||
|
||||
## Features
|
||||
|
||||
✅ **16-bit precision preserved**: Float32 [0-1] maintains full dynamic range
|
||||
✅ **No disk caching**: Conversion happens in memory
|
||||
✅ **No PIL/cv2**: Direct tifffile loading
|
||||
✅ **Variable-length labels**: Handles segmentation polygons
|
||||
✅ **Checkpoint saving**: Resume training if interrupted
|
||||
✅ **Best model tracking**: Automatically saves best weights
|
||||
|
||||
## Example
|
||||
|
||||
Train a segmentation model on microscopy data:
|
||||
|
||||
```bash
|
||||
python scripts/train_float32_standalone.py \
|
||||
--data data/microscopy/data.yaml \
|
||||
--weights yolov11s-seg.pt \
|
||||
--epochs 150 \
|
||||
--batch 8 \
|
||||
--imgsz 1024 \
|
||||
--lr 0.0003 \
|
||||
--save-dir data/models/microscopy_v1
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Out of Memory (OOM)
|
||||
Reduce batch size:
|
||||
```bash
|
||||
--batch 4
|
||||
```
|
||||
|
||||
### Slow Loading
|
||||
Reduce num_workers (edit script line 208):
|
||||
```python
|
||||
num_workers=2 # instead of 4
|
||||
```
|
||||
|
||||
### Different Image Sizes
|
||||
The script expects all images to have the same dimensions. For variable sizes:
|
||||
1. Implement letterbox/resize in dataset's `_read_image()`
|
||||
2. Or preprocess images to same size
|
||||
|
||||
### Loss Computation Errors
|
||||
If you see "Cannot determine loss", the script may need adjustment for your Ultralytics version. Check:
|
||||
```python
|
||||
# In train() function, the preds format may vary
|
||||
# Current script assumes: preds is tuple with loss OR dict with 'loss' key
|
||||
```
|
||||
|
||||
## vs GUI Training
|
||||
|
||||
| Feature | Standalone Script | GUI Training Tab |
|
||||
|---------|------------------|------------------|
|
||||
| Float32 conversion | ✓ Yes | ✓ Yes (automatic) |
|
||||
| Disk caching | ✗ None | ✗ None |
|
||||
| Progress UI | ✗ Console only | ✓ Visual progress bar |
|
||||
| Dataset selection | Manual CLI args | ✓ GUI browsing |
|
||||
| Multi-stage training | Manual runs | ✓ Built-in |
|
||||
| Use case | Advanced users | General users |
|
||||
|
||||
## Technical Details
|
||||
|
||||
### Data Loading Pipeline
|
||||
|
||||
```
|
||||
16-bit TIFF file
|
||||
↓ (tifffile.imread)
|
||||
uint16 [0-65535]
|
||||
↓ (/ 65535.0)
|
||||
float32 [0-1]
|
||||
↓ (replicate channels)
|
||||
float32 RGB (H,W,3) [0-1]
|
||||
↓ (permute to C,H,W)
|
||||
torch.Tensor (3,H,W) float32
|
||||
↓ (DataLoader stack)
|
||||
Batch (B,3,H,W) float32
|
||||
↓
|
||||
YOLO Model
|
||||
```
|
||||
|
||||
### Precision Comparison
|
||||
|
||||
| Method | Unique Values | Data Loss |
|
||||
|--------|---------------|-----------|
|
||||
| **float32 [0-1]** | ~65,536 | None ✓ |
|
||||
| uint16 RGB | 65,536 | None ✓ |
|
||||
| uint8 | 256 | 99.6% ✗ |
|
||||
|
||||
Example: Pixel value 32,768 (middle intensity)
|
||||
- Float32: 32768 / 65535.0 = 0.50000763 (exact)
|
||||
- uint8: 32768 → 128 → many values collapse!
|
||||
|
||||
## License
|
||||
|
||||
Same as main project.
|
||||
351
scripts/train_float32_standalone.py
Executable file
351
scripts/train_float32_standalone.py
Executable file
@@ -0,0 +1,351 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Standalone training script for YOLO with 16-bit TIFF float32 support.
|
||||
|
||||
This script trains YOLO models on 16-bit grayscale TIFF datasets without data loss.
|
||||
Converts images to float32 [0-1] on-the-fly using tifffile (no PIL/cv2).
|
||||
|
||||
Usage:
|
||||
python scripts/train_float32_standalone.py \\
|
||||
--data path/to/data.yaml \\
|
||||
--weights yolov8s-seg.pt \\
|
||||
--epochs 100 \\
|
||||
--batch 16 \\
|
||||
--imgsz 640
|
||||
|
||||
Based on the custom dataset approach to avoid Ultralytics' channel conversion issues.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import tifffile
|
||||
import yaml
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from ultralytics import YOLO
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# ===================== Dataset =====================
|
||||
|
||||
|
||||
class Float32YOLODataset(Dataset):
|
||||
"""PyTorch dataset for 16-bit TIFF images with float32 conversion."""
|
||||
|
||||
def __init__(self, images_dir, labels_dir, img_size=640):
|
||||
self.images_dir = Path(images_dir)
|
||||
self.labels_dir = Path(labels_dir)
|
||||
self.img_size = img_size
|
||||
|
||||
# Find images
|
||||
extensions = {".tif", ".tiff", ".png", ".jpg", ".jpeg", ".bmp"}
|
||||
self.paths = sorted(
|
||||
[
|
||||
p
|
||||
for p in self.images_dir.rglob("*")
|
||||
if p.is_file() and p.suffix.lower() in extensions
|
||||
]
|
||||
)
|
||||
|
||||
if not self.paths:
|
||||
raise ValueError(f"No images found in {images_dir}")
|
||||
|
||||
logger.info(f"Dataset: {len(self.paths)} images from {images_dir}")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
||||
def _read_image(self, path: Path) -> np.ndarray:
|
||||
"""Load image as float32 [0-1] RGB."""
|
||||
# Load with tifffile
|
||||
img = tifffile.imread(str(path))
|
||||
|
||||
# Convert to float32
|
||||
img = img.astype(np.float32)
|
||||
|
||||
# Normalize 16-bit→[0,1]
|
||||
if img.max() > 1.5:
|
||||
img = img / 65535.0
|
||||
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
# Grayscale→RGB
|
||||
if img.ndim == 2:
|
||||
img = np.repeat(img[..., None], 3, axis=2)
|
||||
elif img.ndim == 3 and img.shape[2] == 1:
|
||||
img = np.repeat(img, 3, axis=2)
|
||||
|
||||
# Resize to model input size
|
||||
img = cv2.resize(img, (self.img_size, self.img_size))
|
||||
|
||||
return img # float32 (img_size, img_size, 3) [0,1] BGR
|
||||
|
||||
def _parse_label(self, path: Path) -> list:
|
||||
"""Parse YOLO label with variable-length rows."""
|
||||
if not path.exists():
|
||||
return []
|
||||
|
||||
labels = []
|
||||
with open(path, "r") as f:
|
||||
for line in f:
|
||||
vals = line.strip().split()
|
||||
if len(vals) >= 5:
|
||||
labels.append([float(v) for v in vals])
|
||||
|
||||
return labels
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_path = self.paths[idx]
|
||||
label_path = self.labels_dir / f"{img_path.stem}.txt"
|
||||
|
||||
# Load & convert to tensor (C,H,W)
|
||||
img = self._read_image(img_path)
|
||||
img_t = torch.from_numpy(img).permute(2, 0, 1).contiguous()
|
||||
|
||||
# Load labels
|
||||
labels = self._parse_label(label_path)
|
||||
|
||||
return img_t, labels, str(img_path.name)
|
||||
|
||||
|
||||
# ===================== Collate =====================
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
"""Stack images, keep labels as list."""
|
||||
imgs = torch.stack([b[0] for b in batch], dim=0)
|
||||
labels = [b[1] for b in batch]
|
||||
names = [b[2] for b in batch]
|
||||
return imgs, labels, names
|
||||
|
||||
|
||||
# ===================== Training =====================
|
||||
|
||||
|
||||
def get_pytorch_model(ul_model):
|
||||
"""Extract PyTorch model and loss from Ultralytics wrapper."""
|
||||
pt_model = None
|
||||
loss_fn = None
|
||||
|
||||
# Try common patterns
|
||||
if hasattr(ul_model, "model"):
|
||||
pt_model = ul_model.model
|
||||
|
||||
# Find loss
|
||||
if pt_model and hasattr(pt_model, "loss"):
|
||||
loss_fn = pt_model.loss
|
||||
elif pt_model and hasattr(pt_model, "compute_loss"):
|
||||
loss_fn = pt_model.compute_loss
|
||||
|
||||
if pt_model is None:
|
||||
raise RuntimeError("Could not extract PyTorch model")
|
||||
|
||||
return pt_model, loss_fn
|
||||
|
||||
|
||||
def train(args):
|
||||
"""Main training function."""
|
||||
device = args.device
|
||||
logger.info(f"Device: {device}")
|
||||
|
||||
# Parse data.yaml
|
||||
with open(args.data, "r") as f:
|
||||
data_config = yaml.safe_load(f)
|
||||
|
||||
dataset_root = Path(data_config.get("path", Path(args.data).parent))
|
||||
train_img = dataset_root / data_config.get("train", "train/images")
|
||||
val_img = dataset_root / data_config.get("val", "val/images")
|
||||
train_lbl = train_img.parent / "labels"
|
||||
val_lbl = val_img.parent / "labels"
|
||||
|
||||
# Load model
|
||||
logger.info(f"Loading {args.weights}")
|
||||
ul_model = YOLO(args.weights)
|
||||
pt_model, loss_fn = get_pytorch_model(ul_model)
|
||||
|
||||
# Configure model args
|
||||
from types import SimpleNamespace
|
||||
|
||||
if not hasattr(pt_model, "args"):
|
||||
pt_model.args = SimpleNamespace()
|
||||
if isinstance(pt_model.args, dict):
|
||||
pt_model.args = SimpleNamespace(**pt_model.args)
|
||||
|
||||
# Set segmentation loss args
|
||||
pt_model.args.overlap_mask = getattr(pt_model.args, "overlap_mask", True)
|
||||
pt_model.args.mask_ratio = getattr(pt_model.args, "mask_ratio", 4)
|
||||
pt_model.args.task = "segment"
|
||||
|
||||
pt_model.to(device)
|
||||
pt_model.train()
|
||||
for param in pt_model.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
# Create datasets
|
||||
train_ds = Float32YOLODataset(str(train_img), str(train_lbl), args.imgsz)
|
||||
val_ds = Float32YOLODataset(str(val_img), str(val_lbl), args.imgsz)
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_ds,
|
||||
batch_size=args.batch,
|
||||
shuffle=True,
|
||||
num_workers=4,
|
||||
pin_memory=(device == "cuda"),
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_ds,
|
||||
batch_size=args.batch,
|
||||
shuffle=False,
|
||||
num_workers=2,
|
||||
pin_memory=(device == "cuda"),
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
# Optimizer
|
||||
optimizer = torch.optim.AdamW(pt_model.parameters(), lr=args.lr)
|
||||
|
||||
# Training loop
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
best_loss = float("inf")
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
t0 = time.time()
|
||||
running_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
for imgs, labels_list, names in train_loader:
|
||||
imgs = imgs.to(device)
|
||||
optimizer.zero_grad()
|
||||
num_batches += 1
|
||||
|
||||
# Forward (simple approach - just use preds)
|
||||
preds = pt_model(imgs)
|
||||
|
||||
# Try to compute loss
|
||||
# Simplest fallback: if preds is tuple/list, assume last element is loss
|
||||
if isinstance(preds, (tuple, list)):
|
||||
# Often YOLO forward returns (preds, loss) in training mode
|
||||
if (
|
||||
len(preds) >= 2
|
||||
and isinstance(preds[-1], dict)
|
||||
and "loss" in preds[-1]
|
||||
):
|
||||
loss = preds[-1]["loss"]
|
||||
elif len(preds) >= 2 and isinstance(preds[-1], torch.Tensor):
|
||||
loss = preds[-1]
|
||||
else:
|
||||
# Manually compute using loss_fn if available
|
||||
if loss_fn:
|
||||
# This may fail - see logs
|
||||
try:
|
||||
loss_out = loss_fn(preds, labels_list)
|
||||
if isinstance(loss_out, dict):
|
||||
loss = loss_out["loss"]
|
||||
elif isinstance(loss_out, (tuple, list)):
|
||||
loss = loss_out[0]
|
||||
else:
|
||||
loss = loss_out
|
||||
except Exception as e:
|
||||
logger.error(f"Loss computation failed: {e}")
|
||||
logger.error(
|
||||
"Consider using Ultralytics .train() or check model/loss compatibility"
|
||||
)
|
||||
raise
|
||||
else:
|
||||
raise RuntimeError("Cannot determine loss from model output")
|
||||
elif isinstance(preds, dict) and "loss" in preds:
|
||||
loss = preds["loss"]
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected preds format: {type(preds)}")
|
||||
|
||||
# Backward
|
||||
loss = loss.mean()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += loss.item()
|
||||
|
||||
if (num_batches % 10) == 0:
|
||||
logger.info(
|
||||
f"Epoch {epoch+1} Batch {num_batches} Loss: {loss.item():.4f}"
|
||||
)
|
||||
|
||||
epoch_loss = running_loss / max(1, num_batches)
|
||||
epoch_time = time.time() - t0
|
||||
logger.info(
|
||||
f"Epoch {epoch+1}/{args.epochs} - Loss: {epoch_loss:.4f}, Time: {epoch_time:.1f}s"
|
||||
)
|
||||
|
||||
# Save checkpoint
|
||||
ckpt = Path(args.save_dir) / f"epoch{epoch+1}.pt"
|
||||
torch.save(
|
||||
{
|
||||
"epoch": epoch + 1,
|
||||
"model_state_dict": pt_model.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"loss": epoch_loss,
|
||||
},
|
||||
ckpt,
|
||||
)
|
||||
|
||||
# Save best
|
||||
if epoch_loss < best_loss:
|
||||
best_loss = epoch_loss
|
||||
best_ckpt = Path(args.save_dir) / "best.pt"
|
||||
torch.save(pt_model.state_dict(), best_ckpt)
|
||||
logger.info(f"New best: {best_ckpt}")
|
||||
|
||||
logger.info("Training complete")
|
||||
|
||||
|
||||
# ===================== Main =====================
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train YOLO on 16-bit TIFF with float32"
|
||||
)
|
||||
parser.add_argument("--data", type=str, required=True, help="Path to data.yaml")
|
||||
parser.add_argument(
|
||||
"--weights", type=str, default="yolov8s-seg.pt", help="Pretrained weights"
|
||||
)
|
||||
parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
|
||||
parser.add_argument("--batch", type=int, default=16, help="Batch size")
|
||||
parser.add_argument("--imgsz", type=int, default=640, help="Image size")
|
||||
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
|
||||
parser.add_argument(
|
||||
"--save-dir", type=str, default="runs/train", help="Save directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
logger.info("=" * 70)
|
||||
logger.info("Float32 16-bit TIFF Training - Standalone Script")
|
||||
logger.info("=" * 70)
|
||||
logger.info(f"Data: {args.data}")
|
||||
logger.info(f"Weights: {args.weights}")
|
||||
logger.info(f"Epochs: {args.epochs}, Batch: {args.batch}, ImgSz: {args.imgsz}")
|
||||
logger.info(f"LR: {args.lr}, Device: {args.device}")
|
||||
logger.info("=" * 70)
|
||||
|
||||
train(args)
|
||||
@@ -13,8 +13,9 @@ import hashlib
|
||||
import yaml
|
||||
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.image import Image
|
||||
|
||||
IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp")
|
||||
IMAGE_EXTENSIONS = tuple(Image.SUPPORTED_EXTENSIONS)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -450,6 +451,25 @@ class DatabaseManager:
|
||||
filters["model_id"] = model_id
|
||||
return self.get_detections(filters)
|
||||
|
||||
def delete_detections_for_image(
|
||||
self, image_id: int, model_id: Optional[int] = None
|
||||
) -> int:
|
||||
"""Delete detections tied to a specific image and optional model."""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
if model_id is not None:
|
||||
cursor.execute(
|
||||
"DELETE FROM detections WHERE image_id = ? AND model_id = ?",
|
||||
(image_id, model_id),
|
||||
)
|
||||
else:
|
||||
cursor.execute("DELETE FROM detections WHERE image_id = ?", (image_id,))
|
||||
conn.commit()
|
||||
return cursor.rowcount
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def delete_detections_for_model(self, model_id: int) -> int:
|
||||
"""Delete all detections for a specific model."""
|
||||
conn = self.get_connection()
|
||||
|
||||
@@ -168,7 +168,7 @@ class AnnotationTab(QWidget):
|
||||
self,
|
||||
"Select Image",
|
||||
start_dir,
|
||||
"Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)",
|
||||
"Images (*" + " *".join(Image.SUPPORTED_EXTENSIONS) + ")",
|
||||
)
|
||||
|
||||
if not file_path:
|
||||
|
||||
@@ -20,12 +20,14 @@ from PySide6.QtWidgets import (
|
||||
)
|
||||
from PySide6.QtCore import Qt, QThread, Signal
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from src.database.db_manager import DatabaseManager
|
||||
from src.utils.config_manager import ConfigManager
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.file_utils import get_image_files
|
||||
from src.model.inference import InferenceEngine
|
||||
from src.utils.image import Image
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -147,30 +149,66 @@ class DetectionTab(QWidget):
|
||||
self.model_combo.currentIndexChanged.connect(self._on_model_changed)
|
||||
|
||||
def _load_models(self):
|
||||
"""Load available models from database."""
|
||||
"""Load available models from database and local storage."""
|
||||
try:
|
||||
models = self.db_manager.get_models()
|
||||
self.model_combo.clear()
|
||||
models = self.db_manager.get_models()
|
||||
has_models = False
|
||||
|
||||
if not models:
|
||||
self.model_combo.addItem("No models available", None)
|
||||
self._set_buttons_enabled(False)
|
||||
return
|
||||
known_paths = set()
|
||||
|
||||
# Add base model option
|
||||
# Add base model option first (always available)
|
||||
base_model = self.config_manager.get(
|
||||
"models.default_base_model", "yolov8s-seg.pt"
|
||||
)
|
||||
self.model_combo.addItem(
|
||||
f"Base Model ({base_model})", {"id": 0, "path": base_model}
|
||||
)
|
||||
if base_model:
|
||||
base_data = {
|
||||
"id": 0,
|
||||
"path": base_model,
|
||||
"model_name": Path(base_model).stem or "Base Model",
|
||||
"model_version": "pretrained",
|
||||
"base_model": base_model,
|
||||
"source": "base",
|
||||
}
|
||||
self.model_combo.addItem(f"Base Model ({base_model})", base_data)
|
||||
known_paths.add(self._normalize_model_path(base_model))
|
||||
has_models = True
|
||||
|
||||
# Add trained models
|
||||
# Add trained models from database
|
||||
for model in models:
|
||||
display_name = f"{model['model_name']} v{model['model_version']}"
|
||||
self.model_combo.addItem(display_name, model)
|
||||
model_data = {**model, "path": model.get("model_path")}
|
||||
normalized = self._normalize_model_path(model_data.get("path"))
|
||||
if normalized:
|
||||
known_paths.add(normalized)
|
||||
self.model_combo.addItem(display_name, model_data)
|
||||
has_models = True
|
||||
|
||||
self._set_buttons_enabled(True)
|
||||
# Discover local model files not yet in the database
|
||||
local_models = self._discover_local_models()
|
||||
for model_path in local_models:
|
||||
normalized = self._normalize_model_path(model_path)
|
||||
if normalized in known_paths:
|
||||
continue
|
||||
|
||||
display_name = f"Local Model ({Path(model_path).stem})"
|
||||
model_data = {
|
||||
"id": None,
|
||||
"path": str(model_path),
|
||||
"model_name": Path(model_path).stem,
|
||||
"model_version": "local",
|
||||
"base_model": Path(model_path).stem,
|
||||
"source": "local",
|
||||
}
|
||||
self.model_combo.addItem(display_name, model_data)
|
||||
known_paths.add(normalized)
|
||||
has_models = True
|
||||
|
||||
if not has_models:
|
||||
self.model_combo.addItem("No models available", None)
|
||||
self._set_buttons_enabled(False)
|
||||
else:
|
||||
self._set_buttons_enabled(True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading models: {e}")
|
||||
@@ -199,7 +237,7 @@ class DetectionTab(QWidget):
|
||||
self,
|
||||
"Select Image",
|
||||
start_dir,
|
||||
"Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)",
|
||||
"Images (*" + " *".join(Image.SUPPORTED_EXTENSIONS) + ")",
|
||||
)
|
||||
|
||||
if not file_path:
|
||||
@@ -249,25 +287,39 @@ class DetectionTab(QWidget):
|
||||
QMessageBox.warning(self, "No Model", "Please select a model first.")
|
||||
return
|
||||
|
||||
model_path = model_data["path"]
|
||||
model_id = model_data["id"]
|
||||
model_path = model_data.get("path")
|
||||
if not model_path:
|
||||
QMessageBox.warning(
|
||||
self, "Invalid Model", "Selected model is missing a file path."
|
||||
)
|
||||
return
|
||||
|
||||
# Ensure we have a valid model ID (create entry for base model if needed)
|
||||
if model_id == 0:
|
||||
# Create database entry for base model
|
||||
base_model = self.config_manager.get(
|
||||
"models.default_base_model", "yolov8s-seg.pt"
|
||||
)
|
||||
model_id = self.db_manager.add_model(
|
||||
model_name="Base Model",
|
||||
model_version="pretrained",
|
||||
model_path=base_model,
|
||||
base_model=base_model,
|
||||
if not Path(model_path).exists():
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Model Not Found",
|
||||
f"The selected model file could not be found:\n{model_path}",
|
||||
)
|
||||
return
|
||||
|
||||
model_id = model_data.get("id")
|
||||
|
||||
# Ensure we have a database entry for the selected model
|
||||
if model_id in (None, 0):
|
||||
model_id = self._ensure_model_record(model_data)
|
||||
if not model_id:
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Model Registration Failed",
|
||||
"Unable to register the selected model in the database.",
|
||||
)
|
||||
return
|
||||
|
||||
normalized_model_path = self._normalize_model_path(model_path) or model_path
|
||||
|
||||
# Create inference engine
|
||||
self.inference_engine = InferenceEngine(
|
||||
model_path, self.db_manager, model_id
|
||||
normalized_model_path, self.db_manager, model_id
|
||||
)
|
||||
|
||||
# Get confidence threshold
|
||||
@@ -338,6 +390,76 @@ class DetectionTab(QWidget):
|
||||
self.batch_btn.setEnabled(enabled)
|
||||
self.model_combo.setEnabled(enabled)
|
||||
|
||||
def _discover_local_models(self) -> list:
|
||||
"""Scan the models directory for standalone .pt files."""
|
||||
models_dir = self.config_manager.get_models_directory()
|
||||
if not models_dir:
|
||||
return []
|
||||
|
||||
models_path = Path(models_dir)
|
||||
if not models_path.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
return sorted(
|
||||
[p for p in models_path.rglob("*.pt") if p.is_file()],
|
||||
key=lambda p: str(p).lower(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error discovering local models: {e}")
|
||||
return []
|
||||
|
||||
def _normalize_model_path(self, path_value) -> str:
|
||||
"""Return a normalized absolute path string for comparison."""
|
||||
if not path_value:
|
||||
return ""
|
||||
try:
|
||||
return str(Path(path_value).resolve())
|
||||
except Exception:
|
||||
return str(path_value)
|
||||
|
||||
def _ensure_model_record(self, model_data: dict) -> Optional[int]:
|
||||
"""Ensure a database record exists for the selected model."""
|
||||
model_path = model_data.get("path")
|
||||
if not model_path:
|
||||
return None
|
||||
|
||||
normalized_target = self._normalize_model_path(model_path)
|
||||
|
||||
try:
|
||||
existing_models = self.db_manager.get_models()
|
||||
for model in existing_models:
|
||||
existing_path = model.get("model_path")
|
||||
if not existing_path:
|
||||
continue
|
||||
normalized_existing = self._normalize_model_path(existing_path)
|
||||
if (
|
||||
normalized_existing == normalized_target
|
||||
or existing_path == model_path
|
||||
):
|
||||
return model["id"]
|
||||
|
||||
model_name = (
|
||||
model_data.get("model_name") or Path(model_path).stem or "Custom Model"
|
||||
)
|
||||
model_version = (
|
||||
model_data.get("model_version") or model_data.get("source") or "local"
|
||||
)
|
||||
base_model = model_data.get(
|
||||
"base_model",
|
||||
self.config_manager.get("models.default_base_model", "yolov8s-seg.pt"),
|
||||
)
|
||||
|
||||
return self.db_manager.add_model(
|
||||
model_name=model_name,
|
||||
model_version=model_version,
|
||||
model_path=normalized_target,
|
||||
base_model=base_model,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ensure model record for {model_path}: {e}")
|
||||
return None
|
||||
|
||||
def refresh(self):
|
||||
"""Refresh the tab."""
|
||||
self._load_models()
|
||||
|
||||
@@ -1,15 +1,39 @@
|
||||
"""
|
||||
Results tab for the microscopy object detection application.
|
||||
Results tab for browsing stored detections and visualizing overlays.
|
||||
"""
|
||||
|
||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from PySide6.QtWidgets import (
|
||||
QWidget,
|
||||
QVBoxLayout,
|
||||
QHBoxLayout,
|
||||
QLabel,
|
||||
QGroupBox,
|
||||
QPushButton,
|
||||
QSplitter,
|
||||
QTableWidget,
|
||||
QTableWidgetItem,
|
||||
QHeaderView,
|
||||
QAbstractItemView,
|
||||
QMessageBox,
|
||||
QCheckBox,
|
||||
)
|
||||
from PySide6.QtCore import Qt
|
||||
|
||||
from src.database.db_manager import DatabaseManager
|
||||
from src.utils.config_manager import ConfigManager
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.image import Image, ImageLoadError
|
||||
from src.gui.widgets import AnnotationCanvasWidget
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ResultsTab(QWidget):
|
||||
"""Results tab placeholder."""
|
||||
"""Results tab showing detection history and preview overlays."""
|
||||
|
||||
def __init__(
|
||||
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
||||
@@ -18,29 +42,398 @@ class ResultsTab(QWidget):
|
||||
self.db_manager = db_manager
|
||||
self.config_manager = config_manager
|
||||
|
||||
self.detection_summary: List[Dict] = []
|
||||
self.current_selection: Optional[Dict] = None
|
||||
self.current_image: Optional[Image] = None
|
||||
self.current_detections: List[Dict] = []
|
||||
self._image_path_cache: Dict[str, str] = {}
|
||||
|
||||
self._setup_ui()
|
||||
self.refresh()
|
||||
|
||||
def _setup_ui(self):
|
||||
"""Setup user interface."""
|
||||
layout = QVBoxLayout()
|
||||
|
||||
group = QGroupBox("Results")
|
||||
group_layout = QVBoxLayout()
|
||||
label = QLabel(
|
||||
"Results viewer will be implemented here.\n\n"
|
||||
"Features:\n"
|
||||
"- Detection history browser\n"
|
||||
"- Advanced filtering\n"
|
||||
"- Statistics dashboard\n"
|
||||
"- Export functionality"
|
||||
)
|
||||
group_layout.addWidget(label)
|
||||
group.setLayout(group_layout)
|
||||
# Splitter for list + preview
|
||||
splitter = QSplitter(Qt.Horizontal)
|
||||
|
||||
layout.addWidget(group)
|
||||
layout.addStretch()
|
||||
# Left pane: detection list
|
||||
left_container = QWidget()
|
||||
left_layout = QVBoxLayout()
|
||||
left_layout.setContentsMargins(0, 0, 0, 0)
|
||||
|
||||
controls_layout = QHBoxLayout()
|
||||
self.refresh_btn = QPushButton("Refresh")
|
||||
self.refresh_btn.clicked.connect(self.refresh)
|
||||
controls_layout.addWidget(self.refresh_btn)
|
||||
controls_layout.addStretch()
|
||||
left_layout.addLayout(controls_layout)
|
||||
|
||||
self.results_table = QTableWidget(0, 5)
|
||||
self.results_table.setHorizontalHeaderLabels(
|
||||
["Image", "Model", "Detections", "Classes", "Last Updated"]
|
||||
)
|
||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
||||
0, QHeaderView.Stretch
|
||||
)
|
||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
||||
1, QHeaderView.Stretch
|
||||
)
|
||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
||||
2, QHeaderView.ResizeToContents
|
||||
)
|
||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
||||
3, QHeaderView.Stretch
|
||||
)
|
||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
||||
4, QHeaderView.ResizeToContents
|
||||
)
|
||||
self.results_table.setSelectionBehavior(QAbstractItemView.SelectRows)
|
||||
self.results_table.setSelectionMode(QAbstractItemView.SingleSelection)
|
||||
self.results_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
|
||||
self.results_table.itemSelectionChanged.connect(self._on_result_selected)
|
||||
|
||||
left_layout.addWidget(self.results_table)
|
||||
left_container.setLayout(left_layout)
|
||||
|
||||
# Right pane: preview canvas and controls
|
||||
right_container = QWidget()
|
||||
right_layout = QVBoxLayout()
|
||||
right_layout.setContentsMargins(0, 0, 0, 0)
|
||||
|
||||
preview_group = QGroupBox("Detection Preview")
|
||||
preview_layout = QVBoxLayout()
|
||||
|
||||
self.preview_canvas = AnnotationCanvasWidget()
|
||||
self.preview_canvas.set_polyline_enabled(False)
|
||||
self.preview_canvas.set_show_bboxes(True)
|
||||
preview_layout.addWidget(self.preview_canvas)
|
||||
|
||||
toggles_layout = QHBoxLayout()
|
||||
self.show_masks_checkbox = QCheckBox("Show Masks")
|
||||
self.show_masks_checkbox.setChecked(False)
|
||||
self.show_masks_checkbox.stateChanged.connect(self._apply_detection_overlays)
|
||||
self.show_bboxes_checkbox = QCheckBox("Show Bounding Boxes")
|
||||
self.show_bboxes_checkbox.setChecked(True)
|
||||
self.show_bboxes_checkbox.stateChanged.connect(self._toggle_bboxes)
|
||||
self.show_confidence_checkbox = QCheckBox("Show Confidence")
|
||||
self.show_confidence_checkbox.setChecked(False)
|
||||
self.show_confidence_checkbox.stateChanged.connect(
|
||||
self._apply_detection_overlays
|
||||
)
|
||||
toggles_layout.addWidget(self.show_masks_checkbox)
|
||||
toggles_layout.addWidget(self.show_bboxes_checkbox)
|
||||
toggles_layout.addWidget(self.show_confidence_checkbox)
|
||||
toggles_layout.addStretch()
|
||||
preview_layout.addLayout(toggles_layout)
|
||||
|
||||
self.summary_label = QLabel("Select a detection result to preview.")
|
||||
self.summary_label.setWordWrap(True)
|
||||
preview_layout.addWidget(self.summary_label)
|
||||
|
||||
preview_group.setLayout(preview_layout)
|
||||
right_layout.addWidget(preview_group)
|
||||
right_container.setLayout(right_layout)
|
||||
|
||||
splitter.addWidget(left_container)
|
||||
splitter.addWidget(right_container)
|
||||
splitter.setStretchFactor(0, 1)
|
||||
splitter.setStretchFactor(1, 2)
|
||||
|
||||
layout.addWidget(splitter)
|
||||
self.setLayout(layout)
|
||||
|
||||
def refresh(self):
|
||||
"""Refresh the tab."""
|
||||
pass
|
||||
"""Refresh the detection list and preview."""
|
||||
self._load_detection_summary()
|
||||
self._populate_results_table()
|
||||
self.current_selection = None
|
||||
self.current_image = None
|
||||
self.current_detections = []
|
||||
self.preview_canvas.clear()
|
||||
self.summary_label.setText("Select a detection result to preview.")
|
||||
|
||||
def _load_detection_summary(self):
|
||||
"""Load latest detection summaries grouped by image + model."""
|
||||
try:
|
||||
detections = self.db_manager.get_detections(limit=500)
|
||||
summary_map: Dict[tuple, Dict] = {}
|
||||
|
||||
for det in detections:
|
||||
key = (det["image_id"], det["model_id"])
|
||||
metadata = det.get("metadata") or {}
|
||||
entry = summary_map.setdefault(
|
||||
key,
|
||||
{
|
||||
"image_id": det["image_id"],
|
||||
"model_id": det["model_id"],
|
||||
"image_path": det.get("image_path"),
|
||||
"image_filename": det.get("image_filename")
|
||||
or det.get("image_path"),
|
||||
"model_name": det.get("model_name", ""),
|
||||
"model_version": det.get("model_version", ""),
|
||||
"last_detected": det.get("detected_at"),
|
||||
"count": 0,
|
||||
"classes": set(),
|
||||
"source_path": metadata.get("source_path"),
|
||||
"repository_root": metadata.get("repository_root"),
|
||||
},
|
||||
)
|
||||
|
||||
entry["count"] += 1
|
||||
if det.get("detected_at") and (
|
||||
not entry.get("last_detected")
|
||||
or str(det.get("detected_at")) > str(entry.get("last_detected"))
|
||||
):
|
||||
entry["last_detected"] = det.get("detected_at")
|
||||
if det.get("class_name"):
|
||||
entry["classes"].add(det["class_name"])
|
||||
if metadata.get("source_path") and not entry.get("source_path"):
|
||||
entry["source_path"] = metadata.get("source_path")
|
||||
if metadata.get("repository_root") and not entry.get("repository_root"):
|
||||
entry["repository_root"] = metadata.get("repository_root")
|
||||
|
||||
self.detection_summary = sorted(
|
||||
summary_map.values(),
|
||||
key=lambda x: str(x.get("last_detected") or ""),
|
||||
reverse=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load detection summary: {e}")
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Error",
|
||||
f"Failed to load detection results:\n{str(e)}",
|
||||
)
|
||||
self.detection_summary = []
|
||||
|
||||
def _populate_results_table(self):
|
||||
"""Populate the table widget with detection summaries."""
|
||||
self.results_table.setRowCount(len(self.detection_summary))
|
||||
|
||||
for row, entry in enumerate(self.detection_summary):
|
||||
model_label = f"{entry['model_name']} {entry['model_version']}".strip()
|
||||
class_list = (
|
||||
", ".join(sorted(entry["classes"])) if entry["classes"] else "-"
|
||||
)
|
||||
|
||||
items = [
|
||||
QTableWidgetItem(entry.get("image_filename", "")),
|
||||
QTableWidgetItem(model_label),
|
||||
QTableWidgetItem(str(entry.get("count", 0))),
|
||||
QTableWidgetItem(class_list),
|
||||
QTableWidgetItem(str(entry.get("last_detected") or "")),
|
||||
]
|
||||
|
||||
for col, item in enumerate(items):
|
||||
item.setData(Qt.UserRole, row)
|
||||
self.results_table.setItem(row, col, item)
|
||||
|
||||
self.results_table.clearSelection()
|
||||
|
||||
def _on_result_selected(self):
|
||||
"""Handle selection changes in the detection table."""
|
||||
selected_items = self.results_table.selectedItems()
|
||||
if not selected_items:
|
||||
return
|
||||
|
||||
row = selected_items[0].data(Qt.UserRole)
|
||||
if row is None or row >= len(self.detection_summary):
|
||||
return
|
||||
|
||||
entry = self.detection_summary[row]
|
||||
if (
|
||||
self.current_selection
|
||||
and self.current_selection.get("image_id") == entry["image_id"]
|
||||
and self.current_selection.get("model_id") == entry["model_id"]
|
||||
):
|
||||
return
|
||||
|
||||
self.current_selection = entry
|
||||
|
||||
image_path = self._resolve_image_path(entry)
|
||||
if not image_path:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"Image Not Found",
|
||||
"Unable to locate the image file for this detection.",
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
self.current_image = Image(image_path)
|
||||
self.preview_canvas.load_image(self.current_image)
|
||||
except ImageLoadError as e:
|
||||
logger.error(f"Failed to load image '{image_path}': {e}")
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Image Error",
|
||||
f"Failed to load image for preview:\n{str(e)}",
|
||||
)
|
||||
return
|
||||
|
||||
self._load_detections_for_selection(entry)
|
||||
self._apply_detection_overlays()
|
||||
self._update_summary_label(entry)
|
||||
|
||||
def _load_detections_for_selection(self, entry: Dict):
|
||||
"""Load detection records for the selected image/model pair."""
|
||||
self.current_detections = []
|
||||
if not entry:
|
||||
return
|
||||
|
||||
try:
|
||||
filters = {"image_id": entry["image_id"], "model_id": entry["model_id"]}
|
||||
self.current_detections = self.db_manager.get_detections(filters)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load detections for preview: {e}")
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Error",
|
||||
f"Failed to load detections for this image:\n{str(e)}",
|
||||
)
|
||||
self.current_detections = []
|
||||
|
||||
def _apply_detection_overlays(self):
|
||||
"""Draw detections onto the preview canvas based on current toggles."""
|
||||
self.preview_canvas.clear_annotations()
|
||||
self.preview_canvas.set_show_bboxes(self.show_bboxes_checkbox.isChecked())
|
||||
|
||||
if not self.current_detections or not self.current_image:
|
||||
return
|
||||
|
||||
for det in self.current_detections:
|
||||
color = self._get_class_color(det.get("class_name"))
|
||||
|
||||
if self.show_masks_checkbox.isChecked() and det.get("segmentation_mask"):
|
||||
mask_points = self._convert_mask(det["segmentation_mask"])
|
||||
if mask_points:
|
||||
self.preview_canvas.draw_saved_polyline(mask_points, color)
|
||||
|
||||
bbox = [
|
||||
det.get("x_min"),
|
||||
det.get("y_min"),
|
||||
det.get("x_max"),
|
||||
det.get("y_max"),
|
||||
]
|
||||
if all(v is not None for v in bbox):
|
||||
label = None
|
||||
if self.show_confidence_checkbox.isChecked():
|
||||
confidence = det.get("confidence")
|
||||
if confidence is not None:
|
||||
label = f"{confidence:.2f}"
|
||||
self.preview_canvas.draw_saved_bbox(bbox, color, label=label)
|
||||
|
||||
def _convert_mask(self, mask_points: List[List[float]]) -> List[List[float]]:
|
||||
"""Convert stored [x, y] masks to [y, x] format for the canvas."""
|
||||
converted = []
|
||||
for point in mask_points:
|
||||
if len(point) >= 2:
|
||||
x, y = point[0], point[1]
|
||||
converted.append([y, x])
|
||||
return converted
|
||||
|
||||
def _toggle_bboxes(self):
|
||||
"""Update bounding box visibility on the canvas."""
|
||||
self.preview_canvas.set_show_bboxes(self.show_bboxes_checkbox.isChecked())
|
||||
# Re-render to respect show/hide when toggled
|
||||
self._apply_detection_overlays()
|
||||
|
||||
def _update_summary_label(self, entry: Dict):
|
||||
"""Display textual summary for the selected detection run."""
|
||||
classes = ", ".join(sorted(entry.get("classes", []))) or "-"
|
||||
summary_text = (
|
||||
f"Image: {entry.get('image_filename', 'unknown')}\n"
|
||||
f"Model: {entry.get('model_name', '')} {entry.get('model_version', '')}\n"
|
||||
f"Detections: {entry.get('count', 0)}\n"
|
||||
f"Classes: {classes}\n"
|
||||
f"Last Updated: {entry.get('last_detected', 'n/a')}"
|
||||
)
|
||||
self.summary_label.setText(summary_text)
|
||||
|
||||
def _resolve_image_path(self, entry: Dict) -> Optional[str]:
|
||||
"""Resolve an image path using metadata, cache, and repository hints."""
|
||||
relative_path = entry.get("image_path") if entry else None
|
||||
cache_key = relative_path or entry.get("source_path")
|
||||
if cache_key and cache_key in self._image_path_cache:
|
||||
cached = Path(self._image_path_cache[cache_key])
|
||||
if cached.exists():
|
||||
return self._image_path_cache[cache_key]
|
||||
del self._image_path_cache[cache_key]
|
||||
|
||||
candidates = []
|
||||
source_path = entry.get("source_path") if entry else None
|
||||
if source_path:
|
||||
candidates.append(Path(source_path))
|
||||
|
||||
repo_roots = []
|
||||
if entry.get("repository_root"):
|
||||
repo_roots.append(entry["repository_root"])
|
||||
config_repo = self.config_manager.get_image_repository_path()
|
||||
if config_repo:
|
||||
repo_roots.append(config_repo)
|
||||
|
||||
for root in repo_roots:
|
||||
if relative_path:
|
||||
candidates.append(Path(root) / relative_path)
|
||||
|
||||
if relative_path:
|
||||
candidates.append(Path(relative_path))
|
||||
|
||||
for candidate in candidates:
|
||||
try:
|
||||
if candidate and candidate.exists():
|
||||
resolved = str(candidate.resolve())
|
||||
if cache_key:
|
||||
self._image_path_cache[cache_key] = resolved
|
||||
return resolved
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Fallback: search by filename in known roots
|
||||
filename = Path(relative_path).name if relative_path else None
|
||||
if filename:
|
||||
search_roots = [Path(root) for root in repo_roots if root]
|
||||
if not search_roots:
|
||||
search_roots = [Path("data")]
|
||||
match = self._search_in_roots(filename, search_roots)
|
||||
if match:
|
||||
resolved = str(match.resolve())
|
||||
if cache_key:
|
||||
self._image_path_cache[cache_key] = resolved
|
||||
return resolved
|
||||
|
||||
return None
|
||||
|
||||
def _search_in_roots(self, filename: str, roots: List[Path]) -> Optional[Path]:
|
||||
"""Search for a file name within a list of root directories."""
|
||||
for root in roots:
|
||||
try:
|
||||
if not root.exists():
|
||||
continue
|
||||
for candidate in root.rglob(filename):
|
||||
return candidate
|
||||
except Exception as e:
|
||||
logger.debug(f"Error searching for {filename} in {root}: {e}")
|
||||
return None
|
||||
|
||||
def _get_class_color(self, class_name: Optional[str]) -> str:
|
||||
"""Return consistent color hex for a class name."""
|
||||
if not class_name:
|
||||
return "#FF6B6B"
|
||||
|
||||
color_map = self.config_manager.get_bbox_colors()
|
||||
if class_name in color_map:
|
||||
return color_map[class_name]
|
||||
|
||||
# Deterministic fallback color based on hash
|
||||
palette = [
|
||||
"#FF6B6B",
|
||||
"#4ECDC4",
|
||||
"#FFD166",
|
||||
"#1D3557",
|
||||
"#F4A261",
|
||||
"#E76F51",
|
||||
]
|
||||
return palette[hash(class_name) % len(palette)]
|
||||
|
||||
@@ -3,14 +3,12 @@ Training tab for the microscopy object detection application.
|
||||
Handles model training with YOLO.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
from PIL import Image as PILImage
|
||||
from PySide6.QtCore import Qt, QThread, Signal
|
||||
from PySide6.QtWidgets import (
|
||||
QWidget,
|
||||
@@ -28,24 +26,20 @@ from PySide6.QtWidgets import (
|
||||
QProgressBar,
|
||||
QSpinBox,
|
||||
QDoubleSpinBox,
|
||||
QCheckBox,
|
||||
QScrollArea,
|
||||
)
|
||||
|
||||
from src.database.db_manager import DatabaseManager
|
||||
from src.model.yolo_wrapper import YOLOWrapper
|
||||
from src.utils.config_manager import ConfigManager
|
||||
from src.utils.image import Image
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
DEFAULT_IMAGE_EXTENSIONS = {
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".png",
|
||||
".tif",
|
||||
".tiff",
|
||||
".bmp",
|
||||
}
|
||||
DEFAULT_IMAGE_EXTENSIONS = set(Image.SUPPORTED_EXTENSIONS)
|
||||
|
||||
|
||||
class TrainingWorker(QThread):
|
||||
@@ -67,6 +61,8 @@ class TrainingWorker(QThread):
|
||||
save_dir: str,
|
||||
run_name: str,
|
||||
parent: Optional[QThread] = None,
|
||||
stage_plan: Optional[List[Dict[str, Any]]] = None,
|
||||
total_epochs: Optional[int] = None,
|
||||
):
|
||||
super().__init__(parent)
|
||||
self.data_yaml = data_yaml
|
||||
@@ -78,6 +74,27 @@ class TrainingWorker(QThread):
|
||||
self.lr0 = lr0
|
||||
self.save_dir = save_dir
|
||||
self.run_name = run_name
|
||||
self.stage_plan = stage_plan or [
|
||||
{
|
||||
"label": "Single Stage",
|
||||
"model_path": base_model,
|
||||
"use_previous_best": False,
|
||||
"params": {
|
||||
"epochs": epochs,
|
||||
"batch": batch,
|
||||
"imgsz": imgsz,
|
||||
"patience": patience,
|
||||
"lr0": lr0,
|
||||
"freeze": 0,
|
||||
"name": run_name,
|
||||
},
|
||||
}
|
||||
]
|
||||
computed_total = sum(
|
||||
max(0, int((stage.get("params") or {}).get("epochs", 0)))
|
||||
for stage in self.stage_plan
|
||||
)
|
||||
self.total_epochs = total_epochs if total_epochs else computed_total or epochs
|
||||
self._stop_requested = False
|
||||
|
||||
def stop(self):
|
||||
@@ -86,36 +103,98 @@ class TrainingWorker(QThread):
|
||||
self.requestInterruption()
|
||||
|
||||
def run(self):
|
||||
"""Execute YOLO training and emit progress/finished signals."""
|
||||
wrapper = YOLOWrapper(self.base_model)
|
||||
"""Execute YOLO training over one or more stages and emit progress/finished signals."""
|
||||
|
||||
def on_epoch_end(trainer):
|
||||
current_epoch = getattr(trainer, "epoch", 0) + 1
|
||||
metrics: Dict[str, float] = {}
|
||||
loss_items = getattr(trainer, "loss_items", None)
|
||||
if loss_items:
|
||||
metrics["loss"] = float(loss_items[-1])
|
||||
self.progress.emit(current_epoch, self.epochs, metrics)
|
||||
if self.isInterruptionRequested() or self._stop_requested:
|
||||
setattr(trainer, "stop_training", True)
|
||||
completed_epochs = 0
|
||||
stage_history: List[Dict[str, Any]] = []
|
||||
last_stage_results: Optional[Dict[str, Any]] = None
|
||||
|
||||
callbacks = {"on_fit_epoch_end": on_epoch_end}
|
||||
for stage_index, stage in enumerate(self.stage_plan, start=1):
|
||||
if self._stop_requested or self.isInterruptionRequested():
|
||||
break
|
||||
|
||||
try:
|
||||
results = wrapper.train(
|
||||
data_yaml=self.data_yaml,
|
||||
epochs=self.epochs,
|
||||
imgsz=self.imgsz,
|
||||
batch=self.batch,
|
||||
patience=self.patience,
|
||||
save_dir=self.save_dir,
|
||||
name=self.run_name,
|
||||
lr0=self.lr0,
|
||||
callbacks=callbacks,
|
||||
stage_label = stage.get("label") or f"Stage {stage_index}"
|
||||
stage_params = dict(stage.get("params") or {})
|
||||
stage_epochs = int(stage_params.get("epochs", self.epochs))
|
||||
if stage_epochs <= 0:
|
||||
stage_epochs = 1
|
||||
batch = int(stage_params.get("batch", self.batch))
|
||||
imgsz = int(stage_params.get("imgsz", self.imgsz))
|
||||
patience = int(stage_params.get("patience", self.patience))
|
||||
lr0 = float(stage_params.get("lr0", self.lr0))
|
||||
freeze = int(stage_params.get("freeze", 0))
|
||||
run_name = stage_params.get("name") or f"{self.run_name}_stage{stage_index}"
|
||||
|
||||
weights_path = stage.get("model_path") or self.base_model
|
||||
if stage.get("use_previous_best") and last_stage_results:
|
||||
weights_path = (
|
||||
last_stage_results.get("best_model_path")
|
||||
or last_stage_results.get("last_model_path")
|
||||
or weights_path
|
||||
)
|
||||
|
||||
wrapper = YOLOWrapper(weights_path)
|
||||
stage_offset = completed_epochs
|
||||
|
||||
def on_epoch_end(trainer, offset=stage_offset):
|
||||
current_epoch = getattr(trainer, "epoch", 0) + 1
|
||||
metrics: Dict[str, float] = {}
|
||||
loss_items = getattr(trainer, "loss_items", None)
|
||||
if loss_items:
|
||||
metrics["loss"] = float(loss_items[-1])
|
||||
absolute_epoch = min(
|
||||
max(1, offset + current_epoch),
|
||||
max(1, self.total_epochs),
|
||||
)
|
||||
self.progress.emit(absolute_epoch, self.total_epochs, metrics)
|
||||
if self.isInterruptionRequested() or self._stop_requested:
|
||||
setattr(trainer, "stop_training", True)
|
||||
|
||||
callbacks = {"on_fit_epoch_end": on_epoch_end}
|
||||
|
||||
try:
|
||||
stage_result = wrapper.train(
|
||||
data_yaml=self.data_yaml,
|
||||
epochs=stage_epochs,
|
||||
imgsz=imgsz,
|
||||
batch=batch,
|
||||
patience=patience,
|
||||
save_dir=self.save_dir,
|
||||
name=run_name,
|
||||
lr0=lr0,
|
||||
callbacks=callbacks,
|
||||
freeze=freeze,
|
||||
)
|
||||
except Exception as exc:
|
||||
self.error.emit(str(exc))
|
||||
return
|
||||
|
||||
stage_history.append(
|
||||
{
|
||||
"label": stage_label,
|
||||
"params": stage_params,
|
||||
"weights_used": weights_path,
|
||||
"results": stage_result,
|
||||
}
|
||||
)
|
||||
self.finished.emit(results)
|
||||
except Exception as exc:
|
||||
self.error.emit(str(exc))
|
||||
last_stage_results = stage_result
|
||||
completed_epochs += stage_epochs
|
||||
|
||||
final_payload: Dict[str, Any]
|
||||
if last_stage_results:
|
||||
final_payload = dict(last_stage_results)
|
||||
else:
|
||||
final_payload = {
|
||||
"success": False,
|
||||
"message": "Training stopped before any stage completed.",
|
||||
}
|
||||
|
||||
final_payload["stage_results"] = stage_history
|
||||
final_payload["total_epochs_completed"] = completed_epochs
|
||||
final_payload["total_epochs_planned"] = self.total_epochs
|
||||
final_payload["stages_completed"] = len(stage_history)
|
||||
|
||||
self.finished.emit(final_payload)
|
||||
|
||||
|
||||
class TrainingTab(QWidget):
|
||||
@@ -146,12 +225,23 @@ class TrainingTab(QWidget):
|
||||
|
||||
def _setup_ui(self):
|
||||
"""Setup user interface."""
|
||||
layout = QVBoxLayout()
|
||||
# Create a container widget for all content
|
||||
container = QWidget()
|
||||
container_layout = QVBoxLayout(container)
|
||||
|
||||
layout.addWidget(self._create_dataset_group())
|
||||
layout.addWidget(self._create_training_controls_group())
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
container_layout.addWidget(self._create_dataset_group())
|
||||
container_layout.addWidget(self._create_training_controls_group())
|
||||
container_layout.addStretch()
|
||||
|
||||
# Create scroll area and set the container as its widget
|
||||
scroll_area = QScrollArea()
|
||||
scroll_area.setWidget(container)
|
||||
scroll_area.setWidgetResizable(True)
|
||||
|
||||
# Set main layout with scroll area
|
||||
main_layout = QVBoxLayout(self)
|
||||
main_layout.setContentsMargins(0, 0, 0, 0)
|
||||
main_layout.addWidget(scroll_area)
|
||||
|
||||
self._discover_datasets()
|
||||
self._load_saved_dataset()
|
||||
@@ -249,13 +339,26 @@ class TrainingTab(QWidget):
|
||||
default_base_model = self.config_manager.get(
|
||||
"models.default_base_model", "yolov8s-seg.pt"
|
||||
)
|
||||
base_model_choices = self.config_manager.get("models.base_model_choices", [])
|
||||
|
||||
self.base_model_combo = QComboBox()
|
||||
self.base_model_combo.addItem("Custom path…", "")
|
||||
for choice in base_model_choices:
|
||||
self.base_model_combo.addItem(choice, choice)
|
||||
self.base_model_combo.currentIndexChanged.connect(
|
||||
self._on_base_model_preset_changed
|
||||
)
|
||||
form_layout.addRow("Base Model Preset:", self.base_model_combo)
|
||||
|
||||
base_model_layout = QHBoxLayout()
|
||||
self.base_model_edit = QLineEdit(default_base_model)
|
||||
self.base_model_edit.editingFinished.connect(self._on_base_model_path_edited)
|
||||
base_model_layout.addWidget(self.base_model_edit)
|
||||
self.base_model_browse_button = QPushButton("Browse…")
|
||||
self.base_model_browse_button.clicked.connect(self._browse_base_model)
|
||||
base_model_layout.addWidget(self.base_model_browse_button)
|
||||
form_layout.addRow("Base Model (.pt):", base_model_layout)
|
||||
self._sync_base_model_preset_selection(default_base_model)
|
||||
|
||||
models_dir = self.config_manager.get("models.models_directory", "data/models")
|
||||
save_dir_layout = QHBoxLayout()
|
||||
@@ -298,6 +401,9 @@ class TrainingTab(QWidget):
|
||||
|
||||
group_layout.addLayout(form_layout)
|
||||
|
||||
self.two_stage_group = self._create_two_stage_group(training_defaults)
|
||||
group_layout.addWidget(self.two_stage_group)
|
||||
|
||||
button_layout = QHBoxLayout()
|
||||
self.start_training_button = QPushButton("Start Training")
|
||||
self.start_training_button.clicked.connect(self._start_training)
|
||||
@@ -322,6 +428,134 @@ class TrainingTab(QWidget):
|
||||
group.setLayout(group_layout)
|
||||
return group
|
||||
|
||||
def _create_two_stage_group(self, training_defaults: Dict[str, Any]) -> QGroupBox:
|
||||
group = QGroupBox("Two-Stage Fine-Tuning")
|
||||
group_layout = QVBoxLayout()
|
||||
|
||||
self.two_stage_checkbox = QCheckBox("Enable staged head-only + full fine-tune")
|
||||
two_stage_defaults = (
|
||||
training_defaults.get("two_stage", {}) if training_defaults else {}
|
||||
)
|
||||
self.two_stage_checkbox.setChecked(
|
||||
bool(two_stage_defaults.get("enabled", False))
|
||||
)
|
||||
self.two_stage_checkbox.toggled.connect(self._on_two_stage_toggled)
|
||||
group_layout.addWidget(self.two_stage_checkbox)
|
||||
|
||||
self.two_stage_controls_widget = QWidget()
|
||||
controls_layout = QVBoxLayout()
|
||||
controls_layout.setContentsMargins(0, 0, 0, 0)
|
||||
controls_layout.setSpacing(8)
|
||||
|
||||
stage1_group = QGroupBox("Stage 1 — Head-only stabilization")
|
||||
stage1_form = QFormLayout()
|
||||
stage1_defaults = two_stage_defaults.get("stage1", {})
|
||||
|
||||
self.stage1_epochs_spin = QSpinBox()
|
||||
self.stage1_epochs_spin.setRange(1, 500)
|
||||
self.stage1_epochs_spin.setValue(int(stage1_defaults.get("epochs", 20)))
|
||||
stage1_form.addRow("Epochs:", self.stage1_epochs_spin)
|
||||
|
||||
self.stage1_lr_spin = QDoubleSpinBox()
|
||||
self.stage1_lr_spin.setDecimals(5)
|
||||
self.stage1_lr_spin.setRange(0.00001, 0.1)
|
||||
self.stage1_lr_spin.setSingleStep(0.0005)
|
||||
self.stage1_lr_spin.setValue(float(stage1_defaults.get("lr0", 0.0005)))
|
||||
stage1_form.addRow("Learning Rate:", self.stage1_lr_spin)
|
||||
|
||||
self.stage1_patience_spin = QSpinBox()
|
||||
self.stage1_patience_spin.setRange(1, 200)
|
||||
self.stage1_patience_spin.setValue(int(stage1_defaults.get("patience", 10)))
|
||||
stage1_form.addRow("Patience:", self.stage1_patience_spin)
|
||||
|
||||
self.stage1_freeze_spin = QSpinBox()
|
||||
self.stage1_freeze_spin.setRange(0, 24)
|
||||
self.stage1_freeze_spin.setValue(int(stage1_defaults.get("freeze", 10)))
|
||||
stage1_form.addRow("Freeze layers:", self.stage1_freeze_spin)
|
||||
|
||||
stage1_group.setLayout(stage1_form)
|
||||
controls_layout.addWidget(stage1_group)
|
||||
|
||||
stage2_group = QGroupBox("Stage 2 — Full fine-tuning")
|
||||
stage2_form = QFormLayout()
|
||||
stage2_defaults = two_stage_defaults.get("stage2", {})
|
||||
|
||||
self.stage2_epochs_spin = QSpinBox()
|
||||
self.stage2_epochs_spin.setRange(1, 2000)
|
||||
self.stage2_epochs_spin.setValue(int(stage2_defaults.get("epochs", 150)))
|
||||
stage2_form.addRow("Epochs:", self.stage2_epochs_spin)
|
||||
|
||||
self.stage2_lr_spin = QDoubleSpinBox()
|
||||
self.stage2_lr_spin.setDecimals(5)
|
||||
self.stage2_lr_spin.setRange(0.00001, 0.1)
|
||||
self.stage2_lr_spin.setSingleStep(0.0005)
|
||||
self.stage2_lr_spin.setValue(float(stage2_defaults.get("lr0", 0.0003)))
|
||||
stage2_form.addRow("Learning Rate:", self.stage2_lr_spin)
|
||||
|
||||
self.stage2_patience_spin = QSpinBox()
|
||||
self.stage2_patience_spin.setRange(1, 200)
|
||||
self.stage2_patience_spin.setValue(int(stage2_defaults.get("patience", 30)))
|
||||
stage2_form.addRow("Patience:", self.stage2_patience_spin)
|
||||
|
||||
stage2_group.setLayout(stage2_form)
|
||||
controls_layout.addWidget(stage2_group)
|
||||
|
||||
helper_label = QLabel(
|
||||
"When enabled, staged hyperparameters override the global epochs/patience/lr."
|
||||
)
|
||||
helper_label.setWordWrap(True)
|
||||
controls_layout.addWidget(helper_label)
|
||||
|
||||
self.two_stage_controls_widget.setLayout(controls_layout)
|
||||
group_layout.addWidget(self.two_stage_controls_widget)
|
||||
|
||||
group.setLayout(group_layout)
|
||||
self._on_two_stage_toggled(self.two_stage_checkbox.isChecked())
|
||||
return group
|
||||
|
||||
def _on_two_stage_toggled(self, checked: bool):
|
||||
self._refresh_two_stage_controls_enabled(checked)
|
||||
|
||||
def _refresh_two_stage_controls_enabled(self, checked: Optional[bool] = None):
|
||||
if not hasattr(self, "two_stage_controls_widget"):
|
||||
return
|
||||
desired_state = checked
|
||||
if desired_state is None:
|
||||
desired_state = self.two_stage_checkbox.isChecked()
|
||||
can_edit = self.two_stage_checkbox.isEnabled()
|
||||
self.two_stage_controls_widget.setEnabled(bool(desired_state and can_edit))
|
||||
|
||||
def _on_base_model_preset_changed(self, index: int):
|
||||
preset_value = self.base_model_combo.itemData(index)
|
||||
if preset_value:
|
||||
self.base_model_edit.setText(str(preset_value))
|
||||
elif index == 0:
|
||||
self.base_model_edit.setFocus()
|
||||
|
||||
def _on_base_model_path_edited(self):
|
||||
self._sync_base_model_preset_selection(self.base_model_edit.text().strip())
|
||||
|
||||
def _sync_base_model_preset_selection(self, model_path: str):
|
||||
if not hasattr(self, "base_model_combo"):
|
||||
return
|
||||
normalized = (model_path or "").strip()
|
||||
target_index = 0
|
||||
for idx in range(1, self.base_model_combo.count()):
|
||||
preset_value = self.base_model_combo.itemData(idx)
|
||||
if not preset_value:
|
||||
continue
|
||||
if normalized == preset_value:
|
||||
target_index = idx
|
||||
break
|
||||
if normalized.endswith(f"/{preset_value}") or normalized.endswith(
|
||||
f"\\{preset_value}"
|
||||
):
|
||||
target_index = idx
|
||||
break
|
||||
self.base_model_combo.blockSignals(True)
|
||||
self.base_model_combo.setCurrentIndex(target_index)
|
||||
self.base_model_combo.blockSignals(False)
|
||||
|
||||
def _get_dataset_search_roots(self) -> List[Path]:
|
||||
roots: List[Path] = []
|
||||
default_root = Path("data/datasets").expanduser()
|
||||
@@ -346,6 +580,7 @@ class TrainingTab(QWidget):
|
||||
for yaml_path in root.rglob("*.yaml"):
|
||||
if yaml_path.name.lower() not in {"data.yaml", "dataset.yaml"}:
|
||||
continue
|
||||
|
||||
discovered.append(yaml_path.resolve())
|
||||
except Exception as exc:
|
||||
logger.warning(f"Unable to scan {root}: {exc}")
|
||||
@@ -711,9 +946,6 @@ class TrainingTab(QWidget):
|
||||
for msg in split_messages:
|
||||
self._append_training_log(msg)
|
||||
|
||||
if dataset_yaml:
|
||||
self._clear_rgb_cache_for_dataset(dataset_yaml)
|
||||
|
||||
def _export_labels_for_split(
|
||||
self,
|
||||
split_name: str,
|
||||
@@ -928,136 +1160,89 @@ class TrainingTab(QWidget):
|
||||
return 1.0
|
||||
return value
|
||||
|
||||
def _prepare_dataset_for_training(
|
||||
self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None
|
||||
) -> Path:
|
||||
dataset_info = dataset_info or (
|
||||
self.selected_dataset
|
||||
if self.selected_dataset
|
||||
and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
|
||||
else self._parse_dataset_yaml(dataset_yaml)
|
||||
)
|
||||
|
||||
train_split = dataset_info.get("splits", {}).get("train") or {}
|
||||
images_path_str = train_split.get("path")
|
||||
if not images_path_str:
|
||||
return dataset_yaml
|
||||
|
||||
images_path = Path(images_path_str)
|
||||
if not images_path.exists():
|
||||
return dataset_yaml
|
||||
|
||||
if not self._dataset_requires_rgb_conversion(images_path):
|
||||
return dataset_yaml
|
||||
|
||||
cache_root = self._get_rgb_cache_root(dataset_yaml)
|
||||
rgb_yaml = cache_root / "data.yaml"
|
||||
if rgb_yaml.exists():
|
||||
self._append_training_log(
|
||||
f"Detected grayscale dataset; reusing RGB cache at {cache_root}"
|
||||
)
|
||||
return rgb_yaml
|
||||
|
||||
self._append_training_log(
|
||||
f"Detected grayscale dataset; creating RGB cache at {cache_root}"
|
||||
)
|
||||
self._build_rgb_dataset(cache_root, dataset_info)
|
||||
return rgb_yaml
|
||||
|
||||
def _get_rgb_cache_root(self, dataset_yaml: Path) -> Path:
|
||||
cache_base = Path("data/datasets/_rgb_cache")
|
||||
cache_base.mkdir(parents=True, exist_ok=True)
|
||||
key = hashlib.md5(str(dataset_yaml.parent.resolve()).encode()).hexdigest()[:8]
|
||||
return cache_base / f"{dataset_yaml.parent.name}_{key}"
|
||||
|
||||
def _clear_rgb_cache_for_dataset(self, dataset_yaml: Path):
|
||||
cache_root = self._get_rgb_cache_root(dataset_yaml)
|
||||
if cache_root.exists():
|
||||
try:
|
||||
shutil.rmtree(cache_root)
|
||||
logger.debug(f"Removed RGB cache at {cache_root}")
|
||||
except OSError as exc:
|
||||
logger.warning(f"Failed to remove RGB cache {cache_root}: {exc}")
|
||||
|
||||
def _dataset_requires_rgb_conversion(self, images_dir: Path) -> bool:
|
||||
sample_image = self._find_first_image(images_dir)
|
||||
if not sample_image:
|
||||
return False
|
||||
try:
|
||||
with PILImage.open(sample_image) as img:
|
||||
return img.mode.upper() != "RGB"
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to inspect image {sample_image}: {exc}")
|
||||
return False
|
||||
|
||||
def _find_first_image(self, directory: Path) -> Optional[Path]:
|
||||
if not directory.exists():
|
||||
return None
|
||||
for path in directory.rglob("*"):
|
||||
if path.is_file() and path.suffix.lower() in self.allowed_extensions:
|
||||
return path
|
||||
return None
|
||||
|
||||
def _build_rgb_dataset(self, cache_root: Path, dataset_info: Dict[str, Any]):
|
||||
if cache_root.exists():
|
||||
shutil.rmtree(cache_root)
|
||||
cache_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
splits = dataset_info.get("splits", {})
|
||||
for split_name in ("train", "val", "test"):
|
||||
split_entry = splits.get(split_name)
|
||||
if not split_entry:
|
||||
continue
|
||||
images_src = Path(split_entry.get("path", ""))
|
||||
if not images_src.exists():
|
||||
continue
|
||||
images_dst = cache_root / split_name / "images"
|
||||
self._convert_images_to_rgb(images_src, images_dst)
|
||||
|
||||
labels_src = self._infer_labels_dir(images_src)
|
||||
if labels_src.exists():
|
||||
labels_dst = cache_root / split_name / "labels"
|
||||
self._copy_labels(labels_src, labels_dst)
|
||||
|
||||
class_names = dataset_info.get("class_names") or []
|
||||
names_map = {idx: name for idx, name in enumerate(class_names)}
|
||||
num_classes = dataset_info.get("num_classes") or len(class_names)
|
||||
|
||||
yaml_payload: Dict[str, Any] = {
|
||||
"path": cache_root.as_posix(),
|
||||
"names": names_map,
|
||||
"nc": num_classes,
|
||||
def _compose_stage_plan(self, params: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
two_stage = params.get("two_stage") or {}
|
||||
base_stage = {
|
||||
"label": "Single Stage",
|
||||
"model_path": params["base_model"],
|
||||
"use_previous_best": False,
|
||||
"params": {
|
||||
"epochs": params["epochs"],
|
||||
"batch": params["batch"],
|
||||
"imgsz": params["imgsz"],
|
||||
"patience": params["patience"],
|
||||
"lr0": params["lr0"],
|
||||
"freeze": 0,
|
||||
"name": params["run_name"],
|
||||
},
|
||||
}
|
||||
|
||||
for split_name in ("train", "val", "test"):
|
||||
images_dir = cache_root / split_name / "images"
|
||||
if images_dir.exists():
|
||||
yaml_payload[split_name] = f"{split_name}/images"
|
||||
if not two_stage.get("enabled"):
|
||||
return [base_stage]
|
||||
|
||||
with open(cache_root / "data.yaml", "w", encoding="utf-8") as handle:
|
||||
yaml.safe_dump(yaml_payload, handle, sort_keys=False)
|
||||
stage_plan: List[Dict[str, Any]] = []
|
||||
stage1 = two_stage.get("stage1", {})
|
||||
stage2 = two_stage.get("stage2", {})
|
||||
|
||||
def _convert_images_to_rgb(self, src_dir: Path, dst_dir: Path):
|
||||
for src in src_dir.rglob("*"):
|
||||
if not src.is_file() or src.suffix.lower() not in self.allowed_extensions:
|
||||
continue
|
||||
relative = src.relative_to(src_dir)
|
||||
dst = dst_dir / relative
|
||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||
stage_plan.append(
|
||||
{
|
||||
"label": "Stage 1 — Head-only",
|
||||
"model_path": params["base_model"],
|
||||
"use_previous_best": False,
|
||||
"params": {
|
||||
"epochs": stage1.get("epochs", params["epochs"]),
|
||||
"batch": params["batch"],
|
||||
"imgsz": params["imgsz"],
|
||||
"patience": stage1.get("patience", params["patience"]),
|
||||
"lr0": stage1.get("lr0", params["lr0"]),
|
||||
"freeze": stage1.get("freeze", 0),
|
||||
"name": f"{params['run_name']}_head_ft",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
stage_plan.append(
|
||||
{
|
||||
"label": "Stage 2 — Full",
|
||||
"model_path": params["base_model"],
|
||||
"use_previous_best": True,
|
||||
"params": {
|
||||
"epochs": stage2.get("epochs", params["epochs"]),
|
||||
"batch": params["batch"],
|
||||
"imgsz": params["imgsz"],
|
||||
"patience": stage2.get("patience", params["patience"]),
|
||||
"lr0": stage2.get("lr0", params["lr0"]),
|
||||
"freeze": stage2.get("freeze", 0),
|
||||
"name": f"{params['run_name']}_full_ft",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return stage_plan
|
||||
|
||||
def _calculate_total_stage_epochs(self, stage_plan: List[Dict[str, Any]]) -> int:
|
||||
total = 0
|
||||
for stage in stage_plan:
|
||||
params = stage.get("params") or {}
|
||||
try:
|
||||
with PILImage.open(src) as img:
|
||||
rgb_img = img.convert("RGB")
|
||||
rgb_img.save(dst)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to convert {src} to RGB: {exc}")
|
||||
stage_epochs = int(params.get("epochs", 0))
|
||||
except (TypeError, ValueError):
|
||||
stage_epochs = 0
|
||||
if stage_epochs > 0:
|
||||
total += stage_epochs
|
||||
return total
|
||||
|
||||
def _copy_labels(self, labels_src: Path, labels_dst: Path):
|
||||
label_files = list(labels_src.rglob("*.txt"))
|
||||
for label_file in label_files:
|
||||
relative = label_file.relative_to(labels_src)
|
||||
dst = labels_dst / relative
|
||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(label_file, dst)
|
||||
def _log_stage_plan(self, stage_plan: List[Dict[str, Any]]):
|
||||
for index, stage in enumerate(stage_plan, start=1):
|
||||
stage_label = stage.get("label") or f"Stage {index}"
|
||||
params = stage.get("params") or {}
|
||||
epochs = params.get("epochs", "?")
|
||||
lr0 = params.get("lr0", "?")
|
||||
patience = params.get("patience", "?")
|
||||
freeze = params.get("freeze", 0)
|
||||
self._append_training_log(
|
||||
f" • {stage_label}: epochs={epochs}, lr0={lr0}, patience={patience}, freeze={freeze}"
|
||||
)
|
||||
|
||||
def _infer_labels_dir(self, images_dir: Path) -> Path:
|
||||
return images_dir.parent / "labels"
|
||||
@@ -1085,6 +1270,21 @@ class TrainingTab(QWidget):
|
||||
save_dir_path.mkdir(parents=True, exist_ok=True)
|
||||
run_name = f"{model_name}_{model_version}".replace(" ", "_")
|
||||
|
||||
two_stage_config = {
|
||||
"enabled": self.two_stage_checkbox.isChecked(),
|
||||
"stage1": {
|
||||
"epochs": self.stage1_epochs_spin.value(),
|
||||
"lr0": self.stage1_lr_spin.value(),
|
||||
"patience": self.stage1_patience_spin.value(),
|
||||
"freeze": self.stage1_freeze_spin.value(),
|
||||
},
|
||||
"stage2": {
|
||||
"epochs": self.stage2_epochs_spin.value(),
|
||||
"lr0": self.stage2_lr_spin.value(),
|
||||
"patience": self.stage2_patience_spin.value(),
|
||||
},
|
||||
}
|
||||
|
||||
return {
|
||||
"model_name": model_name,
|
||||
"model_version": model_version,
|
||||
@@ -1096,6 +1296,7 @@ class TrainingTab(QWidget):
|
||||
"imgsz": self.imgsz_spin.value(),
|
||||
"patience": self.patience_spin.value(),
|
||||
"lr0": self.lr_spin.value(),
|
||||
"two_stage": two_stage_config,
|
||||
}
|
||||
|
||||
def _start_training(self):
|
||||
@@ -1130,27 +1331,35 @@ class TrainingTab(QWidget):
|
||||
self.training_log.clear()
|
||||
self._export_labels_from_database(dataset_info)
|
||||
|
||||
dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info)
|
||||
if dataset_to_use != dataset_path:
|
||||
self._append_training_log(
|
||||
f"Using RGB-converted dataset at {dataset_to_use.parent}"
|
||||
)
|
||||
self._append_training_log(
|
||||
"Using Float32 on-the-fly loader for 16-bit TIFF support (no disk caching)"
|
||||
)
|
||||
|
||||
params = self._collect_training_params()
|
||||
stage_plan = self._compose_stage_plan(params)
|
||||
params["stage_plan"] = stage_plan
|
||||
total_planned_epochs = (
|
||||
self._calculate_total_stage_epochs(stage_plan) or params["epochs"]
|
||||
)
|
||||
params["total_planned_epochs"] = total_planned_epochs
|
||||
self._active_training_params = params
|
||||
self._training_cancelled = False
|
||||
|
||||
if len(stage_plan) > 1:
|
||||
self._append_training_log("Two-stage fine-tuning schedule:")
|
||||
self._log_stage_plan(stage_plan)
|
||||
|
||||
self._append_training_log(
|
||||
f"Starting training run '{params['run_name']}' using {params['base_model']}"
|
||||
)
|
||||
|
||||
self.training_progress_bar.setVisible(True)
|
||||
self.training_progress_bar.setMaximum(params["epochs"])
|
||||
self.training_progress_bar.setMaximum(max(1, total_planned_epochs))
|
||||
self.training_progress_bar.setValue(0)
|
||||
self._set_training_state(True)
|
||||
|
||||
self.training_worker = TrainingWorker(
|
||||
data_yaml=dataset_to_use.as_posix(),
|
||||
data_yaml=dataset_path.as_posix(),
|
||||
base_model=params["base_model"],
|
||||
epochs=params["epochs"],
|
||||
batch=params["batch"],
|
||||
@@ -1159,6 +1368,8 @@ class TrainingTab(QWidget):
|
||||
lr0=params["lr0"],
|
||||
save_dir=params["save_dir"],
|
||||
run_name=params["run_name"],
|
||||
stage_plan=stage_plan,
|
||||
total_epochs=total_planned_epochs,
|
||||
)
|
||||
self.training_worker.progress.connect(self._on_training_progress)
|
||||
self.training_worker.finished.connect(self._on_training_finished)
|
||||
@@ -1283,14 +1494,22 @@ class TrainingTab(QWidget):
|
||||
if not model_path:
|
||||
raise ValueError("Training results did not include a model path.")
|
||||
|
||||
effective_epochs = params.get("total_planned_epochs", params["epochs"])
|
||||
training_params = {
|
||||
"epochs": params["epochs"],
|
||||
"epochs": effective_epochs,
|
||||
"batch": params["batch"],
|
||||
"imgsz": params["imgsz"],
|
||||
"patience": params["patience"],
|
||||
"lr0": params["lr0"],
|
||||
"run_name": params["run_name"],
|
||||
"two_stage": params.get("two_stage"),
|
||||
}
|
||||
if params.get("stage_plan"):
|
||||
training_params["stage_plan"] = params["stage_plan"]
|
||||
if results.get("stage_results"):
|
||||
training_params["stage_results"] = results["stage_results"]
|
||||
if results.get("total_epochs_completed") is not None:
|
||||
training_params["epochs_completed"] = results["total_epochs_completed"]
|
||||
|
||||
model_id = self.db_manager.add_model(
|
||||
model_name=params["model_name"],
|
||||
@@ -1315,6 +1534,7 @@ class TrainingTab(QWidget):
|
||||
self.rescan_button.setEnabled(not is_training)
|
||||
self.model_name_edit.setEnabled(not is_training)
|
||||
self.model_version_edit.setEnabled(not is_training)
|
||||
self.base_model_combo.setEnabled(not is_training)
|
||||
self.base_model_edit.setEnabled(not is_training)
|
||||
self.base_model_browse_button.setEnabled(not is_training)
|
||||
self.save_dir_edit.setEnabled(not is_training)
|
||||
@@ -1324,6 +1544,8 @@ class TrainingTab(QWidget):
|
||||
self.imgsz_spin.setEnabled(not is_training)
|
||||
self.patience_spin.setEnabled(not is_training)
|
||||
self.lr_spin.setEnabled(not is_training)
|
||||
self.two_stage_checkbox.setEnabled(not is_training)
|
||||
self._refresh_two_stage_controls_enabled()
|
||||
|
||||
def _append_training_log(self, message: str):
|
||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||
@@ -1339,6 +1561,7 @@ class TrainingTab(QWidget):
|
||||
)
|
||||
if file_path:
|
||||
self.base_model_edit.setText(file_path)
|
||||
self._sync_base_model_preset_selection(file_path)
|
||||
|
||||
def _browse_save_dir(self):
|
||||
start_path = self.save_dir_edit.text().strip() or "data/models"
|
||||
|
||||
@@ -16,8 +16,9 @@ from PySide6.QtGui import (
|
||||
QKeyEvent,
|
||||
QMouseEvent,
|
||||
QPaintEvent,
|
||||
QPolygonF,
|
||||
)
|
||||
from PySide6.QtCore import Qt, QEvent, Signal, QPoint
|
||||
from PySide6.QtCore import Qt, QEvent, Signal, QPoint, QPointF, QRect
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from src.utils.image import Image, ImageLoadError
|
||||
@@ -246,10 +247,10 @@ class AnnotationCanvasWidget(QWidget):
|
||||
return
|
||||
|
||||
try:
|
||||
# Get RGB image data
|
||||
if self.current_image.channels == 3:
|
||||
# Get image data in a format compatible with Qt
|
||||
if self.current_image.channels in (3, 4):
|
||||
image_data = self.current_image.get_rgb()
|
||||
height, width, channels = image_data.shape
|
||||
height, width = image_data.shape[:2]
|
||||
else:
|
||||
image_data = self.current_image.get_grayscale()
|
||||
height, width = image_data.shape
|
||||
@@ -263,7 +264,7 @@ class AnnotationCanvasWidget(QWidget):
|
||||
height,
|
||||
bytes_per_line,
|
||||
self.current_image.qtimage_format,
|
||||
)
|
||||
).copy() # Copy so Qt owns the buffer even after numpy array goes out of scope
|
||||
|
||||
self.original_pixmap = QPixmap.fromImage(qimage)
|
||||
|
||||
@@ -496,8 +497,10 @@ class AnnotationCanvasWidget(QWidget):
|
||||
)
|
||||
|
||||
painter.setPen(pen)
|
||||
for (x1, y1), (x2, y2) in zip(polyline[:-1], polyline[1:]):
|
||||
painter.drawLine(int(x1), int(y1), int(x2), int(y2))
|
||||
# Use QPolygonF for efficient polygon rendering (single call vs N-1 calls)
|
||||
# drawPolygon() automatically closes the shape, ensuring proper visual closure
|
||||
polygon = QPolygonF([QPointF(x, y) for x, y in polyline])
|
||||
painter.drawPolygon(polygon)
|
||||
|
||||
# Draw bounding boxes (dashed) if enabled
|
||||
if self.show_bboxes and self.original_pixmap is not None and self.bboxes:
|
||||
@@ -529,6 +532,40 @@ class AnnotationCanvasWidget(QWidget):
|
||||
painter.setPen(pen)
|
||||
painter.drawRect(x_min, y_min, rect_width, rect_height)
|
||||
|
||||
label_text = meta.get("label")
|
||||
if label_text:
|
||||
painter.save()
|
||||
font = painter.font()
|
||||
font.setPointSizeF(max(10.0, width + 4))
|
||||
painter.setFont(font)
|
||||
metrics = painter.fontMetrics()
|
||||
text_width = metrics.horizontalAdvance(label_text)
|
||||
text_height = metrics.height()
|
||||
padding = 4
|
||||
bg_width = text_width + padding * 2
|
||||
bg_height = text_height + padding * 2
|
||||
canvas_width = self.original_pixmap.width()
|
||||
canvas_height = self.original_pixmap.height()
|
||||
bg_x = max(0, min(x_min, canvas_width - bg_width))
|
||||
bg_y = y_min - bg_height
|
||||
if bg_y < 0:
|
||||
bg_y = min(y_min, canvas_height - bg_height)
|
||||
bg_y = max(0, bg_y)
|
||||
background_rect = QRect(bg_x, bg_y, bg_width, bg_height)
|
||||
background_color = QColor(pen_color)
|
||||
background_color.setAlpha(220)
|
||||
painter.fillRect(background_rect, background_color)
|
||||
text_color = QColor(0, 0, 0)
|
||||
if background_color.lightness() < 128:
|
||||
text_color = QColor(255, 255, 255)
|
||||
painter.setPen(text_color)
|
||||
painter.drawText(
|
||||
background_rect.adjusted(padding, padding, -padding, -padding),
|
||||
Qt.AlignLeft | Qt.AlignVCenter,
|
||||
label_text,
|
||||
)
|
||||
painter.restore()
|
||||
|
||||
painter.end()
|
||||
|
||||
self._update_display()
|
||||
@@ -787,7 +824,13 @@ class AnnotationCanvasWidget(QWidget):
|
||||
f"Drew saved polyline with {len(polyline)} points in color {color}"
|
||||
)
|
||||
|
||||
def draw_saved_bbox(self, bbox: List[float], color: str, width: int = 3):
|
||||
def draw_saved_bbox(
|
||||
self,
|
||||
bbox: List[float],
|
||||
color: str,
|
||||
width: int = 3,
|
||||
label: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Draw a bounding box from database coordinates onto the annotation canvas.
|
||||
|
||||
@@ -796,6 +839,7 @@ class AnnotationCanvasWidget(QWidget):
|
||||
in normalized coordinates (0-1)
|
||||
color: Color hex string (e.g., '#FF0000')
|
||||
width: Line width in pixels
|
||||
label: Optional text label to render near the bounding box
|
||||
"""
|
||||
if not self.annotation_pixmap or not self.original_pixmap:
|
||||
logger.warning("Cannot draw bounding box: no image loaded")
|
||||
@@ -828,11 +872,11 @@ class AnnotationCanvasWidget(QWidget):
|
||||
self.bboxes.append(
|
||||
[float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)]
|
||||
)
|
||||
self.bbox_meta.append({"color": pen_color, "width": int(width)})
|
||||
self.bbox_meta.append({"color": pen_color, "width": int(width), "label": label})
|
||||
|
||||
# Store in all_strokes for consistency
|
||||
self.all_strokes.append(
|
||||
{"bbox": bbox, "color": color, "alpha": 128, "width": width}
|
||||
{"bbox": bbox, "color": color, "alpha": 128, "width": width, "label": label}
|
||||
)
|
||||
|
||||
# Redraw overlay (polylines + all bounding boxes)
|
||||
|
||||
@@ -137,7 +137,7 @@ class ImageDisplayWidget(QWidget):
|
||||
height,
|
||||
bytes_per_line,
|
||||
self.current_image.qtimage_format,
|
||||
)
|
||||
).copy() # Copy to ensure Qt owns its memory after this scope
|
||||
|
||||
# Convert to pixmap
|
||||
pixmap = QPixmap.fromImage(qimage)
|
||||
|
||||
@@ -5,12 +5,12 @@ Handles detection inference and result storage.
|
||||
|
||||
from typing import List, Dict, Optional, Callable
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from src.model.yolo_wrapper import YOLOWrapper
|
||||
from src.database.db_manager import DatabaseManager
|
||||
from src.utils.image import Image
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.file_utils import get_relative_path
|
||||
|
||||
@@ -42,6 +42,7 @@ class InferenceEngine:
|
||||
relative_path: str,
|
||||
conf: float = 0.25,
|
||||
save_to_db: bool = True,
|
||||
repository_root: Optional[str] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Detect objects in a single image.
|
||||
@@ -51,49 +52,79 @@ class InferenceEngine:
|
||||
relative_path: Relative path from repository root
|
||||
conf: Confidence threshold
|
||||
save_to_db: Whether to save results to database
|
||||
repository_root: Base directory used to compute relative_path (if known)
|
||||
|
||||
Returns:
|
||||
Dictionary with detection results
|
||||
"""
|
||||
try:
|
||||
# Normalize storage path (fall back to absolute path when repo root is unknown)
|
||||
stored_relative_path = relative_path
|
||||
if not repository_root:
|
||||
stored_relative_path = str(Path(image_path).resolve())
|
||||
|
||||
# Get image dimensions
|
||||
img = Image.open(image_path)
|
||||
width, height = img.size
|
||||
img.close()
|
||||
img = Image(image_path)
|
||||
width = img.width
|
||||
height = img.height
|
||||
|
||||
# Perform detection
|
||||
detections = self.yolo.predict(image_path, conf=conf)
|
||||
|
||||
# Add/get image in database
|
||||
image_id = self.db_manager.get_or_create_image(
|
||||
relative_path=relative_path,
|
||||
relative_path=stored_relative_path,
|
||||
filename=Path(image_path).name,
|
||||
width=width,
|
||||
height=height,
|
||||
)
|
||||
|
||||
# Save detections to database
|
||||
if save_to_db and detections:
|
||||
detection_records = []
|
||||
for det in detections:
|
||||
# Use normalized bbox from detection
|
||||
bbox_normalized = det[
|
||||
"bbox_normalized"
|
||||
] # [x_min, y_min, x_max, y_max]
|
||||
inserted_count = 0
|
||||
deleted_count = 0
|
||||
|
||||
record = {
|
||||
"image_id": image_id,
|
||||
"model_id": self.model_id,
|
||||
"class_name": det["class_name"],
|
||||
"bbox": tuple(bbox_normalized),
|
||||
"confidence": det["confidence"],
|
||||
"segmentation_mask": det.get("segmentation_mask"),
|
||||
"metadata": {"class_id": det["class_id"]},
|
||||
}
|
||||
detection_records.append(record)
|
||||
# Save detections to database, replacing any previous results for this image/model
|
||||
if save_to_db:
|
||||
deleted_count = self.db_manager.delete_detections_for_image(
|
||||
image_id, self.model_id
|
||||
)
|
||||
if detections:
|
||||
detection_records = []
|
||||
for det in detections:
|
||||
# Use normalized bbox from detection
|
||||
bbox_normalized = det[
|
||||
"bbox_normalized"
|
||||
] # [x_min, y_min, x_max, y_max]
|
||||
|
||||
self.db_manager.add_detections_batch(detection_records)
|
||||
logger.info(f"Saved {len(detection_records)} detections to database")
|
||||
metadata = {
|
||||
"class_id": det["class_id"],
|
||||
"source_path": str(Path(image_path).resolve()),
|
||||
}
|
||||
if repository_root:
|
||||
metadata["repository_root"] = str(
|
||||
Path(repository_root).resolve()
|
||||
)
|
||||
|
||||
record = {
|
||||
"image_id": image_id,
|
||||
"model_id": self.model_id,
|
||||
"class_name": det["class_name"],
|
||||
"bbox": tuple(bbox_normalized),
|
||||
"confidence": det["confidence"],
|
||||
"segmentation_mask": det.get("segmentation_mask"),
|
||||
"metadata": metadata,
|
||||
}
|
||||
detection_records.append(record)
|
||||
|
||||
inserted_count = self.db_manager.add_detections_batch(
|
||||
detection_records
|
||||
)
|
||||
logger.info(
|
||||
f"Saved {inserted_count} detections to database (replaced {deleted_count})"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Detection run removed {deleted_count} stale entries but produced no new detections"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
@@ -142,7 +173,12 @@ class InferenceEngine:
|
||||
rel_path = get_relative_path(image_path, repository_root)
|
||||
|
||||
# Perform detection
|
||||
result = self.detect_single(image_path, rel_path, conf)
|
||||
result = self.detect_single(
|
||||
image_path,
|
||||
rel_path,
|
||||
conf=conf,
|
||||
repository_root=repository_root,
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
# Update progress
|
||||
|
||||
@@ -7,7 +7,12 @@ from ultralytics import YOLO
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Callable, Any
|
||||
import torch
|
||||
import tempfile
|
||||
import os
|
||||
import numpy as np
|
||||
from src.utils.image import Image, convert_grayscale_to_rgb_preserve_range
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.train_ultralytics_float import train_with_float32_loader
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -56,10 +61,11 @@ class YOLOWrapper:
|
||||
name: str = "custom_model",
|
||||
resume: bool = False,
|
||||
callbacks: Optional[Dict[str, Callable]] = None,
|
||||
use_float32_loader: bool = True,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Train the YOLO model.
|
||||
Train the YOLO model with optional float32 loader for 16-bit TIFFs.
|
||||
|
||||
Args:
|
||||
data_yaml: Path to data.yaml configuration file
|
||||
@@ -71,40 +77,62 @@ class YOLOWrapper:
|
||||
name: Name for the training run
|
||||
resume: Resume training from last checkpoint
|
||||
callbacks: Optional Ultralytics callback dictionary
|
||||
use_float32_loader: Use custom Float32Dataset for 16-bit TIFFs (default: True)
|
||||
**kwargs: Additional training arguments
|
||||
|
||||
Returns:
|
||||
Dictionary with training results
|
||||
"""
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
|
||||
try:
|
||||
if 1:
|
||||
logger.info(f"Starting training: {name}")
|
||||
logger.info(
|
||||
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
|
||||
)
|
||||
|
||||
# Train the model
|
||||
results = self.model.train(
|
||||
data=data_yaml,
|
||||
epochs=epochs,
|
||||
imgsz=imgsz,
|
||||
batch=batch,
|
||||
patience=patience,
|
||||
project=save_dir,
|
||||
name=name,
|
||||
device=self.device,
|
||||
resume=resume,
|
||||
**kwargs,
|
||||
)
|
||||
# Check if dataset has 16-bit TIFFs and use float32 loader
|
||||
if use_float32_loader:
|
||||
logger.info("Using Float32Dataset loader for 16-bit TIFF support")
|
||||
return train_with_float32_loader(
|
||||
model_path=self.model_path,
|
||||
data_yaml=data_yaml,
|
||||
epochs=epochs,
|
||||
imgsz=imgsz,
|
||||
batch=batch,
|
||||
patience=patience,
|
||||
save_dir=save_dir,
|
||||
name=name,
|
||||
callbacks=callbacks,
|
||||
device=self.device,
|
||||
resume=resume,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
# Standard training (old behavior)
|
||||
if self.model is None:
|
||||
if not self.load_model():
|
||||
raise RuntimeError(
|
||||
f"Failed to load model from {self.model_path}"
|
||||
)
|
||||
|
||||
logger.info("Training completed successfully")
|
||||
return self._format_training_results(results)
|
||||
results = self.model.train(
|
||||
data=data_yaml,
|
||||
epochs=epochs,
|
||||
imgsz=imgsz,
|
||||
batch=batch,
|
||||
patience=patience,
|
||||
project=save_dir,
|
||||
name=name,
|
||||
device=self.device,
|
||||
resume=resume,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during training: {e}")
|
||||
raise
|
||||
logger.info("Training completed successfully")
|
||||
return self._format_training_results(results)
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error during training: {e}")
|
||||
# raise
|
||||
|
||||
def validate(self, data_yaml: str, split: str = "val", **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -119,7 +147,8 @@ class YOLOWrapper:
|
||||
Dictionary with validation metrics
|
||||
"""
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
if not self.load_model():
|
||||
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
||||
|
||||
try:
|
||||
logger.info(f"Starting validation on {split} split")
|
||||
@@ -160,12 +189,15 @@ class YOLOWrapper:
|
||||
List of detection dictionaries
|
||||
"""
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
if not self.load_model():
|
||||
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
||||
|
||||
prepared_source, cleanup_path = self._prepare_source(source)
|
||||
|
||||
try:
|
||||
logger.info(f"Running inference on {source}")
|
||||
results = self.model.predict(
|
||||
source=source,
|
||||
source=prepared_source,
|
||||
conf=conf,
|
||||
iou=iou,
|
||||
save=save,
|
||||
@@ -182,6 +214,17 @@ class YOLOWrapper:
|
||||
except Exception as e:
|
||||
logger.error(f"Error during inference: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Clean up temporary files (only for non-16-bit images)
|
||||
# 16-bit TIFFs return numpy arrays directly, so cleanup_path is None
|
||||
if cleanup_path:
|
||||
try:
|
||||
os.remove(cleanup_path)
|
||||
logger.debug(f"Cleaned up temporary file: {cleanup_path}")
|
||||
except OSError as cleanup_error:
|
||||
logger.warning(
|
||||
f"Failed to delete temporary file {cleanup_path}: {cleanup_error}"
|
||||
)
|
||||
|
||||
def export(
|
||||
self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
|
||||
@@ -198,7 +241,8 @@ class YOLOWrapper:
|
||||
Path to exported model
|
||||
"""
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
if not self.load_model():
|
||||
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
||||
|
||||
try:
|
||||
logger.info(f"Exporting model to {format} format")
|
||||
@@ -210,6 +254,84 @@ class YOLOWrapper:
|
||||
logger.error(f"Error exporting model: {e}")
|
||||
raise
|
||||
|
||||
def _prepare_source(self, source):
|
||||
"""Convert single-channel images to RGB for inference.
|
||||
|
||||
For 16-bit TIFF files, this will:
|
||||
1. Load using tifffile
|
||||
2. Normalize to float32 [0-1] (NO uint8 conversion to avoid data loss)
|
||||
3. Replicate grayscale → RGB (3 channels)
|
||||
4. Pass directly as numpy array to YOLO
|
||||
"""
|
||||
cleanup_path = None
|
||||
|
||||
if isinstance(source, (str, Path)):
|
||||
source_path = Path(source)
|
||||
if source_path.is_file():
|
||||
try:
|
||||
img_obj = Image(source_path)
|
||||
|
||||
# Check if it's a 16-bit TIFF file
|
||||
is_16bit_tiff = (
|
||||
source_path.suffix.lower() in [".tif", ".tiff"]
|
||||
and img_obj.dtype == np.uint16
|
||||
)
|
||||
|
||||
if is_16bit_tiff:
|
||||
# Process 16-bit TIFF: normalize to float32 [0-1]
|
||||
# NO uint8 conversion - pass float32 directly to avoid data loss
|
||||
normalized_float = img_obj.to_normalized_float32()
|
||||
|
||||
# Convert grayscale to RGB by replicating channels
|
||||
if len(normalized_float.shape) == 2:
|
||||
# Grayscale: H,W → H,W,3
|
||||
rgb_float = np.stack([normalized_float] * 3, axis=-1)
|
||||
elif (
|
||||
len(normalized_float.shape) == 3
|
||||
and normalized_float.shape[2] == 1
|
||||
):
|
||||
# Grayscale with channel dim: H,W,1 → H,W,3
|
||||
rgb_float = np.repeat(normalized_float, 3, axis=2)
|
||||
else:
|
||||
# Already multi-channel
|
||||
rgb_float = normalized_float
|
||||
|
||||
# Ensure contiguous array and float32
|
||||
rgb_float = np.ascontiguousarray(rgb_float, dtype=np.float32)
|
||||
|
||||
logger.info(
|
||||
f"Loaded 16-bit TIFF {source_path} as float32 [0-1] RGB "
|
||||
f"(shape: {rgb_float.shape}, dtype: {rgb_float.dtype}, "
|
||||
f"range: [{rgb_float.min():.4f}, {rgb_float.max():.4f}])"
|
||||
)
|
||||
|
||||
# Return numpy array directly - YOLO can handle it
|
||||
return rgb_float, cleanup_path
|
||||
else:
|
||||
# Standard processing for other images
|
||||
pil_img = img_obj.pil_image
|
||||
if len(pil_img.getbands()) == 1:
|
||||
rgb_img = convert_grayscale_to_rgb_preserve_range(pil_img)
|
||||
else:
|
||||
rgb_img = pil_img.convert("RGB")
|
||||
|
||||
suffix = source_path.suffix or ".png"
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
|
||||
tmp_path = tmp.name
|
||||
tmp.close()
|
||||
rgb_img.save(tmp_path)
|
||||
cleanup_path = tmp_path
|
||||
logger.info(
|
||||
f"Converted image {source_path} to RGB for inference at {tmp_path}"
|
||||
)
|
||||
return tmp_path, cleanup_path
|
||||
except Exception as convert_error:
|
||||
logger.warning(
|
||||
f"Failed to preprocess {source_path} as RGB, continuing with original file: {convert_error}"
|
||||
)
|
||||
|
||||
return source, cleanup_path
|
||||
|
||||
def _format_training_results(self, results) -> Dict[str, Any]:
|
||||
"""Format training results into dictionary."""
|
||||
try:
|
||||
|
||||
@@ -7,6 +7,7 @@ import yaml
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.image import Image
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -46,18 +47,15 @@ class ConfigManager:
|
||||
"database": {"path": "data/detections.db"},
|
||||
"image_repository": {
|
||||
"base_path": "",
|
||||
"allowed_extensions": [
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".png",
|
||||
".tif",
|
||||
".tiff",
|
||||
".bmp",
|
||||
],
|
||||
"allowed_extensions": Image.SUPPORTED_EXTENSIONS,
|
||||
},
|
||||
"models": {
|
||||
"default_base_model": "yolov8s-seg.pt",
|
||||
"models_directory": "data/models",
|
||||
"base_model_choices": [
|
||||
"yolov8s-seg.pt",
|
||||
"yolov11s-seg.pt",
|
||||
],
|
||||
},
|
||||
"training": {
|
||||
"default_epochs": 100,
|
||||
@@ -65,6 +63,20 @@ class ConfigManager:
|
||||
"default_imgsz": 640,
|
||||
"default_patience": 50,
|
||||
"default_lr0": 0.01,
|
||||
"two_stage": {
|
||||
"enabled": False,
|
||||
"stage1": {
|
||||
"epochs": 20,
|
||||
"lr0": 0.0005,
|
||||
"patience": 10,
|
||||
"freeze": 10,
|
||||
},
|
||||
"stage2": {
|
||||
"epochs": 150,
|
||||
"lr0": 0.0003,
|
||||
"patience": 30,
|
||||
},
|
||||
},
|
||||
},
|
||||
"detection": {
|
||||
"default_confidence": 0.25,
|
||||
@@ -214,5 +226,5 @@ class ConfigManager:
|
||||
def get_allowed_extensions(self) -> list:
|
||||
"""Get list of allowed image file extensions."""
|
||||
return self.get(
|
||||
"image_repository.allowed_extensions", [".jpg", ".jpeg", ".png"]
|
||||
"image_repository.allowed_extensions", Image.SUPPORTED_EXTENSIONS
|
||||
)
|
||||
|
||||
@@ -28,7 +28,9 @@ def get_image_files(
|
||||
List of absolute paths to image files
|
||||
"""
|
||||
if allowed_extensions is None:
|
||||
allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
|
||||
from src.utils.image import Image
|
||||
|
||||
allowed_extensions = Image.SUPPORTED_EXTENSIONS
|
||||
|
||||
# Normalize extensions to lowercase
|
||||
allowed_extensions = [ext.lower() for ext in allowed_extensions]
|
||||
@@ -204,7 +206,9 @@ def is_image_file(
|
||||
True if file is an image
|
||||
"""
|
||||
if allowed_extensions is None:
|
||||
allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
|
||||
from src.utils.image import Image
|
||||
|
||||
allowed_extensions = Image.SUPPORTED_EXTENSIONS
|
||||
|
||||
extension = Path(file_path).suffix.lower()
|
||||
return extension in [ext.lower() for ext in allowed_extensions]
|
||||
|
||||
@@ -7,6 +7,7 @@ import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
from PIL import Image as PILImage
|
||||
import tifffile
|
||||
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.file_utils import validate_file_path, is_image_file
|
||||
@@ -85,35 +86,75 @@ class Image:
|
||||
)
|
||||
|
||||
try:
|
||||
# Load with OpenCV (returns BGR format)
|
||||
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
|
||||
# Check if it's a TIFF file - use tifffile for better support
|
||||
if self.path.suffix.lower() in [".tif", ".tiff"]:
|
||||
self._data = tifffile.imread(str(self.path))
|
||||
|
||||
if self._data is None:
|
||||
raise ImageLoadError(f"Failed to load image with OpenCV: {self.path}")
|
||||
if self._data is None:
|
||||
raise ImageLoadError(
|
||||
f"Failed to load TIFF with tifffile: {self.path}"
|
||||
)
|
||||
|
||||
# Extract metadata
|
||||
self._height, self._width = self._data.shape[:2]
|
||||
self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1
|
||||
self._format = self.path.suffix.lower().lstrip(".")
|
||||
self._size_bytes = self.path.stat().st_size
|
||||
self._dtype = self._data.dtype
|
||||
# Extract metadata
|
||||
self._height, self._width = (
|
||||
self._data.shape[:2]
|
||||
if len(self._data.shape) >= 2
|
||||
else (self._data.shape[0], 1)
|
||||
)
|
||||
self._channels = (
|
||||
self._data.shape[2] if len(self._data.shape) == 3 else 1
|
||||
)
|
||||
self._format = self.path.suffix.lower().lstrip(".")
|
||||
self._size_bytes = self.path.stat().st_size
|
||||
self._dtype = self._data.dtype
|
||||
|
||||
# Load PIL version for compatibility (convert BGR to RGB)
|
||||
if self._channels == 3:
|
||||
rgb_data = cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
|
||||
self._pil_image = PILImage.fromarray(rgb_data)
|
||||
elif self._channels == 4:
|
||||
rgba_data = cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA)
|
||||
self._pil_image = PILImage.fromarray(rgba_data)
|
||||
# Load PIL version for compatibility
|
||||
if self._channels == 1:
|
||||
# Grayscale
|
||||
self._pil_image = PILImage.fromarray(self._data)
|
||||
else:
|
||||
# Multi-channel (RGB or RGBA)
|
||||
self._pil_image = PILImage.fromarray(self._data)
|
||||
|
||||
logger.info(
|
||||
f"Successfully loaded TIFF image: {self.path.name} "
|
||||
f"({self._width}x{self._height}, {self._channels} channels, "
|
||||
f"dtype={self._dtype}, {self._format.upper()})"
|
||||
)
|
||||
else:
|
||||
# Grayscale
|
||||
self._pil_image = PILImage.fromarray(self._data)
|
||||
# Load with OpenCV (returns BGR format) for non-TIFF images
|
||||
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
|
||||
|
||||
logger.info(
|
||||
f"Successfully loaded image: {self.path.name} "
|
||||
f"({self._width}x{self._height}, {self._channels} channels, "
|
||||
f"{self._format.upper()})"
|
||||
)
|
||||
if self._data is None:
|
||||
raise ImageLoadError(
|
||||
f"Failed to load image with OpenCV: {self.path}"
|
||||
)
|
||||
|
||||
# Extract metadata
|
||||
self._height, self._width = self._data.shape[:2]
|
||||
self._channels = (
|
||||
self._data.shape[2] if len(self._data.shape) == 3 else 1
|
||||
)
|
||||
self._format = self.path.suffix.lower().lstrip(".")
|
||||
self._size_bytes = self.path.stat().st_size
|
||||
self._dtype = self._data.dtype
|
||||
|
||||
# Load PIL version for compatibility (convert BGR to RGB)
|
||||
if self._channels == 3:
|
||||
rgb_data = cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
|
||||
self._pil_image = PILImage.fromarray(rgb_data)
|
||||
elif self._channels == 4:
|
||||
rgba_data = cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA)
|
||||
self._pil_image = PILImage.fromarray(rgba_data)
|
||||
else:
|
||||
# Grayscale
|
||||
self._pil_image = PILImage.fromarray(self._data)
|
||||
|
||||
logger.info(
|
||||
f"Successfully loaded image: {self.path.name} "
|
||||
f"({self._width}x{self._height}, {self._channels} channels, "
|
||||
f"{self._format.upper()})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading image {self.path}: {e}")
|
||||
@@ -277,6 +318,44 @@ class Image:
|
||||
"""
|
||||
return self._channels >= 3
|
||||
|
||||
def to_normalized_float32(self) -> np.ndarray:
|
||||
"""
|
||||
Convert image data to normalized float32 in range [0, 1].
|
||||
|
||||
For 16-bit images, this properly scales the full dynamic range.
|
||||
For 8-bit images, divides by 255.
|
||||
Already float images are clipped to [0, 1].
|
||||
|
||||
Returns:
|
||||
Normalized image data as float32 numpy array [0, 1]
|
||||
"""
|
||||
data = self._data.astype(np.float32)
|
||||
|
||||
if self._dtype == np.uint16:
|
||||
# 16-bit: normalize by max value (65535)
|
||||
data = data / 65535.0
|
||||
elif self._dtype == np.uint8:
|
||||
# 8-bit: normalize by 255
|
||||
data = data / 255.0
|
||||
elif np.issubdtype(self._dtype, np.floating):
|
||||
# Already float, just clip to [0, 1]
|
||||
data = np.clip(data, 0.0, 1.0)
|
||||
else:
|
||||
# Other integer types: use dtype info
|
||||
if np.issubdtype(self._dtype, np.integer):
|
||||
max_val = np.iinfo(self._dtype).max
|
||||
data = data / float(max_val)
|
||||
else:
|
||||
# Unknown type: attempt min-max normalization
|
||||
min_val = data.min()
|
||||
max_val = data.max()
|
||||
if max_val > min_val:
|
||||
data = (data - min_val) / (max_val - min_val)
|
||||
else:
|
||||
data = np.zeros_like(data)
|
||||
|
||||
return np.clip(data, 0.0, 1.0)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of the Image object."""
|
||||
return (
|
||||
@@ -289,3 +368,40 @@ class Image:
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the Image object."""
|
||||
return self.__repr__()
|
||||
|
||||
|
||||
def convert_grayscale_to_rgb_preserve_range(
|
||||
pil_image: PILImage.Image,
|
||||
) -> PILImage.Image:
|
||||
"""Convert a single-channel PIL image to RGB while preserving dynamic range.
|
||||
|
||||
Args:
|
||||
pil_image: Single-channel PIL image (e.g., 16-bit grayscale).
|
||||
|
||||
Returns:
|
||||
PIL Image in RGB mode with intensities normalized to 0-255.
|
||||
"""
|
||||
|
||||
if pil_image.mode == "RGB":
|
||||
return pil_image
|
||||
|
||||
grayscale = np.array(pil_image)
|
||||
if grayscale.ndim == 3:
|
||||
grayscale = grayscale[:, :, 0]
|
||||
|
||||
original_dtype = grayscale.dtype
|
||||
grayscale = grayscale.astype(np.float32)
|
||||
|
||||
if grayscale.size == 0:
|
||||
return PILImage.new("RGB", pil_image.size, color=(0, 0, 0))
|
||||
|
||||
if np.issubdtype(original_dtype, np.integer):
|
||||
denom = float(max(np.iinfo(original_dtype).max, 1))
|
||||
else:
|
||||
max_val = float(grayscale.max())
|
||||
denom = max(max_val, 1.0)
|
||||
|
||||
grayscale = np.clip(grayscale / denom, 0.0, 1.0)
|
||||
grayscale_u8 = (grayscale * 255.0).round().astype(np.uint8)
|
||||
rgb_arr = np.repeat(grayscale_u8[:, :, None], 3, axis=2)
|
||||
return PILImage.fromarray(rgb_arr, mode="RGB")
|
||||
|
||||
122
src/utils/image_converters.py
Normal file
122
src/utils/image_converters.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import numpy as np
|
||||
|
||||
from roifile import ImagejRoi
|
||||
from tifffile import TiffFile, TiffWriter
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class UT:
|
||||
"""
|
||||
Docstring for UT
|
||||
|
||||
Operetta files along with rois drawn in ImageJ
|
||||
"""
|
||||
|
||||
def __init__(self, roifile_fn: Path):
|
||||
self.roifile_fn = roifile_fn
|
||||
self.rois = ImagejRoi.fromfile(self.roifile_fn)
|
||||
self.stem = self.roifile_fn.stem.strip("-RoiSet")
|
||||
self.image, self.image_props = self._load_images()
|
||||
|
||||
def _load_images(self):
|
||||
"""Loading sequence of tif files
|
||||
array sequence is CZYX
|
||||
"""
|
||||
print(self.roifile_fn.parent, self.stem)
|
||||
fns = list(self.roifile_fn.parent.glob(f"{self.stem}*.tif*"))
|
||||
stems = [fn.stem.split(self.stem)[-1] for fn in fns]
|
||||
n_ch = len(set([stem.split("-ch")[-1].split("t")[0] for stem in stems]))
|
||||
n_p = len(set([stem.split("-")[0] for stem in stems]))
|
||||
n_t = len(set([stem.split("t")[1] for stem in stems]))
|
||||
print(n_ch, n_p, n_t)
|
||||
|
||||
with TiffFile(fns[0]) as tif:
|
||||
img = tif.asarray()
|
||||
w, h = img.shape
|
||||
dtype = img.dtype
|
||||
self.image_props = {
|
||||
"channels": n_ch,
|
||||
"planes": n_p,
|
||||
"tiles": n_t,
|
||||
"width": w,
|
||||
"height": h,
|
||||
"dtype": dtype,
|
||||
}
|
||||
|
||||
image_stack = np.zeros((n_ch, n_p, w, h), dtype=dtype)
|
||||
for fn in fns:
|
||||
with TiffFile(fn) as tif:
|
||||
img = tif.asarray()
|
||||
stem = fn.stem.split(self.stem)[-1]
|
||||
ch = int(stem.split("-ch")[-1].split("t")[0])
|
||||
p = int(stem.split("-")[0].lstrip("p"))
|
||||
t = int(stem.split("t")[1])
|
||||
print(fn.stem, "ch", ch, "p", p, "t", t)
|
||||
image_stack[ch - 1, p - 1] = img
|
||||
|
||||
print(image_stack.shape)
|
||||
|
||||
return image_stack, self.image_props
|
||||
|
||||
@property
|
||||
def width(self):
|
||||
return self.image_props["width"]
|
||||
|
||||
@property
|
||||
def height(self):
|
||||
return self.image_props["height"]
|
||||
|
||||
@property
|
||||
def nchannels(self):
|
||||
return self.image_props["channels"]
|
||||
|
||||
@property
|
||||
def nplanes(self):
|
||||
return self.image_props["planes"]
|
||||
|
||||
def export_rois(
|
||||
self,
|
||||
path: Path,
|
||||
subfolder: str = "labels",
|
||||
class_index: int = 0,
|
||||
):
|
||||
"""Export rois to a file"""
|
||||
with open(path / subfolder / f"{self.stem}.txt", "w") as f:
|
||||
for roi in self.rois:
|
||||
# TODO add image coordinates normalization
|
||||
coords = ""
|
||||
for x, y in roi.subpixel_coordinates:
|
||||
coords += f"{x/self.width} {y/self.height} "
|
||||
f.write(f"{class_index} {coords}\n")
|
||||
|
||||
return
|
||||
|
||||
def export_image(
|
||||
self,
|
||||
path: Path,
|
||||
subfolder: str = "images",
|
||||
plane_mode: str = "max projection",
|
||||
channel: int = 0,
|
||||
):
|
||||
"""Export image to a file"""
|
||||
|
||||
if plane_mode == "max projection":
|
||||
self.image = np.max(self.image[channel], axis=0)
|
||||
print(self.image.shape)
|
||||
|
||||
with TiffWriter(path / subfolder / f"{self.stem}.tif") as tif:
|
||||
tif.write(self.image)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("input", type=Path)
|
||||
parser.add_argument("output", type=Path)
|
||||
args = parser.parse_args()
|
||||
|
||||
for rfn in args.input.glob("*.zip"):
|
||||
ut = UT(rfn)
|
||||
ut.export_rois(args.output, class_index=0)
|
||||
ut.export_image(args.output, plane_mode="max projection", channel=0)
|
||||
561
src/utils/train_ultralytics_float.py
Normal file
561
src/utils/train_ultralytics_float.py
Normal file
@@ -0,0 +1,561 @@
|
||||
"""
|
||||
Custom YOLO training with on-the-fly float32 conversion for 16-bit grayscale images.
|
||||
|
||||
This module provides a custom dataset class and training function that:
|
||||
1. Load 16-bit TIFF images directly with tifffile (no PIL/cv2)
|
||||
2. Convert to float32 [0-1] on-the-fly (no data loss)
|
||||
3. Replicate grayscale to 3-channel RGB in memory
|
||||
4. Use custom training loop to bypass Ultralytics' dataset infrastructure
|
||||
5. No disk caching required
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import tifffile
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List, Tuple
|
||||
from ultralytics import YOLO
|
||||
import yaml
|
||||
import time
|
||||
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Float32YOLODataset(Dataset):
|
||||
"""
|
||||
Custom PyTorch dataset for YOLO that loads 16-bit grayscale TIFFs as float32 RGB.
|
||||
|
||||
This dataset:
|
||||
- Loads with tifffile (not PIL/cv2)
|
||||
- Converts uint16 → float32 [0-1] (preserves full dynamic range)
|
||||
- Replicates grayscale to 3 channels
|
||||
- Returns torch tensors in (C, H, W) format
|
||||
"""
|
||||
|
||||
def __init__(self, images_dir: str, labels_dir: str, img_size: int = 640):
|
||||
"""
|
||||
Initialize dataset.
|
||||
|
||||
Args:
|
||||
images_dir: Directory containing images
|
||||
labels_dir: Directory containing YOLO label files (.txt)
|
||||
img_size: Target image size (for reference, actual resizing done by model)
|
||||
"""
|
||||
self.images_dir = Path(images_dir)
|
||||
self.labels_dir = Path(labels_dir)
|
||||
self.img_size = img_size
|
||||
|
||||
# Find all image files
|
||||
extensions = {".tif", ".tiff", ".png", ".jpg", ".jpeg", ".bmp"}
|
||||
self.image_paths = sorted(
|
||||
[
|
||||
p
|
||||
for p in self.images_dir.rglob("*")
|
||||
if p.is_file() and p.suffix.lower() in extensions
|
||||
]
|
||||
)
|
||||
|
||||
if not self.image_paths:
|
||||
raise ValueError(f"No images found in {images_dir}")
|
||||
|
||||
logger.info(
|
||||
f"Float32YOLODataset initialized with {len(self.image_paths)} images from {images_dir}"
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_paths)
|
||||
|
||||
def _read_image(self, img_path: Path) -> np.ndarray:
|
||||
"""
|
||||
Read image and convert to float32 [0-1] RGB.
|
||||
|
||||
Returns:
|
||||
numpy array, shape (H, W, 3), dtype float32, range [0, 1]
|
||||
"""
|
||||
# Load image with tifffile
|
||||
img = tifffile.imread(str(img_path))
|
||||
|
||||
# Convert to float32
|
||||
img = img.astype(np.float32)
|
||||
|
||||
# Normalize if 16-bit (values > 1.5 indicates uint16)
|
||||
if img.max() > 1.5:
|
||||
img = img / 65535.0
|
||||
|
||||
# Ensure [0, 1] range
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
# Convert grayscale to RGB
|
||||
if img.ndim == 2:
|
||||
# H,W → H,W,3
|
||||
img = np.repeat(img[..., None], 3, axis=2)
|
||||
elif img.ndim == 3 and img.shape[2] == 1:
|
||||
# H,W,1 → H,W,3
|
||||
img = np.repeat(img, 3, axis=2)
|
||||
|
||||
return img # float32 (H, W, 3) in [0, 1]
|
||||
|
||||
def _parse_label(self, label_path: Path) -> List[np.ndarray]:
|
||||
"""
|
||||
Parse YOLO label file with variable-length rows (segmentation polygons).
|
||||
|
||||
Returns:
|
||||
List of numpy arrays, one per annotation
|
||||
"""
|
||||
if not label_path.exists():
|
||||
return []
|
||||
|
||||
labels = []
|
||||
try:
|
||||
with open(label_path, "r") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
# Parse space-separated values
|
||||
values = line.split()
|
||||
if len(values) >= 5: # At minimum: class_id x y w h
|
||||
labels.append(
|
||||
np.array([float(v) for v in values], dtype=np.float32)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error parsing label {label_path}: {e}")
|
||||
return []
|
||||
|
||||
return labels
|
||||
|
||||
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, List[np.ndarray], str]:
|
||||
"""
|
||||
Get a single training sample.
|
||||
|
||||
Returns:
|
||||
Tuple of (image_tensor, labels, filename)
|
||||
- image_tensor: shape (3, H, W), dtype float32, range [0, 1]
|
||||
- labels: list of numpy arrays with YOLO format labels (variable length for segmentation)
|
||||
- filename: image filename
|
||||
"""
|
||||
img_path = self.image_paths[idx]
|
||||
label_path = self.labels_dir / f"{img_path.stem}.txt"
|
||||
|
||||
# Load image as float32 RGB
|
||||
img = self._read_image(img_path)
|
||||
|
||||
# Convert to tensor: (H, W, 3) → (3, H, W)
|
||||
img_tensor = torch.from_numpy(img).permute(2, 0, 1).contiguous()
|
||||
|
||||
# Load labels (list of variable-length arrays for segmentation)
|
||||
labels = self._parse_label(label_path)
|
||||
|
||||
return img_tensor, labels, img_path.name
|
||||
|
||||
|
||||
def collate_fn(
|
||||
batch: List[Tuple[torch.Tensor, List[np.ndarray], str]],
|
||||
) -> Tuple[torch.Tensor, List[List[np.ndarray]], List[str]]:
|
||||
"""
|
||||
Collate function for DataLoader.
|
||||
|
||||
Args:
|
||||
batch: List of (img_tensor, labels_list, filename) tuples
|
||||
where labels_list is a list of variable-length numpy arrays
|
||||
|
||||
Returns:
|
||||
Tuple of (stacked_images, list_of_labels_lists, list_of_filenames)
|
||||
"""
|
||||
imgs = [b[0] for b in batch]
|
||||
labels = [b[1] for b in batch] # Each element is a list of arrays
|
||||
names = [b[2] for b in batch]
|
||||
|
||||
# Stack images - requires same H,W
|
||||
# For different sizes, implement letterbox/resize in dataset
|
||||
imgs_batch = torch.stack(imgs, dim=0)
|
||||
|
||||
return imgs_batch, labels, names
|
||||
|
||||
|
||||
def train_with_float32_loader(
|
||||
model_path: str,
|
||||
data_yaml: str,
|
||||
epochs: int = 100,
|
||||
imgsz: int = 640,
|
||||
batch: int = 16,
|
||||
patience: int = 50,
|
||||
save_dir: str = "data/models",
|
||||
name: str = "custom_model",
|
||||
callbacks: Optional[Dict] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Train YOLO model with custom Float32 dataset for 16-bit TIFF support.
|
||||
|
||||
Uses a custom training loop to bypass Ultralytics' dataset pipeline,
|
||||
avoiding channel conversion issues.
|
||||
|
||||
Args:
|
||||
model_path: Path to base model weights (.pt file)
|
||||
data_yaml: Path to dataset YAML configuration
|
||||
epochs: Number of training epochs
|
||||
imgsz: Input image size
|
||||
batch: Batch size
|
||||
patience: Early stopping patience
|
||||
save_dir: Directory to save trained model
|
||||
name: Name for the training run
|
||||
callbacks: Optional callback dictionary (for progress reporting)
|
||||
**kwargs: Additional training arguments (lr0, freeze, device, etc.)
|
||||
|
||||
Returns:
|
||||
Dict with training results including model paths and metrics
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Starting Float32 custom training: {name}")
|
||||
logger.info(
|
||||
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
|
||||
)
|
||||
|
||||
# Parse data.yaml to get dataset paths
|
||||
with open(data_yaml, "r") as f:
|
||||
data_config = yaml.safe_load(f)
|
||||
|
||||
dataset_root = Path(data_config.get("path", Path(data_yaml).parent))
|
||||
train_images = dataset_root / data_config.get("train", "train/images")
|
||||
val_images = dataset_root / data_config.get("val", "val/images")
|
||||
|
||||
# Infer label directories
|
||||
train_labels = train_images.parent / "labels"
|
||||
val_labels = val_images.parent / "labels"
|
||||
|
||||
logger.info(f"Train images: {train_images}")
|
||||
logger.info(f"Train labels: {train_labels}")
|
||||
logger.info(f"Val images: {val_images}")
|
||||
logger.info(f"Val labels: {val_labels}")
|
||||
|
||||
# Create datasets
|
||||
train_dataset = Float32YOLODataset(
|
||||
str(train_images), str(train_labels), img_size=imgsz
|
||||
)
|
||||
val_dataset = Float32YOLODataset(
|
||||
str(val_images), str(val_labels), img_size=imgsz
|
||||
)
|
||||
|
||||
# Create data loaders
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch,
|
||||
shuffle=True,
|
||||
num_workers=4,
|
||||
pin_memory=True,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=batch,
|
||||
shuffle=False,
|
||||
num_workers=2,
|
||||
pin_memory=True,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
# Load model
|
||||
logger.info(f"Loading model from {model_path}")
|
||||
ul_model = YOLO(model_path)
|
||||
|
||||
# Get PyTorch model
|
||||
pt_model, loss_fn = _get_pytorch_model(ul_model)
|
||||
|
||||
# Setup device
|
||||
device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Configure model args for loss function
|
||||
from types import SimpleNamespace
|
||||
|
||||
# Required args for segmentation loss
|
||||
required_args = {
|
||||
"overlap_mask": True,
|
||||
"mask_ratio": 4,
|
||||
"task": "segment",
|
||||
"single_cls": False,
|
||||
"box": 7.5,
|
||||
"cls": 0.5,
|
||||
"dfl": 1.5,
|
||||
}
|
||||
|
||||
if not hasattr(pt_model, "args"):
|
||||
# No args - create SimpleNamespace
|
||||
pt_model.args = SimpleNamespace(**required_args)
|
||||
elif isinstance(pt_model.args, dict):
|
||||
# Args is dict - MUST convert to SimpleNamespace for attribute access
|
||||
# The loss function uses model.args.overlap_mask (attribute access)
|
||||
merged = {**pt_model.args, **required_args}
|
||||
pt_model.args = SimpleNamespace(**merged)
|
||||
logger.info(
|
||||
"Converted model.args from dict to SimpleNamespace for loss function compatibility"
|
||||
)
|
||||
else:
|
||||
# Args is SimpleNamespace or other - set attributes
|
||||
for key, value in required_args.items():
|
||||
if not hasattr(pt_model.args, key):
|
||||
setattr(pt_model.args, key, value)
|
||||
|
||||
pt_model.to(device)
|
||||
pt_model.train()
|
||||
|
||||
logger.info(f"Training on device: {device}")
|
||||
logger.info(f"PyTorch model type: {type(pt_model)}")
|
||||
logger.info(f"Model args configured for segmentation loss")
|
||||
|
||||
# Setup optimizer
|
||||
lr0 = kwargs.get("lr0", 0.01)
|
||||
optimizer = torch.optim.AdamW(pt_model.parameters(), lr=lr0)
|
||||
|
||||
# Training loop
|
||||
save_path = Path(save_dir) / name
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
weights_dir = save_path / "weights"
|
||||
weights_dir.mkdir(exist_ok=True)
|
||||
|
||||
best_loss = float("inf")
|
||||
patience_counter = 0
|
||||
|
||||
for epoch in range(epochs):
|
||||
epoch_start = time.time()
|
||||
running_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
logger.info(f"Epoch {epoch+1}/{epochs} starting...")
|
||||
|
||||
for batch_idx, (imgs, labels_list, names) in enumerate(train_loader):
|
||||
imgs = imgs.to(device) # (B, 3, H, W) float32
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Forward pass
|
||||
try:
|
||||
preds = pt_model(imgs)
|
||||
except Exception as e:
|
||||
# Try with labels
|
||||
preds = pt_model(imgs, labels_list)
|
||||
|
||||
# Compute loss
|
||||
# For Ultralytics models, the easiest approach is to construct a batch dict
|
||||
# and call the model in training mode which returns preds + loss
|
||||
batch_dict = {
|
||||
"img": imgs, # Already on device
|
||||
"batch_idx": (
|
||||
torch.cat(
|
||||
[
|
||||
torch.full((len(lab),), i, dtype=torch.long)
|
||||
for i, lab in enumerate(labels_list)
|
||||
]
|
||||
).to(device)
|
||||
if any(len(lab) > 0 for lab in labels_list)
|
||||
else torch.tensor([], dtype=torch.long, device=device)
|
||||
),
|
||||
"cls": (
|
||||
torch.cat(
|
||||
[
|
||||
torch.from_numpy(lab[:, 0:1])
|
||||
for lab in labels_list
|
||||
if len(lab) > 0
|
||||
]
|
||||
).to(device)
|
||||
if any(len(lab) > 0 for lab in labels_list)
|
||||
else torch.tensor([], dtype=torch.float32, device=device)
|
||||
),
|
||||
"bboxes": (
|
||||
torch.cat(
|
||||
[
|
||||
torch.from_numpy(lab[:, 1:5])
|
||||
for lab in labels_list
|
||||
if len(lab) > 0
|
||||
]
|
||||
).to(device)
|
||||
if any(len(lab) > 0 for lab in labels_list)
|
||||
else torch.tensor([], dtype=torch.float32, device=device)
|
||||
),
|
||||
"ori_shape": (imgs.shape[2], imgs.shape[3]), # H, W
|
||||
"resized_shape": (imgs.shape[2], imgs.shape[3]),
|
||||
}
|
||||
|
||||
# Add masks if segmentation labels exist
|
||||
if any(len(lab) > 5 for lab in labels_list if len(lab) > 0):
|
||||
masks = []
|
||||
for lab in labels_list:
|
||||
if len(lab) > 0 and lab.shape[1] > 5:
|
||||
# Has segmentation points
|
||||
masks.append(torch.from_numpy(lab[:, 5:]))
|
||||
if masks:
|
||||
batch_dict["masks"] = masks
|
||||
|
||||
# Call model loss (it will compute loss internally)
|
||||
try:
|
||||
loss_output = pt_model.loss(batch_dict, preds)
|
||||
if isinstance(loss_output, (tuple, list)):
|
||||
loss = loss_output[0]
|
||||
else:
|
||||
loss = loss_output
|
||||
except Exception as e:
|
||||
logger.error(f"Model loss computation failed: {e}")
|
||||
# Last resort: maybe preds is already a dict with 'loss'
|
||||
if isinstance(preds, dict) and "loss" in preds:
|
||||
loss = preds["loss"]
|
||||
else:
|
||||
raise RuntimeError(f"Cannot compute loss: {e}")
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += loss.item()
|
||||
num_batches += 1
|
||||
|
||||
# Report progress via callback
|
||||
if callbacks and "on_fit_epoch_end" in callbacks:
|
||||
# Create a mock trainer object for callback
|
||||
class MockTrainer:
|
||||
def __init__(self, epoch):
|
||||
self.epoch = epoch
|
||||
self.loss_items = [loss.item()]
|
||||
|
||||
callbacks["on_fit_epoch_end"](MockTrainer(epoch))
|
||||
|
||||
epoch_loss = running_loss / max(1, num_batches)
|
||||
epoch_time = time.time() - epoch_start
|
||||
|
||||
logger.info(
|
||||
f"Epoch {epoch+1}/{epochs} completed. Avg Loss: {epoch_loss:.4f}, Time: {epoch_time:.1f}s"
|
||||
)
|
||||
|
||||
# Save checkpoint
|
||||
ckpt_path = weights_dir / f"epoch{epoch+1}.pt"
|
||||
torch.save(
|
||||
{
|
||||
"epoch": epoch + 1,
|
||||
"model_state_dict": pt_model.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"loss": epoch_loss,
|
||||
},
|
||||
ckpt_path,
|
||||
)
|
||||
|
||||
# Save as last.pt
|
||||
last_path = weights_dir / "last.pt"
|
||||
torch.save(pt_model.state_dict(), last_path)
|
||||
|
||||
# Check for best model
|
||||
if epoch_loss < best_loss:
|
||||
best_loss = epoch_loss
|
||||
patience_counter = 0
|
||||
best_path = weights_dir / "best.pt"
|
||||
torch.save(pt_model.state_dict(), best_path)
|
||||
logger.info(f"New best model saved: {best_path}")
|
||||
else:
|
||||
patience_counter += 1
|
||||
|
||||
# Early stopping
|
||||
if patience_counter >= patience:
|
||||
logger.info(f"Early stopping triggered after {epoch+1} epochs")
|
||||
break
|
||||
|
||||
logger.info("Training completed successfully")
|
||||
|
||||
# Format results
|
||||
return {
|
||||
"success": True,
|
||||
"final_epoch": epoch + 1,
|
||||
"metrics": {
|
||||
"final_loss": epoch_loss,
|
||||
"best_loss": best_loss,
|
||||
},
|
||||
"best_model_path": str(weights_dir / "best.pt"),
|
||||
"last_model_path": str(weights_dir / "last.pt"),
|
||||
"save_dir": str(save_path),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during Float32 training: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
|
||||
def _get_pytorch_model(ul_model: YOLO) -> Tuple[torch.nn.Module, Optional[callable]]:
|
||||
"""
|
||||
Extract PyTorch model and loss function from Ultralytics YOLO wrapper.
|
||||
|
||||
Args:
|
||||
ul_model: Ultralytics YOLO model wrapper
|
||||
|
||||
Returns:
|
||||
Tuple of (pytorch_model, loss_function)
|
||||
"""
|
||||
# Try to get the underlying PyTorch model
|
||||
candidates = []
|
||||
|
||||
# Direct model attribute
|
||||
if hasattr(ul_model, "model"):
|
||||
candidates.append(ul_model.model)
|
||||
|
||||
# Sometimes nested
|
||||
if hasattr(ul_model, "model") and hasattr(ul_model.model, "model"):
|
||||
candidates.append(ul_model.model.model)
|
||||
|
||||
# The wrapper itself
|
||||
if isinstance(ul_model, torch.nn.Module):
|
||||
candidates.append(ul_model)
|
||||
|
||||
# Find a valid model
|
||||
pt_model = None
|
||||
loss_fn = None
|
||||
|
||||
for candidate in candidates:
|
||||
if candidate is None or not isinstance(candidate, torch.nn.Module):
|
||||
continue
|
||||
|
||||
pt_model = candidate
|
||||
|
||||
# Try to find loss function
|
||||
if hasattr(candidate, "loss") and callable(getattr(candidate, "loss")):
|
||||
loss_fn = getattr(candidate, "loss")
|
||||
elif hasattr(candidate, "compute_loss") and callable(
|
||||
getattr(candidate, "compute_loss")
|
||||
):
|
||||
loss_fn = getattr(candidate, "compute_loss")
|
||||
|
||||
break
|
||||
|
||||
if pt_model is None:
|
||||
raise RuntimeError("Could not extract PyTorch model from Ultralytics wrapper")
|
||||
|
||||
logger.info(f"Extracted PyTorch model: {type(pt_model)}")
|
||||
logger.info(
|
||||
f"Loss function: {type(loss_fn) if loss_fn else 'None (will attempt fallback)'}"
|
||||
)
|
||||
|
||||
return pt_model, loss_fn
|
||||
|
||||
|
||||
# Compatibility function (kept for backwards compatibility)
|
||||
def train_float32(model: YOLO, data_yaml: str, **train_kwargs) -> Any:
|
||||
"""
|
||||
Train YOLO model with Float32YOLODataset (alternative API).
|
||||
|
||||
Args:
|
||||
model: Initialized YOLO model instance
|
||||
data_yaml: Path to dataset YAML
|
||||
**train_kwargs: Training parameters
|
||||
|
||||
Returns:
|
||||
Training results dict
|
||||
"""
|
||||
return train_with_float32_loader(
|
||||
model_path=(
|
||||
model.model_path if hasattr(model, "model_path") else "yolov8s-seg.pt"
|
||||
),
|
||||
data_yaml=data_yaml,
|
||||
**train_kwargs,
|
||||
)
|
||||
182
tests/show_yolo_seg.py
Normal file
182
tests/show_yolo_seg.py
Normal file
@@ -0,0 +1,182 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
show_yolo_seg.py
|
||||
|
||||
Usage:
|
||||
python show_yolo_seg.py /path/to/image.jpg /path/to/labels.txt
|
||||
|
||||
Supports:
|
||||
- Segmentation polygons: "class x1 y1 x2 y2 ... xn yn"
|
||||
- YOLO bbox lines as fallback: "class x_center y_center width height"
|
||||
Coordinates can be normalized [0..1] or absolute pixels (auto-detected).
|
||||
"""
|
||||
import sys
|
||||
import cv2
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import random
|
||||
|
||||
|
||||
def parse_label_line(line):
|
||||
parts = line.strip().split()
|
||||
if not parts:
|
||||
return None
|
||||
cls = int(float(parts[0]))
|
||||
coords = [float(x) for x in parts[1:]]
|
||||
return cls, coords
|
||||
|
||||
|
||||
def coords_are_normalized(coords):
|
||||
# If every coordinate is between 0 and 1 (inclusive-ish), assume normalized
|
||||
if not coords:
|
||||
return False
|
||||
return max(coords) <= 1.001
|
||||
|
||||
|
||||
def yolo_bbox_to_xyxy(coords, img_w, img_h):
|
||||
# coords: [xc, yc, w, h] normalized or absolute
|
||||
xc, yc, w, h = coords[:4]
|
||||
if max(coords) <= 1.001:
|
||||
xc *= img_w
|
||||
yc *= img_h
|
||||
w *= img_w
|
||||
h *= img_h
|
||||
x1 = int(round(xc - w / 2))
|
||||
y1 = int(round(yc - h / 2))
|
||||
x2 = int(round(xc + w / 2))
|
||||
y2 = int(round(yc + h / 2))
|
||||
return x1, y1, x2, y2
|
||||
|
||||
|
||||
def poly_to_pts(coords, img_w, img_h):
|
||||
# coords: [x1 y1 x2 y2 ...] either normalized or absolute
|
||||
if coords_are_normalized(coords):
|
||||
coords = [
|
||||
coords[i] * (img_w if i % 2 == 0 else img_h) for i in range(len(coords))
|
||||
]
|
||||
pts = np.array(coords, dtype=np.int32).reshape(-1, 2)
|
||||
return pts
|
||||
|
||||
|
||||
def random_color_for_class(cls):
|
||||
random.seed(cls) # deterministic per class
|
||||
return tuple(int(x) for x in np.array([random.randint(0, 255) for _ in range(3)]))
|
||||
|
||||
|
||||
def draw_annotations(img, labels, alpha=0.4, draw_bbox_for_poly=True):
|
||||
# img: BGR numpy array
|
||||
overlay = img.copy()
|
||||
h, w = img.shape[:2]
|
||||
for cls, coords in labels:
|
||||
if not coords:
|
||||
continue
|
||||
# polygon case (>=6 coordinates)
|
||||
if len(coords) >= 6:
|
||||
pts = poly_to_pts(coords, w, h)
|
||||
color = random_color_for_class(cls)
|
||||
# fill on overlay
|
||||
cv2.fillPoly(overlay, [pts], color)
|
||||
# outline on base image
|
||||
cv2.polylines(img, [pts], isClosed=True, color=color, thickness=2)
|
||||
# put class text at first point
|
||||
x, y = int(pts[0, 0]), int(pts[0, 1]) - 6
|
||||
cv2.putText(
|
||||
img,
|
||||
str(cls),
|
||||
(x, max(6, y)),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.6,
|
||||
(255, 255, 255),
|
||||
2,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
if draw_bbox_for_poly:
|
||||
x, y, w_box, h_box = cv2.boundingRect(pts)
|
||||
cv2.rectangle(img, (x, y), (x + w_box, y + h_box), color, 1)
|
||||
# YOLO bbox case (4 coords)
|
||||
elif len(coords) == 4:
|
||||
x1, y1, x2, y2 = yolo_bbox_to_xyxy(coords, w, h)
|
||||
color = random_color_for_class(cls)
|
||||
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
|
||||
cv2.putText(
|
||||
img,
|
||||
str(cls),
|
||||
(x1, max(6, y1 - 4)),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.6,
|
||||
(255, 255, 255),
|
||||
2,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
else:
|
||||
# Unknown / invalid format, skip
|
||||
continue
|
||||
|
||||
# blend overlay for filled polygons
|
||||
cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0, img)
|
||||
return img
|
||||
|
||||
|
||||
def load_labels_file(label_path):
|
||||
labels = []
|
||||
with open(label_path, "r") as f:
|
||||
for raw in f:
|
||||
line = raw.strip()
|
||||
if not line:
|
||||
continue
|
||||
parsed = parse_label_line(line)
|
||||
if parsed:
|
||||
labels.append(parsed)
|
||||
return labels
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Show YOLO segmentation / polygon annotations"
|
||||
)
|
||||
parser.add_argument("image", type=str, help="Path to image file")
|
||||
parser.add_argument("labels", type=str, help="Path to YOLO label file (polygons)")
|
||||
parser.add_argument(
|
||||
"--alpha", type=float, default=0.4, help="Polygon fill alpha (0..1)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-bbox", action="store_true", help="Don't draw bounding boxes for polygons"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
img_path = Path(args.image)
|
||||
lbl_path = Path(args.labels)
|
||||
|
||||
if not img_path.exists():
|
||||
print("Image not found:", img_path)
|
||||
sys.exit(1)
|
||||
if not lbl_path.exists():
|
||||
print("Label file not found:", lbl_path)
|
||||
sys.exit(1)
|
||||
|
||||
img = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
print("Could not load image:", img_path)
|
||||
sys.exit(1)
|
||||
|
||||
labels = load_labels_file(str(lbl_path))
|
||||
if not labels:
|
||||
print("No labels parsed from", lbl_path)
|
||||
# continue and just show image
|
||||
out = draw_annotations(
|
||||
img.copy(), labels, alpha=args.alpha, draw_bbox_for_poly=(not args.no_bbox)
|
||||
)
|
||||
|
||||
# Convert BGR -> RGB for matplotlib display
|
||||
out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
|
||||
plt.figure(figsize=(10, 10 * out.shape[0] / out.shape[1]))
|
||||
plt.imshow(out_rgb)
|
||||
plt.axis("off")
|
||||
plt.title(f"{img_path.name} ({lbl_path.name})")
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
109
tests/test_16bit_tiff_loading.py
Normal file
109
tests/test_16bit_tiff_loading.py
Normal file
@@ -0,0 +1,109 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for 16-bit TIFF loading and normalization.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import tifffile
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add parent directory to path to import modules
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from src.utils.image import Image
|
||||
|
||||
|
||||
def create_test_16bit_tiff(output_path: str) -> str:
|
||||
"""Create a test 16-bit grayscale TIFF file.
|
||||
|
||||
Args:
|
||||
output_path: Path where to save the test TIFF
|
||||
|
||||
Returns:
|
||||
Path to the created TIFF file
|
||||
"""
|
||||
# Create a 16-bit grayscale test image (100x100)
|
||||
# With values ranging from 0 to 65535 (full 16-bit range)
|
||||
height, width = 100, 100
|
||||
|
||||
# Create a gradient pattern
|
||||
test_data = np.zeros((height, width), dtype=np.uint16)
|
||||
for i in range(height):
|
||||
for j in range(width):
|
||||
# Create a diagonal gradient
|
||||
test_data[i, j] = int((i + j) / (height + width - 2) * 65535)
|
||||
|
||||
# Save as TIFF
|
||||
tifffile.imwrite(output_path, test_data)
|
||||
print(f"Created test 16-bit TIFF: {output_path}")
|
||||
print(f" Shape: {test_data.shape}")
|
||||
print(f" Dtype: {test_data.dtype}")
|
||||
print(f" Min value: {test_data.min()}")
|
||||
print(f" Max value: {test_data.max()}")
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def test_image_loading():
|
||||
"""Test loading 16-bit TIFF with the Image class."""
|
||||
print("\n=== Testing Image Loading ===")
|
||||
|
||||
# Create temporary test file
|
||||
with tempfile.NamedTemporaryFile(suffix=".tif", delete=False) as tmp:
|
||||
test_path = tmp.name
|
||||
|
||||
try:
|
||||
# Create test image
|
||||
create_test_16bit_tiff(test_path)
|
||||
|
||||
# Load with Image class
|
||||
print("\nLoading with Image class...")
|
||||
img = Image(test_path)
|
||||
|
||||
print(f"Successfully loaded image:")
|
||||
print(f" Width: {img.width}")
|
||||
print(f" Height: {img.height}")
|
||||
print(f" Channels: {img.channels}")
|
||||
print(f" Dtype: {img.dtype}")
|
||||
print(f" Format: {img.format}")
|
||||
|
||||
# Test normalization
|
||||
print("\nTesting normalization to float32 [0-1]...")
|
||||
normalized = img.to_normalized_float32()
|
||||
|
||||
print(f"Normalized image:")
|
||||
print(f" Shape: {normalized.shape}")
|
||||
print(f" Dtype: {normalized.dtype}")
|
||||
print(f" Min value: {normalized.min():.6f}")
|
||||
print(f" Max value: {normalized.max():.6f}")
|
||||
print(f" Mean value: {normalized.mean():.6f}")
|
||||
|
||||
# Verify normalization
|
||||
assert normalized.dtype == np.float32, "Dtype should be float32"
|
||||
assert (
|
||||
0.0 <= normalized.min() <= normalized.max() <= 1.0
|
||||
), "Values should be in [0, 1]"
|
||||
|
||||
print("\n✓ All tests passed!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ Test failed with error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
if os.path.exists(test_path):
|
||||
os.remove(test_path)
|
||||
print(f"\nCleaned up test file: {test_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_image_loading()
|
||||
sys.exit(0 if success else 1)
|
||||
211
tests/test_float32_training_loader.py
Normal file
211
tests/test_float32_training_loader.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
Test script for Float32 on-the-fly loading for 16-bit TIFFs.
|
||||
|
||||
This test verifies that:
|
||||
1. Float32YOLODataset can load 16-bit TIFF files
|
||||
2. Images are converted to float32 [0-1] in memory
|
||||
3. Grayscale is replicated to 3 channels (RGB)
|
||||
4. No disk caching is used
|
||||
5. Full 16-bit precision is preserved
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
import numpy as np
|
||||
import tifffile
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
|
||||
|
||||
def create_test_dataset():
|
||||
"""Create a minimal test dataset with 16-bit TIFF images."""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
dataset_dir = temp_dir / "test_dataset"
|
||||
|
||||
# Create directory structure
|
||||
train_images = dataset_dir / "train" / "images"
|
||||
train_labels = dataset_dir / "train" / "labels"
|
||||
train_images.mkdir(parents=True, exist_ok=True)
|
||||
train_labels.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create a 16-bit TIFF test image
|
||||
img_16bit = np.random.randint(0, 65536, (100, 100), dtype=np.uint16)
|
||||
img_path = train_images / "test_image.tif"
|
||||
tifffile.imwrite(str(img_path), img_16bit)
|
||||
|
||||
# Create a dummy label file
|
||||
label_path = train_labels / "test_image.txt"
|
||||
with open(label_path, "w") as f:
|
||||
f.write("0 0.5 0.5 0.2 0.2\n") # class_id x_center y_center width height
|
||||
|
||||
# Create data.yaml
|
||||
data_yaml = {
|
||||
"path": str(dataset_dir),
|
||||
"train": "train/images",
|
||||
"val": "train/images", # Use same for val in test
|
||||
"names": {0: "object"},
|
||||
"nc": 1,
|
||||
}
|
||||
yaml_path = dataset_dir / "data.yaml"
|
||||
with open(yaml_path, "w") as f:
|
||||
yaml.safe_dump(data_yaml, f)
|
||||
|
||||
print(f"✓ Created test dataset at: {dataset_dir}")
|
||||
print(f" - Image: {img_path} (shape={img_16bit.shape}, dtype={img_16bit.dtype})")
|
||||
print(f" - Min value: {img_16bit.min()}, Max value: {img_16bit.max()}")
|
||||
print(f" - data.yaml: {yaml_path}")
|
||||
|
||||
return dataset_dir, img_path, img_16bit
|
||||
|
||||
|
||||
def test_float32_dataset():
|
||||
"""Test the Float32YOLODataset class directly."""
|
||||
print("\n=== Testing Float32YOLODataset ===\n")
|
||||
|
||||
try:
|
||||
from src.utils.train_ultralytics_float import Float32YOLODataset
|
||||
|
||||
print("✓ Successfully imported Float32YOLODataset")
|
||||
except ImportError as e:
|
||||
print(f"✗ Failed to import Float32YOLODataset: {e}")
|
||||
return False
|
||||
|
||||
# Create test dataset
|
||||
dataset_dir, img_path, original_img = create_test_dataset()
|
||||
|
||||
try:
|
||||
# Initialize the dataset
|
||||
print("\nInitializing Float32YOLODataset...")
|
||||
dataset = Float32YOLODataset(
|
||||
images_dir=str(dataset_dir / "train" / "images"),
|
||||
labels_dir=str(dataset_dir / "train" / "labels"),
|
||||
img_size=640,
|
||||
)
|
||||
print(f"✓ Float32YOLODataset initialized with {len(dataset)} images")
|
||||
|
||||
# Get an item
|
||||
if len(dataset) > 0:
|
||||
print("\nGetting first item...")
|
||||
img_tensor, labels, filename = dataset[0]
|
||||
|
||||
print(f"✓ Item retrieved successfully")
|
||||
print(f" - Image tensor shape: {img_tensor.shape}")
|
||||
print(f" - Image tensor dtype: {img_tensor.dtype}")
|
||||
print(f" - Value range: [{img_tensor.min():.6f}, {img_tensor.max():.6f}]")
|
||||
print(f" - Filename: {filename}")
|
||||
print(f" - Labels: {len(labels)} annotations")
|
||||
if labels:
|
||||
print(
|
||||
f" - First label shape: {labels[0].shape if len(labels) > 0 else 'N/A'}"
|
||||
)
|
||||
|
||||
# Verify it's float32
|
||||
if img_tensor.dtype == torch.float32:
|
||||
print("✓ Correct dtype: float32")
|
||||
else:
|
||||
print(f"✗ Wrong dtype: {img_tensor.dtype} (expected float32)")
|
||||
return False
|
||||
|
||||
# Verify it's 3-channel in correct format (C, H, W)
|
||||
if len(img_tensor.shape) == 3 and img_tensor.shape[0] == 3:
|
||||
print(
|
||||
f"✓ Correct format: (C, H, W) = {img_tensor.shape} with 3 channels"
|
||||
)
|
||||
else:
|
||||
print(f"✗ Wrong shape: {img_tensor.shape} (expected (3, H, W))")
|
||||
return False
|
||||
|
||||
# Verify it's in [0, 1] range
|
||||
if 0.0 <= img_tensor.min() and img_tensor.max() <= 1.0:
|
||||
print("✓ Values in correct range: [0, 1]")
|
||||
else:
|
||||
print(
|
||||
f"✗ Values out of range: [{img_tensor.min()}, {img_tensor.max()}]"
|
||||
)
|
||||
return False
|
||||
|
||||
# Verify precision (should have many unique values)
|
||||
unique_values = len(torch.unique(img_tensor))
|
||||
print(f" - Unique values: {unique_values}")
|
||||
if unique_values > 256:
|
||||
print(f"✓ High precision maintained ({unique_values} > 256 levels)")
|
||||
else:
|
||||
print(f"⚠ Low precision: only {unique_values} unique values")
|
||||
|
||||
print("\n✓ All Float32YOLODataset tests passed!")
|
||||
return True
|
||||
else:
|
||||
print("✗ No items in dataset")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Error during testing: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def test_integration():
|
||||
"""Test integration with train_with_float32_loader."""
|
||||
print("\n=== Testing Integration with train_with_float32_loader ===\n")
|
||||
|
||||
# Create test dataset
|
||||
dataset_dir, img_path, original_img = create_test_dataset()
|
||||
data_yaml = dataset_dir / "data.yaml"
|
||||
|
||||
print(f"\nTest dataset ready at: {data_yaml}")
|
||||
print("\nTo test full training, run:")
|
||||
print(f" from src.utils.train_ultralytics_float import train_with_float32_loader")
|
||||
print(f" results = train_with_float32_loader(")
|
||||
print(f" model_path='yolov8n-seg.pt',")
|
||||
print(f" data_yaml='{data_yaml}',")
|
||||
print(f" epochs=1,")
|
||||
print(f" batch=1,")
|
||||
print(f" imgsz=640")
|
||||
print(f" )")
|
||||
print("\nThis will use custom training loop with Float32YOLODataset")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests."""
|
||||
import torch # Import here to ensure torch is available
|
||||
|
||||
print("=" * 70)
|
||||
print("Float32 Training Loader Test Suite")
|
||||
print("=" * 70)
|
||||
|
||||
results = []
|
||||
|
||||
# Test 1: Float32YOLODataset
|
||||
results.append(("Float32YOLODataset", test_float32_dataset()))
|
||||
|
||||
# Test 2: Integration check
|
||||
results.append(("Integration Check", test_integration()))
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 70)
|
||||
print("Test Summary")
|
||||
print("=" * 70)
|
||||
for test_name, passed in results:
|
||||
status = "✓ PASSED" if passed else "✗ FAILED"
|
||||
print(f"{status}: {test_name}")
|
||||
|
||||
all_passed = all(passed for _, passed in results)
|
||||
print("=" * 70)
|
||||
if all_passed:
|
||||
print("✓ All tests passed!")
|
||||
else:
|
||||
print("✗ Some tests failed")
|
||||
print("=" * 70)
|
||||
|
||||
return all_passed
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
import torch # Make torch available
|
||||
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -27,7 +27,7 @@ class TestImage:
|
||||
|
||||
def test_supported_extensions(self):
|
||||
"""Test that supported extensions are correctly defined."""
|
||||
expected_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
|
||||
expected_extensions = Image.SUPPORTED_EXTENSIONS
|
||||
assert Image.SUPPORTED_EXTENSIONS == expected_extensions
|
||||
|
||||
def test_image_properties(self, tmp_path):
|
||||
|
||||
1774
tests/test_pyside_freehand_tool
Normal file
1774
tests/test_pyside_freehand_tool
Normal file
File diff suppressed because it is too large
Load Diff
142
tests/test_training_dataset_prep.py
Normal file
142
tests/test_training_dataset_prep.py
Normal file
@@ -0,0 +1,142 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for training dataset preparation with 16-bit TIFFs.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import tifffile
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import sys
|
||||
import os
|
||||
import shutil
|
||||
|
||||
# Add parent directory to path to import modules
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from src.utils.image import Image
|
||||
|
||||
|
||||
def test_float32_3ch_conversion():
|
||||
"""Test conversion of 16-bit TIFF to 16-bit RGB PNG."""
|
||||
print("\n=== Testing 16-bit RGB PNG Conversion ===")
|
||||
|
||||
# Create temporary directory structure
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir = Path(tmpdir)
|
||||
src_dir = tmpdir / "original"
|
||||
dst_dir = tmpdir / "converted"
|
||||
src_dir.mkdir()
|
||||
dst_dir.mkdir()
|
||||
|
||||
# Create test 16-bit TIFF
|
||||
test_data = np.zeros((100, 100), dtype=np.uint16)
|
||||
for i in range(100):
|
||||
for j in range(100):
|
||||
test_data[i, j] = int((i + j) / 198 * 65535)
|
||||
|
||||
test_file = src_dir / "test_16bit.tif"
|
||||
tifffile.imwrite(test_file, test_data)
|
||||
print(f"Created test 16-bit TIFF: {test_file}")
|
||||
print(f" Shape: {test_data.shape}")
|
||||
print(f" Dtype: {test_data.dtype}")
|
||||
print(f" Range: [{test_data.min()}, {test_data.max()}]")
|
||||
|
||||
# Simulate the conversion process (matching training_tab.py)
|
||||
print("\nConverting to 16-bit RGB PNG using PIL merge...")
|
||||
img_obj = Image(test_file)
|
||||
from PIL import Image as PILImage
|
||||
|
||||
# Get uint16 data
|
||||
uint16_data = img_obj.data
|
||||
|
||||
# Use PIL's merge method with 'I;16' channels (proper way for 16-bit RGB)
|
||||
if len(uint16_data.shape) == 2:
|
||||
# Grayscale - replicate to RGB
|
||||
r_img = PILImage.fromarray(uint16_data, mode="I;16")
|
||||
g_img = PILImage.fromarray(uint16_data, mode="I;16")
|
||||
b_img = PILImage.fromarray(uint16_data, mode="I;16")
|
||||
else:
|
||||
r_img = PILImage.fromarray(uint16_data[:, :, 0], mode="I;16")
|
||||
g_img = PILImage.fromarray(
|
||||
(
|
||||
uint16_data[:, :, 1]
|
||||
if uint16_data.shape[2] > 1
|
||||
else uint16_data[:, :, 0]
|
||||
),
|
||||
mode="I;16",
|
||||
)
|
||||
b_img = PILImage.fromarray(
|
||||
(
|
||||
uint16_data[:, :, 2]
|
||||
if uint16_data.shape[2] > 2
|
||||
else uint16_data[:, :, 0]
|
||||
),
|
||||
mode="I;16",
|
||||
)
|
||||
|
||||
# Merge channels into RGB
|
||||
rgb_img = PILImage.merge("RGB", (r_img, g_img, b_img))
|
||||
|
||||
# Save as PNG
|
||||
output_file = dst_dir / "test_16bit_rgb.png"
|
||||
rgb_img.save(output_file)
|
||||
print(f"Saved 16-bit RGB PNG: {output_file}")
|
||||
print(f" PIL mode after merge: {rgb_img.mode}")
|
||||
|
||||
# Verify the output - Load with OpenCV (as YOLO does)
|
||||
import cv2
|
||||
|
||||
loaded = cv2.imread(str(output_file), cv2.IMREAD_UNCHANGED)
|
||||
print(f"\nVerifying output (loaded with OpenCV):")
|
||||
print(f" Shape: {loaded.shape}")
|
||||
print(f" Dtype: {loaded.dtype}")
|
||||
print(f" Channels: {loaded.shape[2] if len(loaded.shape) == 3 else 1}")
|
||||
print(f" Range: [{loaded.min()}, {loaded.max()}]")
|
||||
print(f" Unique values: {len(np.unique(loaded[:,:,0]))}")
|
||||
|
||||
# Assertions
|
||||
assert loaded.dtype == np.uint16, f"Expected uint16, got {loaded.dtype}"
|
||||
assert loaded.shape[2] == 3, f"Expected 3 channels, got {loaded.shape[2]}"
|
||||
assert (
|
||||
loaded.min() >= 0 and loaded.max() <= 65535
|
||||
), f"Expected [0,65535] range, got [{loaded.min()}, {loaded.max()}]"
|
||||
|
||||
# Verify all channels are identical (replicated grayscale)
|
||||
assert np.array_equal(
|
||||
loaded[:, :, 0], loaded[:, :, 1]
|
||||
), "Channel 0 and 1 should be identical"
|
||||
assert np.array_equal(
|
||||
loaded[:, :, 0], loaded[:, :, 2]
|
||||
), "Channel 0 and 2 should be identical"
|
||||
|
||||
# Verify no data loss
|
||||
unique_vals = len(np.unique(loaded[:, :, 0]))
|
||||
print(f"\n Precision check:")
|
||||
print(f" Unique values in channel: {unique_vals}")
|
||||
print(f" Source unique values: {len(np.unique(test_data))}")
|
||||
|
||||
assert unique_vals == len(
|
||||
np.unique(test_data)
|
||||
), f"Expected {len(np.unique(test_data))} unique values, got {unique_vals}"
|
||||
|
||||
print("\n✓ All conversion tests passed!")
|
||||
print(" - uint16 dtype preserved")
|
||||
print(" - 3 channels created")
|
||||
print(" - Range [0-65535] maintained")
|
||||
print(" - No precision loss from conversion")
|
||||
print(" - Channels properly replicated")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
success = test_float32_3ch_conversion()
|
||||
sys.exit(0 if success else 1)
|
||||
except Exception as e:
|
||||
print(f"\n✗ Test failed with error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
150
tests/test_yolo_16bit_float32.py
Normal file
150
tests/test_yolo_16bit_float32.py
Normal file
@@ -0,0 +1,150 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for YOLO preprocessing of 16-bit TIFF images with float32 passthrough.
|
||||
Verifies that no uint8 conversion occurs and data is preserved.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import tifffile
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add parent directory to path to import modules
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from src.model.yolo_wrapper import YOLOWrapper
|
||||
|
||||
|
||||
def create_test_16bit_tiff(output_path: str) -> str:
|
||||
"""Create a test 16-bit grayscale TIFF file.
|
||||
|
||||
Args:
|
||||
output_path: Path where to save the test TIFF
|
||||
|
||||
Returns:
|
||||
Path to the created TIFF file
|
||||
"""
|
||||
# Create a 16-bit grayscale test image (200x200)
|
||||
# With specific values to test precision preservation
|
||||
height, width = 200, 200
|
||||
|
||||
# Create a gradient pattern with the full 16-bit range
|
||||
test_data = np.zeros((height, width), dtype=np.uint16)
|
||||
for i in range(height):
|
||||
for j in range(width):
|
||||
# Create a diagonal gradient using full 16-bit range
|
||||
test_data[i, j] = int((i + j) / (height + width - 2) * 65535)
|
||||
|
||||
# Save as TIFF
|
||||
tifffile.imwrite(output_path, test_data)
|
||||
print(f"Created test 16-bit TIFF: {output_path}")
|
||||
print(f" Shape: {test_data.shape}")
|
||||
print(f" Dtype: {test_data.dtype}")
|
||||
print(f" Min value: {test_data.min()}")
|
||||
print(f" Max value: {test_data.max()}")
|
||||
print(
|
||||
f" Sample values: {test_data[50, 50]}, {test_data[100, 100]}, {test_data[150, 150]}"
|
||||
)
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def test_float32_passthrough():
|
||||
"""Test that 16-bit TIFF preprocessing passes float32 directly without uint8 conversion."""
|
||||
print("\n=== Testing Float32 Passthrough (NO uint8) ===")
|
||||
|
||||
# Create temporary test file
|
||||
with tempfile.NamedTemporaryFile(suffix=".tif", delete=False) as tmp:
|
||||
test_path = tmp.name
|
||||
|
||||
try:
|
||||
# Create test image
|
||||
create_test_16bit_tiff(test_path)
|
||||
|
||||
# Create YOLOWrapper instance
|
||||
print("\nTesting YOLOWrapper._prepare_source() for float32 passthrough...")
|
||||
wrapper = YOLOWrapper()
|
||||
|
||||
# Call _prepare_source to preprocess the image
|
||||
prepared_source, cleanup_path = wrapper._prepare_source(test_path)
|
||||
|
||||
print(f"\nPreprocessing result:")
|
||||
print(f" Original path: {test_path}")
|
||||
print(f" Prepared source type: {type(prepared_source)}")
|
||||
|
||||
# Verify it returns a numpy array (not a file path)
|
||||
if isinstance(prepared_source, np.ndarray):
|
||||
print(
|
||||
f"\n✓ SUCCESS: Prepared source is a numpy array (float32 passthrough)"
|
||||
)
|
||||
print(f" Shape: {prepared_source.shape}")
|
||||
print(f" Dtype: {prepared_source.dtype}")
|
||||
print(f" Min value: {prepared_source.min():.6f}")
|
||||
print(f" Max value: {prepared_source.max():.6f}")
|
||||
print(f" Mean value: {prepared_source.mean():.6f}")
|
||||
|
||||
# Verify it's float32 in [0, 1] range
|
||||
assert (
|
||||
prepared_source.dtype == np.float32
|
||||
), f"Expected float32, got {prepared_source.dtype}"
|
||||
assert (
|
||||
0.0 <= prepared_source.min() <= prepared_source.max() <= 1.0
|
||||
), f"Expected values in [0, 1], got [{prepared_source.min()}, {prepared_source.max()}]"
|
||||
|
||||
# Verify it has 3 channels (RGB)
|
||||
assert (
|
||||
prepared_source.shape[2] == 3
|
||||
), f"Expected 3 channels (RGB), got {prepared_source.shape[2]}"
|
||||
|
||||
# Verify no quantization to 256 levels (would happen with uint8 conversion)
|
||||
unique_values = len(np.unique(prepared_source))
|
||||
print(f" Unique values: {unique_values}")
|
||||
|
||||
# With float32, we should have much more than 256 unique values
|
||||
if unique_values > 256:
|
||||
print(f"\n✓ SUCCESS: Data has {unique_values} unique values (> 256)")
|
||||
print(f" This confirms NO uint8 quantization occurred!")
|
||||
else:
|
||||
print(f"\n✗ WARNING: Data has only {unique_values} unique values")
|
||||
print(f" This might indicate uint8 quantization happened")
|
||||
|
||||
# Sample some values to show precision
|
||||
print(f"\n Sample normalized values:")
|
||||
print(f" [50, 50]: {prepared_source[50, 50, 0]:.8f}")
|
||||
print(f" [100, 100]: {prepared_source[100, 100, 0]:.8f}")
|
||||
print(f" [150, 150]: {prepared_source[150, 150, 0]:.8f}")
|
||||
|
||||
# No cleanup needed since we returned array directly
|
||||
assert (
|
||||
cleanup_path is None
|
||||
), "Cleanup path should be None for float32 pass through"
|
||||
|
||||
print("\n✓ All float32 passthrough tests passed!")
|
||||
return True
|
||||
|
||||
else:
|
||||
print(f"\n✗ FAILED: Prepared source is a file path: {prepared_source}")
|
||||
print(f" This means data was saved to disk, not passed as float32 array")
|
||||
if cleanup_path and os.path.exists(cleanup_path):
|
||||
os.remove(cleanup_path)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ Test failed with error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
if os.path.exists(test_path):
|
||||
os.remove(test_path)
|
||||
print(f"\nCleaned up test file: {test_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_float32_passthrough()
|
||||
sys.exit(0 if success else 1)
|
||||
126
tests/test_yolo_16bit_preprocessing.py
Normal file
126
tests/test_yolo_16bit_preprocessing.py
Normal file
@@ -0,0 +1,126 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for YOLO preprocessing of 16-bit TIFF images.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import tifffile
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add parent directory to path to import modules
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from src.model.yolo_wrapper import YOLOWrapper
|
||||
from src.utils.image import Image
|
||||
from PIL import Image as PILImage
|
||||
|
||||
|
||||
def create_test_16bit_tiff(output_path: str) -> str:
|
||||
"""Create a test 16-bit grayscale TIFF file.
|
||||
|
||||
Args:
|
||||
output_path: Path where to save the test TIFF
|
||||
|
||||
Returns:
|
||||
Path to the created TIFF file
|
||||
"""
|
||||
# Create a 16-bit grayscale test image (200x200)
|
||||
# With values ranging from 0 to 65535 (full 16-bit range)
|
||||
height, width = 200, 200
|
||||
|
||||
# Create a gradient pattern
|
||||
test_data = np.zeros((height, width), dtype=np.uint16)
|
||||
for i in range(height):
|
||||
for j in range(width):
|
||||
# Create a diagonal gradient
|
||||
test_data[i, j] = int((i + j) / (height + width - 2) * 65535)
|
||||
|
||||
# Save as TIFF
|
||||
tifffile.imwrite(output_path, test_data)
|
||||
print(f"Created test 16-bit TIFF: {output_path}")
|
||||
print(f" Shape: {test_data.shape}")
|
||||
print(f" Dtype: {test_data.dtype}")
|
||||
print(f" Min value: {test_data.min()}")
|
||||
print(f" Max value: {test_data.max()}")
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def test_yolo_preprocessing():
|
||||
"""Test YOLO preprocessing of 16-bit TIFF images."""
|
||||
print("\n=== Testing YOLO Preprocessing of 16-bit TIFF ===")
|
||||
|
||||
# Create temporary test file
|
||||
with tempfile.NamedTemporaryFile(suffix=".tif", delete=False) as tmp:
|
||||
test_path = tmp.name
|
||||
|
||||
try:
|
||||
# Create test image
|
||||
create_test_16bit_tiff(test_path)
|
||||
|
||||
# Create YOLOWrapper instance (no actual model loading needed for this test)
|
||||
print("\nTesting YOLOWrapper._prepare_source()...")
|
||||
wrapper = YOLOWrapper()
|
||||
|
||||
# Call _prepare_source to preprocess the image
|
||||
prepared_path, cleanup_path = wrapper._prepare_source(test_path)
|
||||
|
||||
print(f"\nPreprocessing complete:")
|
||||
print(f" Original path: {test_path}")
|
||||
print(f" Prepared path: {prepared_path}")
|
||||
print(f" Cleanup path: {cleanup_path}")
|
||||
|
||||
# Verify the prepared image exists
|
||||
assert os.path.exists(prepared_path), "Prepared image should exist"
|
||||
|
||||
# Load the prepared image and verify it's uint8 RGB
|
||||
prepared_img = PILImage.open(prepared_path)
|
||||
print(f"\nPrepared image properties:")
|
||||
print(f" Mode: {prepared_img.mode}")
|
||||
print(f" Size: {prepared_img.size}")
|
||||
print(f" Format: {prepared_img.format}")
|
||||
|
||||
# Convert to numpy to check values
|
||||
img_array = np.array(prepared_img)
|
||||
print(f" Shape: {img_array.shape}")
|
||||
print(f" Dtype: {img_array.dtype}")
|
||||
print(f" Min value: {img_array.min()}")
|
||||
print(f" Max value: {img_array.max()}")
|
||||
print(f" Mean value: {img_array.mean():.2f}")
|
||||
|
||||
# Verify it's RGB uint8
|
||||
assert prepared_img.mode == "RGB", "Prepared image should be RGB"
|
||||
assert img_array.dtype == np.uint8, "Prepared image should be uint8"
|
||||
assert img_array.shape[2] == 3, "Prepared image should have 3 channels"
|
||||
assert (
|
||||
0 <= img_array.min() <= img_array.max() <= 255
|
||||
), "Values should be in [0, 255]"
|
||||
|
||||
# Cleanup prepared file if needed
|
||||
if cleanup_path and os.path.exists(cleanup_path):
|
||||
os.remove(cleanup_path)
|
||||
print(f"\nCleaned up prepared image: {cleanup_path}")
|
||||
|
||||
print("\n✓ All YOLO preprocessing tests passed!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ Test failed with error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
if os.path.exists(test_path):
|
||||
os.remove(test_path)
|
||||
print(f"Cleaned up test file: {test_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_yolo_preprocessing()
|
||||
sys.exit(0 if success else 1)
|
||||
Reference in New Issue
Block a user