6 Commits

Author SHA1 Message Date
c7e1271193 Adding file 2025-12-13 09:42:00 +02:00
aec0fbf83c Adding standalone training script and update 2025-12-13 09:28:24 +02:00
908e9a5b82 Bug fix 2025-12-13 01:18:16 +02:00
edcd448a61 Update, cleanup 2025-12-13 01:06:40 +02:00
2411223a14 Adding test scripts 2025-12-13 00:32:32 +02:00
b3b1e3acff Implementing float 32 data managent 2025-12-13 00:31:23 +02:00
31 changed files with 4980 additions and 2691 deletions

57
config/app_config.yaml Normal file
View File

@@ -0,0 +1,57 @@
database:
path: data/detections.db
image_repository:
base_path: ''
allowed_extensions:
- .jpg
- .jpeg
- .png
- .tif
- .tiff
- .bmp
models:
default_base_model: yolov8s-seg.pt
models_directory: data/models
base_model_choices:
- yolov8s-seg.pt
- yolo11s-seg.pt
training:
default_epochs: 100
default_batch_size: 16
default_imgsz: 1024
default_patience: 50
default_lr0: 0.01
two_stage:
enabled: false
stage1:
epochs: 20
lr0: 0.0005
patience: 10
freeze: 10
stage2:
epochs: 150
lr0: 0.0003
patience: 30
last_dataset_yaml: /home/martin/code/object_detection/data/datasets/data.yaml
last_dataset_dir: /home/martin/code/object_detection/data/datasets
detection:
default_confidence: 0.25
default_iou: 0.45
max_batch_size: 100
visualization:
bbox_colors:
organelle: '#FF6B6B'
membrane_branch: '#4ECDC4'
default: '#00FF00'
bbox_thickness: 2
font_size: 12
export:
formats:
- csv
- json
- excel
default_format: csv
logging:
level: INFO
file: logs/app.log
format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'

300
docs/16BIT_TIFF_SUPPORT.md Normal file
View File

@@ -0,0 +1,300 @@
# 16-bit TIFF Support for YOLO Object Detection
## Overview
This document describes the implementation of 16-bit grayscale TIFF support for YOLO object detection. The system properly loads 16-bit TIFF images, normalizes them to float32 [0-1], and handles them appropriately for both **inference** and **training** **without uint8 conversion** to preserve the full dynamic range and avoid data loss.
## Key Features
✅ Reads 16-bit or float32 images using tifffile
✅ Converts to float32 [0-1] (NO uint8 conversion)
✅ Replicates grayscale → RGB (3 channels)
**Inference**: Passes numpy arrays directly to YOLO (no file I/O)
**Training**: On-the-fly float32 conversion (NO disk caching)
✅ Uses Ultralytics YOLOv8/v11 models
✅ Works with segmentation models
✅ No data loss, no double normalization, no silent clipping
## Changes Made
### 1. Dependencies ([`requirements.txt`](../requirements.txt:14))
- Added `tifffile>=2023.0.0` for reliable 16-bit TIFF loading
### 2. Image Loading ([`src/utils/image.py`](../src/utils/image.py))
#### Enhanced TIFF Loading
- Modified [`Image._load()`](../src/utils/image.py:87) to use `tifffile` for `.tif` and `.tiff` files
- Preserves original 16-bit data type during loading
- Properly handles both grayscale and multi-channel TIFF files
#### New Normalization Method
Added [`Image.to_normalized_float32()`](../src/utils/image.py:280) method that:
- Converts image data to `float32`
- Properly scales values to [0, 1] range:
- **16-bit images**: divides by 65535 (full dynamic range)
- 8-bit images: divides by 255
- Float images: clips to [0, 1]
- Handles various data types automatically
### 3. YOLO Preprocessing ([`src/model/yolo_wrapper.py`](../src/model/yolo_wrapper.py))
Enhanced [`YOLOWrapper._prepare_source()`](../src/model/yolo_wrapper.py:231) to:
1. Detect 16-bit TIFF files automatically
2. Load and normalize to float32 [0-1] using the new method
3. Replicate grayscale to RGB (3 channels)
4. **Return numpy array directly** (NO file saving, NO uint8 conversion)
5. Pass float32 array directly to YOLO for inference
## Processing Pipeline
### For Inference (predict)
For 16-bit TIFF files during inference:
1. **Load**: File loaded using `tifffile` → preserves 16-bit uint16 data
2. **Normalize**: Convert to float32 and scale to [0, 1]
```python
float_data = uint16_data.astype(np.float32) / 65535.0
```
3. **RGB Conversion**: Replicate grayscale to 3 channels
```python
rgb_float = np.stack([float_data] * 3, axis=-1)
```
4. **Pass to YOLO**: Return float32 array directly (no uint8, no file I/O)
5. **Inference**: YOLO processes the float32 [0-1] RGB array
### For Training (train)
Training now uses a custom dataset loader with on-the-fly conversion (NO disk caching):
1. **Custom Dataset**: Uses `Float32Dataset` class that extends Ultralytics' `YOLODataset`
2. **Load On-The-Fly**: Each image is loaded and converted during training:
- Detect 16-bit TIFF files automatically
- Load with `tifffile` (preserves uint16)
- Convert to float32 [0-1] in memory
- Replicate to 3 channels (RGB)
3. **No Disk Cache**: Conversion happens in memory, no files written
4. **Train**: YOLO trains on float32 [0-1] RGB arrays directly
See [`src/utils/train_ultralytics_float.py`](../src/utils/train_ultralytics_float.py) for implementation.
### No Data Loss!
Unlike approaches that convert to uint8 (256 levels), this implementation:
- Preserves full 16-bit dynamic range (65536 levels)
- Maintains precision with float32 representation
- For inference: passes data directly without file conversions
- For training: uses float32 TIFFs (not uint8 PNGs)
## Usage
### Basic Image Loading
```python
from src.utils.image import Image
# Load a 16-bit TIFF file
img = Image("path/to/16bit_image.tif")
# Get normalized float32 data [0-1]
normalized = img.to_normalized_float32() # Shape: (H, W), dtype: float32
# Original data is preserved
original = img.data # Still uint16
```
### YOLO Inference
The preprocessing is automatic - just use YOLO as normal:
```python
from src.model.yolo_wrapper import YOLOWrapper
# Initialize model
yolo = YOLOWrapper("yolov8s-seg.pt")
yolo.load_model()
# Perform inference on 16-bit TIFF
# The image will be automatically normalized and passed as float32 [0-1]
detections = yolo.predict("path/to/16bit_image.tif", conf=0.25)
```
### With InferenceEngine
```python
from src.model.inference import InferenceEngine
from src.database.db_manager import DatabaseManager
# Setup
db = DatabaseManager("database.db")
engine = InferenceEngine("model.pt", db, model_id=1)
# Detect objects in 16-bit TIFF
result = engine.detect_single(
image_path="path/to/16bit_image.tif",
relative_path="images/16bit_image.tif",
conf=0.25
)
```
## Testing
Three test scripts are provided:
### 1. Image Loading Test
```bash
./venv/bin/python tests/test_16bit_tiff_loading.py
```
Tests:
- Loading 16-bit TIFF files with tifffile
- Normalization to float32 [0-1]
- Data type and value range verification
### 2. Float32 Passthrough Test (Most Important!)
```bash
./venv/bin/python tests/test_yolo_16bit_float32.py
```
Tests:
- YOLO preprocessing returns numpy array (not file path)
- Data is float32 [0-1] (not uint8)
- No quantization to 256 levels (proves no uint8 conversion)
- Sample output:
```
✓ SUCCESS: Prepared source is a numpy array (float32 passthrough)
Shape: (200, 200, 3)
Dtype: float32
Min value: 0.000000
Max value: 1.000000
Unique values: 399
✓ SUCCESS: Data has 399 unique values (> 256)
This confirms NO uint8 quantization occurred!
```
### 3. Legacy Test (Shows Old Behavior)
```bash
./venv/bin/python tests/test_yolo_16bit_preprocessing.py
```
This test shows the old behavior (uint8 conversion) - kept for comparison.
## Benefits
1. **No Data Loss**: Preserves full 16-bit dynamic range (65536 levels vs 256)
2. **High Precision**: Float32 maintains fine-grained intensity differences
3. **Automatic Processing**: No manual preprocessing needed
4. **YOLO Compatible**: Ultralytics YOLO accepts float32 [0-1] arrays
5. **Performance**: No intermediate file I/O for 16-bit TIFFs
6. **Backwards Compatible**: Regular images (8-bit PNG, JPEG, etc.) still work as before
## Technical Notes
### Float32 vs uint8
**With uint8 conversion (OLD - BAD):**
- 16-bit (65536 levels) → uint8 (256 levels) = **99.6% data loss!**
- Fine intensity differences are lost
- Quantization artifacts
**With float32 [0-1] (NEW - GOOD):**
- 16-bit (65536 levels) → float32 (continuous) = **No data loss**
- Full dynamic range preserved
- Smooth gradients maintained
### Memory Considerations
For a 2048×2048 single-channel image:
| Format | Memory | Disk Space | Notes |
|--------|--------|------------|-------|
| Original 16-bit | 8 MB | ~8 MB | uint16 grayscale TIFF |
| Float32 grayscale | 16 MB | - | Intermediate |
| Float32 3-channel | 48 MB | ~48 MB | Training cache |
| uint8 RGB (old) | 12 MB | ~12 MB | OLD approach with data loss |
The float32 approach uses ~3× more memory than uint8 during training but preserves **all information**.
**No Disk Cache**: The new on-the-fly approach eliminates the need for cached datasets on disk.
### Why Direct Numpy Array?
Passing numpy arrays directly to YOLO (instead of saving to file):
1. **Faster**: No disk I/O overhead
2. **No Quantization**: Avoids PNG/JPEG quantization
3. **Memory Efficient**: Single copy in memory
4. **Cleaner**: No temp file management
Ultralytics YOLO supports various input types:
- File paths (str): `"image.jpg"`
- Numpy arrays: `np.ndarray` ← **we use this**
- PIL Images: `PIL.Image`
- Torch tensors: `torch.Tensor`
## Training with Float32 Dataset Loader
The system now includes a custom dataset loader for 16-bit TIFF training:
```python
from src.utils.train_ultralytics_float import train_with_float32_loader
# Train with on-the-fly float32 conversion
results = train_with_float32_loader(
model_path="yolov8s-seg.pt",
data_yaml="data/my_dataset/data.yaml",
epochs=100,
batch=16,
imgsz=640,
)
```
The `Float32Dataset` class automatically:
- Detects 16-bit TIFF files
- Loads with `tifffile` (not PIL/cv2)
- Converts to float32 [0-1] on-the-fly
- Replicates to 3 channels
- Integrates seamlessly with Ultralytics training pipeline
This is used automatically by the training tab in the GUI.
## Installation
Install the updated dependencies:
```bash
./venv/bin/pip install -r requirements.txt
```
Or install tifffile directly:
```bash
./venv/bin/pip install tifffile>=2023.0.0
```
## Example Test Output
```
=== Testing Float32 Passthrough (NO uint8) ===
Created test 16-bit TIFF: /tmp/tmpdt5hm0ab.tif
Shape: (200, 200)
Dtype: uint16
Min value: 0
Max value: 65535
Preprocessing result:
Prepared source type: <class 'numpy.ndarray'>
✓ SUCCESS: Prepared source is a numpy array (float32 passthrough)
Shape: (200, 200, 3)
Dtype: float32
Min value: 0.000000
Max value: 1.000000
Mean value: 0.499992
Unique values: 399
✓ SUCCESS: Data has 399 unique values (> 256)
This confirms NO uint8 quantization occurred!
✓ All float32 passthrough tests passed!

269
docs/TRAINING_16BIT_TIFF.md Normal file
View File

@@ -0,0 +1,269 @@
# Training YOLO with 16-bit TIFF Datasets
## Quick Start
If your dataset contains 16-bit grayscale TIFF files, the training tab will automatically:
1. Detect 16-bit TIFF images in your dataset
2. Convert them to float32 [0-1] RGB **on-the-fly** during training
3. Train without any disk caching (memory-efficient)
**No manual intervention or disk space needed!**
## Why Float32 On-The-Fly Conversion?
### The Problem
YOLO's training expects:
- 3-channel images (RGB)
- Images loaded from disk by the dataloader
16-bit grayscale TIFFs are:
- 1-channel (grayscale)
- Need to be converted to RGB format
### The Solution
**NEW APPROACH (Current)**: On-the-fly float32 conversion
- Load 16-bit TIFF with `tifffile` (not PIL/cv2)
- Convert uint16 [0-65535] → float32 [0-1] in memory
- Replicate grayscale to 3 channels
- Pass directly to YOLO training pipeline
- **No disk caching required!**
**OLD APPROACH (Deprecated)**: Disk caching
- Created 16-bit RGB PNG cache files on disk
- Required ~2x dataset size in disk space
- Slower first training run
## How It Works
### Custom Dataset Loader
The system uses a custom `Float32Dataset` class that extends Ultralytics' `YOLODataset`:
```python
from src.utils.train_ultralytics_float import Float32Dataset
# This dataset loader:
# 1. Intercepts image loading
# 2. Detects 16-bit TIFFs
# 3. Converts to float32 [0-1] RGB on-the-fly
# 4. Passes to training pipeline
```
### Conversion Process
For each 16-bit grayscale TIFF during training:
```
1. Load with tifffile → uint16 [0, 65535]
2. Convert to float32 → img.astype(float32) / 65535.0
3. Replicate to RGB → np.stack([img] * 3, axis=-1)
4. Result: float32 [0, 1] RGB array, shape (H, W, 3)
```
### Memory vs Disk
| Aspect | On-the-fly (NEW) | Disk Cache (OLD) |
|--------|------------------|------------------|
| Disk Space | Dataset size only | ~2× dataset size |
| First Training | Fast | Slow (creates cache) |
| Subsequent Training | Fast | Fast |
| Data Loss | None | None |
| Setup Required | None | Cache creation |
## Data Preservation
### Float32 Precision
16-bit TIFF: 65,536 levels (0-65535)
Float32: ~7 decimal digits precision
**Conversion accuracy:**
```python
Original: 32768 (uint16, middle intensity)
Float32: 32768 / 65535 = 0.50000763 (exact)
```
Full 16-bit precision is preserved in float32 representation.
### Comparison to uint8
| Approach | Precision Loss | Recommended |
|----------|----------------|-------------|
| **float32 [0-1]** | None | ✓ YES |
| uint16 RGB | None | ✓ YES (but disk-heavy) |
| uint8 | 99.6% data loss | ✗ NO |
**Why NO uint8:**
```
Original values: 32768, 32769, 32770 (distinct)
Converted to uint8: 128, 128, 128 (collapsed!)
```
Multiple 16-bit values collapse to the same uint8 value.
## Training Tab Behavior
When you click "Start Training" with a 16-bit TIFF dataset:
```
[01:23:45] Exported 150 annotations across 50 image(s).
[01:23:45] Using Float32 on-the-fly loader for 16-bit TIFF support (no disk caching)
[01:23:45] Starting training run 'my_model_v1' using yolov8s-seg.pt
[01:23:46] Using Float32Dataset loader for 16-bit TIFF support
```
Every training run uses the same approach - fast and efficient!
## Inference vs Training
| Operation | Input | Processing | Output to YOLO |
|-----------|-------|------------|----------------|
| **Inference** | 16-bit TIFF file | Load → float32 [0-1] → 3ch | numpy array (float32) |
| **Training** | 16-bit TIFF dataset | Load on-the-fly → float32 [0-1] → 3ch | numpy array (float32) |
Both preserve full 16-bit precision using float32 representation.
## Technical Details
### Custom Dataset Class
Located in `src/utils/train_ultralytics_float.py`:
```python
class Float32Dataset(YOLODataset):
"""
Extends Ultralytics YOLODataset to handle 16-bit TIFFs.
Key methods:
- load_image(): Intercepts image loading
- Detects .tif/.tiff with dtype == uint16
- Converts: uint16 → float32 [0-1] → RGB (3-channel)
"""
```
### Integration with YOLO
The `YOLOWrapper.train()` method automatically uses the custom loader:
```python
# In src/model/yolo_wrapper.py
def train(self, data_yaml, use_float32_loader=True, **kwargs):
if use_float32_loader:
# Use custom Float32Dataset
return train_with_float32_loader(...)
else:
# Standard YOLO training
return self.model.train(...)
```
### No PIL or cv2 for 16-bit
16-bit TIFF loading uses `tifffile` directly:
- PIL: Can load 16-bit but converts during processing
- cv2: Limited 16-bit TIFF support
- tifffile: Native 16-bit support, numpy output
## Advantages Over Disk Caching
### 1. No Disk Space Required
```
Dataset: 1000 images × 12 MB = 12 GB
Old cache: Additional 24 GB (16-bit RGB PNGs)
New approach: 0 GB additional (on-the-fly)
```
### 2. Faster Setup
```
Old: First training requires cache creation (minutes)
New: Start training immediately (seconds)
```
### 3. Always In Sync
```
Old: Cache could become stale if images change
New: Always loads current version from disk
```
### 4. Simpler Workflow
```
Old: Manage cache directory, cleanup, etc.
New: Just point to dataset and train
```
## Troubleshooting
### Error: "expected input to have 3 channels, but got 1"
This shouldn't happen with the new Float32Dataset, but if it does:
1. Check that `use_float32_loader=True` in training call
2. Verify `Float32Dataset` is being used (check logs)
3. Ensure `tifffile` is installed: `pip install tifffile`
### Memory Usage
On-the-fly conversion uses memory during training:
- Image loaded: ~24 MB (2048×2048 uint16)
- Converted float32 RGB: ~48 MB (temporary)
- Released after augmentation pipeline
**Mitigation:**
- Reduce batch size if OOM errors occur
- Images are processed one at a time during loading
- Only active batch kept in memory
### Slow Training
If training seems slow:
- Check disk I/O (slow disk can bottleneck loading)
- Verify images aren't being re-converted each epoch (should cache after first load)
- Monitor CPU usage during loading
## Migration from Old Approach
If you have existing cached datasets:
```bash
# Old cache location (safe to delete)
rm -rf data/datasets/_float32_cache/
# The new approach doesn't use this directory
```
Your original dataset structure remains unchanged:
```
data/my_dataset/
├── train/
│ ├── images/ (original 16-bit TIFFs)
│ └── labels/
├── val/
│ ├── images/
│ └── labels/
└── data.yaml
```
Just point to the same `data.yaml` and train!
## Performance Comparison
| Metric | Old (Disk Cache) | New (On-the-fly) |
|--------|------------------|------------------|
| First training setup | 5-10 min | 0 sec |
| Disk space overhead | 100% | 0% |
| Training speed | Fast | Fast |
| Subsequent runs | Fast | Fast |
| Data accuracy | 16-bit preserved | 16-bit preserved |
## Summary
**On-the-fly conversion**: Load and convert during training
**No disk caching**: Zero additional disk space
**Full precision**: Float32 preserves 16-bit dynamic range
**No PIL/cv2**: Direct tifffile loading
**Automatic**: Works transparently with training tab
**Fast**: Efficient memory-based conversion
The new approach is simpler, faster to set up, and requires no disk space overhead!

View File

@@ -82,12 +82,12 @@ include-package-data = true
"src.database" = ["*.sql"]
[tool.black]
line-length = 120
line-length = 88
target-version = ['py38', 'py39', 'py310', 'py311']
include = '\.pyi?$'
[tool.pylint.messages_control]
max-line-length = 120
max-line-length = 88
[tool.mypy]
python_version = "3.8"

View File

@@ -11,6 +11,7 @@ pyqtgraph>=0.13.0
opencv-python>=4.8.0
Pillow>=10.0.0
numpy>=1.24.0
tifffile>=2023.0.0
# Database
sqlalchemy>=2.0.0

View File

@@ -0,0 +1,179 @@
# Standalone Float32 Training Script for 16-bit TIFFs
## Overview
This standalone script (`train_float32_standalone.py`) trains YOLO models on 16-bit grayscale TIFF datasets with **no data loss**.
- Loads 16-bit TIFFs with `tifffile` (not PIL/cv2)
- Converts to float32 [0-1] on-the-fly (preserves full 16-bit precision)
- Replicates grayscale → 3-channel RGB in memory
- **No disk caching required**
- Uses custom PyTorch Dataset + training loop
## Quick Start
```bash
# Activate virtual environment
source venv/bin/activate
# Train on your 16-bit TIFF dataset
python scripts/train_float32_standalone.py \
--data data/my_dataset/data.yaml \
--weights yolov8s-seg.pt \
--epochs 100 \
--batch 16 \
--imgsz 640 \
--lr 0.0001 \
--save-dir runs/my_training \
--device cuda
```
## Arguments
| Argument | Required | Default | Description |
|----------|----------|---------|-------------|
| `--data` | Yes | - | Path to YOLO data.yaml file |
| `--weights` | No | yolov8s-seg.pt | Pretrained model weights |
| `--epochs` | No | 100 | Number of training epochs |
| `--batch` | No | 16 | Batch size |
| `--imgsz` | No | 640 | Input image size |
| `--lr` | No | 0.0001 | Learning rate |
| `--save-dir` | No | runs/train | Directory to save checkpoints |
| `--device` | No | cuda/cpu | Training device (auto-detected) |
## Dataset Format
Your data.yaml should follow standard YOLO format:
```yaml
path: /path/to/dataset
train: train/images
val: val/images
test: test/images # optional
names:
0: class1
1: class2
nc: 2
```
Directory structure:
```
dataset/
├── train/
│ ├── images/
│ │ ├── img1.tif (16-bit grayscale TIFF)
│ │ └── img2.tif
│ └── labels/
│ ├── img1.txt (YOLO format)
│ └── img2.txt
├── val/
│ ├── images/
│ └── labels/
└── data.yaml
```
## Output
The script saves:
- `epoch{N}.pt`: Checkpoint after each epoch
- `best.pt`: Best model weights (lowest loss)
- Training logs to console
## Features
**16-bit precision preserved**: Float32 [0-1] maintains full dynamic range
**No disk caching**: Conversion happens in memory
**No PIL/cv2**: Direct tifffile loading
**Variable-length labels**: Handles segmentation polygons
**Checkpoint saving**: Resume training if interrupted
**Best model tracking**: Automatically saves best weights
## Example
Train a segmentation model on microscopy data:
```bash
python scripts/train_float32_standalone.py \
--data data/microscopy/data.yaml \
--weights yolov11s-seg.pt \
--epochs 150 \
--batch 8 \
--imgsz 1024 \
--lr 0.0003 \
--save-dir data/models/microscopy_v1
```
## Troubleshooting
### Out of Memory (OOM)
Reduce batch size:
```bash
--batch 4
```
### Slow Loading
Reduce num_workers (edit script line 208):
```python
num_workers=2 # instead of 4
```
### Different Image Sizes
The script expects all images to have the same dimensions. For variable sizes:
1. Implement letterbox/resize in dataset's `_read_image()`
2. Or preprocess images to same size
### Loss Computation Errors
If you see "Cannot determine loss", the script may need adjustment for your Ultralytics version. Check:
```python
# In train() function, the preds format may vary
# Current script assumes: preds is tuple with loss OR dict with 'loss' key
```
## vs GUI Training
| Feature | Standalone Script | GUI Training Tab |
|---------|------------------|------------------|
| Float32 conversion | ✓ Yes | ✓ Yes (automatic) |
| Disk caching | ✗ None | ✗ None |
| Progress UI | ✗ Console only | ✓ Visual progress bar |
| Dataset selection | Manual CLI args | ✓ GUI browsing |
| Multi-stage training | Manual runs | ✓ Built-in |
| Use case | Advanced users | General users |
## Technical Details
### Data Loading Pipeline
```
16-bit TIFF file
↓ (tifffile.imread)
uint16 [0-65535]
↓ (/ 65535.0)
float32 [0-1]
↓ (replicate channels)
float32 RGB (H,W,3) [0-1]
↓ (permute to C,H,W)
torch.Tensor (3,H,W) float32
↓ (DataLoader stack)
Batch (B,3,H,W) float32
YOLO Model
```
### Precision Comparison
| Method | Unique Values | Data Loss |
|--------|---------------|-----------|
| **float32 [0-1]** | ~65,536 | None ✓ |
| uint16 RGB | 65,536 | None ✓ |
| uint8 | 256 | 99.6% ✗ |
Example: Pixel value 32,768 (middle intensity)
- Float32: 32768 / 65535.0 = 0.50000763 (exact)
- uint8: 32768 → 128 → many values collapse!
## License
Same as main project.

View File

@@ -0,0 +1,349 @@
#!/usr/bin/env python3
"""
Standalone training script for YOLO with 16-bit TIFF float32 support.
This script trains YOLO models on 16-bit grayscale TIFF datasets without data loss.
Converts images to float32 [0-1] on-the-fly using tifffile (no PIL/cv2).
Usage:
python scripts/train_float32_standalone.py \\
--data path/to/data.yaml \\
--weights yolov8s-seg.pt \\
--epochs 100 \\
--batch 16 \\
--imgsz 640
Based on the custom dataset approach to avoid Ultralytics' channel conversion issues.
"""
import argparse
import os
import sys
import time
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import tifffile
import yaml
from torch.utils.data import Dataset, DataLoader
from ultralytics import YOLO
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from src.utils.logger import get_logger
logger = get_logger(__name__)
# ===================== Dataset =====================
class Float32YOLODataset(Dataset):
"""PyTorch dataset for 16-bit TIFF images with float32 conversion."""
def __init__(self, images_dir, labels_dir, img_size=640):
self.images_dir = Path(images_dir)
self.labels_dir = Path(labels_dir)
self.img_size = img_size
# Find images
extensions = {".tif", ".tiff", ".png", ".jpg", ".jpeg", ".bmp"}
self.paths = sorted(
[
p
for p in self.images_dir.rglob("*")
if p.is_file() and p.suffix.lower() in extensions
]
)
if not self.paths:
raise ValueError(f"No images found in {images_dir}")
logger.info(f"Dataset: {len(self.paths)} images from {images_dir}")
def __len__(self):
return len(self.paths)
def _read_image(self, path: Path) -> np.ndarray:
"""Load image as float32 [0-1] RGB."""
# Load with tifffile
img = tifffile.imread(str(path))
# Convert to float32
img = img.astype(np.float32)
# Normalize 16-bit→[0,1]
if img.max() > 1.5:
img = img / 65535.0
img = np.clip(img, 0.0, 1.0)
# Grayscale→RGB
if img.ndim == 2:
img = np.repeat(img[..., None], 3, axis=2)
elif img.ndim == 3 and img.shape[2] == 1:
img = np.repeat(img, 3, axis=2)
return img # float32 (H,W,3) [0,1]
def _parse_label(self, path: Path) -> np.ndarray:
"""Parse YOLO label with variable-length rows."""
if not path.exists():
return np.zeros((0, 5), dtype=np.float32)
labels = []
with open(path, "r") as f:
for line in f:
vals = line.strip().split()
if len(vals) >= 5:
labels.append([float(v) for v in vals])
return (
np.array(labels, dtype=np.float32)
if labels
else np.zeros((0, 5), dtype=np.float32)
)
def __getitem__(self, idx):
img_path = self.paths[idx]
label_path = self.labels_dir / f"{img_path.stem}.txt"
# Load & convert to tensor (C,H,W)
img = self._read_image(img_path)
img_t = torch.from_numpy(img).permute(2, 0, 1).contiguous()
# Load labels
labels = self._parse_label(label_path)
return img_t, labels, str(img_path.name)
# ===================== Collate =====================
def collate_fn(batch):
"""Stack images, keep labels as list."""
imgs = torch.stack([b[0] for b in batch], dim=0)
labels = [b[1] for b in batch]
names = [b[2] for b in batch]
return imgs, labels, names
# ===================== Training =====================
def get_pytorch_model(ul_model):
"""Extract PyTorch model and loss from Ultralytics wrapper."""
pt_model = None
loss_fn = None
# Try common patterns
if hasattr(ul_model, "model"):
pt_model = ul_model.model
if pt_model and hasattr(pt_model, "model"):
pt_model = pt_model.model
# Find loss
if pt_model and hasattr(pt_model, "loss"):
loss_fn = pt_model.loss
elif pt_model and hasattr(pt_model, "compute_loss"):
loss_fn = pt_model.compute_loss
if pt_model is None:
raise RuntimeError("Could not extract PyTorch model")
return pt_model, loss_fn
def train(args):
"""Main training function."""
device = args.device
logger.info(f"Device: {device}")
# Parse data.yaml
with open(args.data, "r") as f:
data_config = yaml.safe_load(f)
dataset_root = Path(data_config.get("path", Path(args.data).parent))
train_img = dataset_root / data_config.get("train", "train/images")
val_img = dataset_root / data_config.get("val", "val/images")
train_lbl = train_img.parent / "labels"
val_lbl = val_img.parent / "labels"
# Load model
logger.info(f"Loading {args.weights}")
ul_model = YOLO(args.weights)
pt_model, loss_fn = get_pytorch_model(ul_model)
# Configure model args
from types import SimpleNamespace
if not hasattr(pt_model, "args"):
pt_model.args = SimpleNamespace()
if isinstance(pt_model.args, dict):
pt_model.args = SimpleNamespace(**pt_model.args)
# Set segmentation loss args
pt_model.args.overlap_mask = getattr(pt_model.args, "overlap_mask", True)
pt_model.args.mask_ratio = getattr(pt_model.args, "mask_ratio", 4)
pt_model.args.task = "segment"
pt_model.to(device)
pt_model.train()
# Create datasets
train_ds = Float32YOLODataset(str(train_img), str(train_lbl), args.imgsz)
val_ds = Float32YOLODataset(str(val_img), str(val_lbl), args.imgsz)
train_loader = DataLoader(
train_ds,
batch_size=args.batch,
shuffle=True,
num_workers=4,
pin_memory=(device == "cuda"),
collate_fn=collate_fn,
)
val_loader = DataLoader(
val_ds,
batch_size=args.batch,
shuffle=False,
num_workers=2,
pin_memory=(device == "cuda"),
collate_fn=collate_fn,
)
# Optimizer
optimizer = torch.optim.AdamW(pt_model.parameters(), lr=args.lr)
# Training loop
os.makedirs(args.save_dir, exist_ok=True)
best_loss = float("inf")
for epoch in range(args.epochs):
t0 = time.time()
running_loss = 0.0
num_batches = 0
for imgs, labels_list, names in train_loader:
imgs = imgs.to(device)
optimizer.zero_grad()
num_batches += 1
# Forward (simple approach - just use preds)
preds = pt_model(imgs)
# Try to compute loss
# Simplest fallback: if preds is tuple/list, assume last element is loss
if isinstance(preds, (tuple, list)):
# Often YOLO forward returns (preds, loss) in training mode
if (
len(preds) >= 2
and isinstance(preds[-1], dict)
and "loss" in preds[-1]
):
loss = preds[-1]["loss"]
elif len(preds) >= 2 and isinstance(preds[-1], torch.Tensor):
loss = preds[-1]
else:
# Manually compute using loss_fn if available
if loss_fn:
# This may fail - see logs
try:
loss_out = loss_fn(preds, labels_list)
loss = (
loss_out[0]
if isinstance(loss_out, (tuple, list))
else loss_out
)
except Exception as e:
logger.error(f"Loss computation failed: {e}")
logger.error(
"Consider using Ultralytics .train() or check model/loss compatibility"
)
raise
else:
raise RuntimeError("Cannot determine loss from model output")
elif isinstance(preds, dict) and "loss" in preds:
loss = preds["loss"]
else:
raise RuntimeError(f"Unexpected preds format: {type(preds)}")
# Backward
loss.backward()
optimizer.step()
running_loss += loss.item()
if (num_batches % 10) == 0:
logger.info(
f"Epoch {epoch+1} Batch {num_batches} Loss: {loss.item():.4f}"
)
epoch_loss = running_loss / max(1, num_batches)
epoch_time = time.time() - t0
logger.info(
f"Epoch {epoch+1}/{args.epochs} - Loss: {epoch_loss:.4f}, Time: {epoch_time:.1f}s"
)
# Save checkpoint
ckpt = Path(args.save_dir) / f"epoch{epoch+1}.pt"
torch.save(
{
"epoch": epoch + 1,
"model_state_dict": pt_model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": epoch_loss,
},
ckpt,
)
# Save best
if epoch_loss < best_loss:
best_loss = epoch_loss
best_ckpt = Path(args.save_dir) / "best.pt"
torch.save(pt_model.state_dict(), best_ckpt)
logger.info(f"New best: {best_ckpt}")
logger.info("Training complete")
# ===================== Main =====================
def parse_args():
parser = argparse.ArgumentParser(
description="Train YOLO on 16-bit TIFF with float32"
)
parser.add_argument("--data", type=str, required=True, help="Path to data.yaml")
parser.add_argument(
"--weights", type=str, default="yolov8s-seg.pt", help="Pretrained weights"
)
parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
parser.add_argument("--batch", type=int, default=16, help="Batch size")
parser.add_argument("--imgsz", type=int, default=640, help="Image size")
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
parser.add_argument(
"--save-dir", type=str, default="runs/train", help="Save directory"
)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
logger.info("=" * 70)
logger.info("Float32 16-bit TIFF Training - Standalone Script")
logger.info("=" * 70)
logger.info(f"Data: {args.data}")
logger.info(f"Weights: {args.weights}")
logger.info(f"Epochs: {args.epochs}, Batch: {args.batch}, ImgSz: {args.imgsz}")
logger.info(f"LR: {args.lr}, Device: {args.device}")
logger.info("=" * 70)
train(args)

View File

@@ -60,7 +60,9 @@ class DatabaseManager:
cursor = conn.cursor()
# Check if annotations table exists
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='annotations'")
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='annotations'"
)
if not cursor.fetchone():
# Table doesn't exist yet, no migration needed
return
@@ -201,28 +203,6 @@ class DatabaseManager:
finally:
conn.close()
def delete_model(self, model_id: int) -> bool:
"""Delete a model from the database.
Note: detections referencing this model are deleted automatically via
the `detections.model_id` foreign key (ON DELETE CASCADE).
Args:
model_id: ID of the model to delete.
Returns:
True if a model row was deleted, False otherwise.
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("DELETE FROM models WHERE id = ?", (model_id,))
conn.commit()
return cursor.rowcount > 0
finally:
conn.close()
# ==================== Image Operations ====================
def add_image(
@@ -262,7 +242,9 @@ class DatabaseManager:
return cursor.lastrowid
except sqlite3.IntegrityError:
# Image already exists, return its ID
cursor.execute("SELECT id FROM images WHERE relative_path = ?", (relative_path,))
cursor.execute(
"SELECT id FROM images WHERE relative_path = ?", (relative_path,)
)
row = cursor.fetchone()
return row["id"] if row else None
finally:
@@ -273,13 +255,17 @@ class DatabaseManager:
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM images WHERE relative_path = ?", (relative_path,))
cursor.execute(
"SELECT * FROM images WHERE relative_path = ?", (relative_path,)
)
row = cursor.fetchone()
return dict(row) if row else None
finally:
conn.close()
def get_or_create_image(self, relative_path: str, filename: str, width: int, height: int) -> int:
def get_or_create_image(
self, relative_path: str, filename: str, width: int, height: int
) -> int:
"""Get existing image or create new one."""
existing = self.get_image_by_path(relative_path)
if existing:
@@ -369,8 +355,16 @@ class DatabaseManager:
bbox[2],
bbox[3],
det["confidence"],
(json.dumps(det.get("segmentation_mask")) if det.get("segmentation_mask") else None),
(json.dumps(det.get("metadata")) if det.get("metadata") else None),
(
json.dumps(det.get("segmentation_mask"))
if det.get("segmentation_mask")
else None
),
(
json.dumps(det.get("metadata"))
if det.get("metadata")
else None
),
),
)
conn.commit()
@@ -415,13 +409,12 @@ class DatabaseManager:
if filters:
conditions = []
for key, value in filters.items():
if key.startswith("d.") or key.startswith("i.") or key.startswith("m."):
if "like" in value.lower():
conditions.append(f"{key} LIKE ?")
params.append(value.split(" ")[1])
else:
if (
key.startswith("d.")
or key.startswith("i.")
or key.startswith("m.")
):
conditions.append(f"{key} = ?")
params.append(value)
else:
conditions.append(f"d.{key} = ?")
params.append(value)
@@ -449,14 +442,18 @@ class DatabaseManager:
finally:
conn.close()
def get_detections_for_image(self, image_id: int, model_id: Optional[int] = None) -> List[Dict]:
def get_detections_for_image(
self, image_id: int, model_id: Optional[int] = None
) -> List[Dict]:
"""Get all detections for a specific image."""
filters = {"image_id": image_id}
if model_id:
filters["model_id"] = model_id
return self.get_detections(filters)
def delete_detections_for_image(self, image_id: int, model_id: Optional[int] = None) -> int:
def delete_detections_for_image(
self, image_id: int, model_id: Optional[int] = None
) -> int:
"""Delete detections tied to a specific image and optional model."""
conn = self.get_connection()
try:
@@ -484,22 +481,6 @@ class DatabaseManager:
finally:
conn.close()
def delete_all_detections(self) -> int:
"""Delete all detections from the database.
Returns:
Number of rows deleted.
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("DELETE FROM detections")
conn.commit()
return cursor.rowcount
finally:
conn.close()
# ==================== Statistics Operations ====================
def get_detection_statistics(
@@ -543,7 +524,9 @@ class DatabaseManager:
""",
params,
)
class_counts = {row["class_name"]: row["count"] for row in cursor.fetchall()}
class_counts = {
row["class_name"]: row["count"] for row in cursor.fetchall()
}
# Average confidence
cursor.execute(
@@ -600,7 +583,9 @@ class DatabaseManager:
# ==================== Export Operations ====================
def export_detections_to_csv(self, output_path: str, filters: Optional[Dict] = None) -> bool:
def export_detections_to_csv(
self, output_path: str, filters: Optional[Dict] = None
) -> bool:
"""Export detections to CSV file."""
try:
detections = self.get_detections(filters)
@@ -629,7 +614,9 @@ class DatabaseManager:
for det in detections:
row = {k: det[k] for k in fieldnames if k in det}
# Convert segmentation mask list to JSON string for CSV
if row.get("segmentation_mask") and isinstance(row["segmentation_mask"], list):
if row.get("segmentation_mask") and isinstance(
row["segmentation_mask"], list
):
row["segmentation_mask"] = json.dumps(row["segmentation_mask"])
writer.writerow(row)
@@ -638,7 +625,9 @@ class DatabaseManager:
print(f"Error exporting to CSV: {e}")
return False
def export_detections_to_json(self, output_path: str, filters: Optional[Dict] = None) -> bool:
def export_detections_to_json(
self, output_path: str, filters: Optional[Dict] = None
) -> bool:
"""Export detections to JSON file."""
try:
detections = self.get_detections(filters)
@@ -658,75 +647,6 @@ class DatabaseManager:
# ==================== Annotation Operations ====================
def get_annotated_images_summary(
self,
name_filter: Optional[str] = None,
order_by: str = "filename",
order_dir: str = "ASC",
limit: Optional[int] = None,
offset: int = 0,
) -> List[Dict]:
"""Return images that have at least one manual annotation.
Args:
name_filter: Optional substring filter applied to filename/relative_path.
order_by: One of: 'filename', 'relative_path', 'annotation_count', 'added_at'.
order_dir: 'ASC' or 'DESC'.
limit: Optional max number of rows.
offset: Pagination offset.
Returns:
List of dicts: {id, relative_path, filename, added_at, annotation_count}
"""
allowed_order_by = {
"filename": "i.filename",
"relative_path": "i.relative_path",
"annotation_count": "annotation_count",
"added_at": "i.added_at",
}
order_expr = allowed_order_by.get(order_by, "i.filename")
dir_norm = str(order_dir).upper().strip()
if dir_norm not in {"ASC", "DESC"}:
dir_norm = "ASC"
conn = self.get_connection()
try:
params: List[Any] = []
where_sql = ""
if name_filter:
# Case-insensitive substring search.
token = f"%{name_filter}%"
where_sql = "WHERE (i.filename LIKE ? OR i.relative_path LIKE ?)"
params.extend([token, token])
limit_sql = ""
if limit is not None:
limit_sql = " LIMIT ? OFFSET ?"
params.extend([int(limit), int(offset)])
query = f"""
SELECT
i.id,
i.relative_path,
i.filename,
i.added_at,
COUNT(a.id) AS annotation_count
FROM images i
JOIN annotations a ON a.image_id = i.id
{where_sql}
GROUP BY i.id
HAVING annotation_count > 0
ORDER BY {order_expr} {dir_norm}
{limit_sql}
"""
cursor = conn.cursor()
cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()]
finally:
conn.close()
def add_annotation(
self,
image_id: int,
@@ -865,13 +785,17 @@ class DatabaseManager:
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM object_classes WHERE class_name = ?", (class_name,))
cursor.execute(
"SELECT * FROM object_classes WHERE class_name = ?", (class_name,)
)
row = cursor.fetchone()
return dict(row) if row else None
finally:
conn.close()
def add_object_class(self, class_name: str, color: str, description: Optional[str] = None) -> int:
def add_object_class(
self, class_name: str, color: str, description: Optional[str] = None
) -> int:
"""
Add a new object class.
@@ -1004,7 +928,8 @@ class DatabaseManager:
if not split_map[required]:
raise ValueError(
"Unable to determine %s image directory under %s. Provide it "
"explicitly via the 'splits' argument." % (required, dataset_root_path)
"explicitly via the 'splits' argument."
% (required, dataset_root_path)
)
yaml_splits: Dict[str, str] = {}
@@ -1030,7 +955,11 @@ class DatabaseManager:
if yaml_splits.get("test"):
payload["test"] = yaml_splits["test"]
output_path_obj = Path(output_path).expanduser() if output_path else dataset_root_path / "data.yaml"
output_path_obj = (
Path(output_path).expanduser()
if output_path
else dataset_root_path / "data.yaml"
)
output_path_obj.parent.mkdir(parents=True, exist_ok=True)
with open(output_path_obj, "w", encoding="utf-8") as handle:
@@ -1090,9 +1019,15 @@ class DatabaseManager:
for split_name, options in patterns.items():
for relative in options:
candidate = (dataset_root / relative).resolve()
if candidate.exists() and candidate.is_dir() and self._directory_has_images(candidate):
if (
candidate.exists()
and candidate.is_dir()
and self._directory_has_images(candidate)
):
try:
inferred[split_name] = candidate.relative_to(dataset_root).as_posix()
inferred[split_name] = candidate.relative_to(
dataset_root
).as_posix()
except ValueError:
inferred[split_name] = candidate.as_posix()
break

View File

@@ -55,7 +55,10 @@ CREATE TABLE IF NOT EXISTS object_classes (
-- Insert default object classes
INSERT OR IGNORE INTO object_classes (class_name, color, description) VALUES
('terminal', '#FFFF00', 'Axion terminal');
('cell', '#FF0000', 'Cell object'),
('nucleus', '#00FF00', 'Cell nucleus'),
('mitochondria', '#0000FF', 'Mitochondria'),
('vesicle', '#FFFF00', 'Vesicle');
-- Annotations table: stores manual annotations
CREATE TABLE IF NOT EXISTS annotations (

View File

@@ -1,7 +1,6 @@
"""Main window for the microscopy object detection application."""
import shutil
from pathlib import Path
"""
Main window for the microscopy object detection application.
"""
from PySide6.QtWidgets import (
QMainWindow,
@@ -21,7 +20,6 @@ from src.database.db_manager import DatabaseManager
from src.utils.config_manager import ConfigManager
from src.utils.logger import get_logger
from src.gui.dialogs.config_dialog import ConfigDialog
from src.gui.dialogs.delete_model_dialog import DeleteModelDialog
from src.gui.tabs.detection_tab import DetectionTab
from src.gui.tabs.training_tab import TrainingTab
from src.gui.tabs.validation_tab import ValidationTab
@@ -93,12 +91,6 @@ class MainWindow(QMainWindow):
db_stats_action.triggered.connect(self._show_database_stats)
tools_menu.addAction(db_stats_action)
tools_menu.addSeparator()
delete_model_action = QAction("Delete &Model…", self)
delete_model_action.triggered.connect(self._show_delete_model_dialog)
tools_menu.addAction(delete_model_action)
# Help menu
help_menu = menubar.addMenu("&Help")
@@ -125,10 +117,10 @@ class MainWindow(QMainWindow):
# Add tabs to widget
self.tab_widget.addTab(self.detection_tab, "Detection")
self.tab_widget.addTab(self.results_tab, "Results")
self.tab_widget.addTab(self.annotation_tab, "Annotation")
self.tab_widget.addTab(self.training_tab, "Training")
self.tab_widget.addTab(self.validation_tab, "Validation")
self.tab_widget.addTab(self.results_tab, "Results")
self.tab_widget.addTab(self.annotation_tab, "Annotation (Future)")
# Connect tab change signal
self.tab_widget.currentChanged.connect(self._on_tab_changed)
@@ -160,7 +152,9 @@ class MainWindow(QMainWindow):
"""Center window on screen."""
screen = self.screen().geometry()
size = self.geometry()
self.move((screen.width() - size.width()) // 2, (screen.height() - size.height()) // 2)
self.move(
(screen.width() - size.width()) // 2, (screen.height() - size.height()) // 2
)
def _restore_window_state(self):
"""Restore window geometry from settings or center window."""
@@ -199,10 +193,6 @@ class MainWindow(QMainWindow):
self.training_tab.refresh()
if hasattr(self, "results_tab"):
self.results_tab.refresh()
if hasattr(self, "annotation_tab"):
self.annotation_tab.refresh()
if hasattr(self, "validation_tab"):
self.validation_tab.refresh()
except Exception as e:
logger.error(f"Error applying settings: {e}")
@@ -219,14 +209,6 @@ class MainWindow(QMainWindow):
logger.debug(f"Switched to tab: {tab_name}")
self._update_status(f"Viewing: {tab_name}")
# Ensure the Annotation tab always shows up-to-date DB-backed lists.
try:
current_widget = self.tab_widget.widget(index)
if hasattr(self, "annotation_tab") and current_widget is self.annotation_tab:
self.annotation_tab.refresh()
except Exception as exc:
logger.debug(f"Failed to refresh annotation tab on selection: {exc}")
def _show_database_stats(self):
"""Show database statistics dialog."""
try:
@@ -249,230 +231,10 @@ class MainWindow(QMainWindow):
except Exception as e:
logger.error(f"Error getting database stats: {e}")
QMessageBox.warning(self, "Error", f"Failed to get database statistics:\n{str(e)}")
def _show_delete_model_dialog(self) -> None:
"""Open the model deletion dialog."""
dialog = DeleteModelDialog(self.db_manager, self)
if not dialog.exec():
return
model_ids = dialog.selected_model_ids
if not model_ids:
return
self._delete_models(model_ids)
def _delete_models(self, model_ids: list[int]) -> None:
"""Delete one or more models from the database and remove artifacts from disk."""
deleted_count = 0
removed_paths: list[str] = []
remove_errors: list[str] = []
for model_id in model_ids:
model = None
try:
model = self.db_manager.get_model_by_id(int(model_id))
except Exception as exc:
logger.error(f"Failed to load model {model_id} before deletion: {exc}")
if not model:
remove_errors.append(f"Model id {model_id} not found in database.")
continue
try:
deleted = self.db_manager.delete_model(int(model_id))
except Exception as exc:
logger.error(f"Failed to delete model {model_id}: {exc}")
remove_errors.append(f"Failed to delete model id {model_id} from DB: {exc}")
continue
if not deleted:
remove_errors.append(f"Model id {model_id} was not deleted (already removed?).")
continue
deleted_count += 1
removed, errors = self._delete_model_artifacts_from_disk(model)
removed_paths.extend(removed)
remove_errors.extend(errors)
# Refresh tabs to reflect the deletion(s).
try:
if hasattr(self, "detection_tab"):
self.detection_tab.refresh()
if hasattr(self, "results_tab"):
self.results_tab.refresh()
if hasattr(self, "validation_tab"):
self.validation_tab.refresh()
if hasattr(self, "training_tab"):
self.training_tab.refresh()
except Exception as exc:
logger.warning(f"Failed to refresh tabs after model deletion: {exc}")
details: list[str] = []
if removed_paths:
details.append("Removed from disk:\n" + "\n".join(removed_paths))
if remove_errors:
details.append("\nDisk cleanup warnings:\n" + "\n".join(remove_errors))
QMessageBox.information(
self,
"Delete Model",
f"Deleted {deleted_count} model(s) from database." + ("\n\n" + "\n".join(details) if details else ""),
QMessageBox.warning(
self, "Error", f"Failed to get database statistics:\n{str(e)}"
)
def _delete_model(self, model_id: int) -> None:
"""Delete a model from the database and remove its artifacts from disk."""
model = None
try:
model = self.db_manager.get_model_by_id(model_id)
except Exception as exc:
logger.error(f"Failed to load model {model_id} before deletion: {exc}")
if not model:
QMessageBox.warning(self, "Delete Model", "Selected model was not found in the database.")
return
model_path = str(model.get("model_path") or "")
try:
deleted = self.db_manager.delete_model(model_id)
except Exception as exc:
logger.error(f"Failed to delete model {model_id}: {exc}")
QMessageBox.critical(self, "Delete Model", f"Failed to delete model from database:\n{exc}")
return
if not deleted:
QMessageBox.warning(self, "Delete Model", "No model was deleted (it may have already been removed).")
return
removed_paths, remove_errors = self._delete_model_artifacts_from_disk(model)
# Refresh tabs to reflect the deletion.
try:
if hasattr(self, "detection_tab"):
self.detection_tab.refresh()
if hasattr(self, "results_tab"):
self.results_tab.refresh()
if hasattr(self, "validation_tab"):
self.validation_tab.refresh()
if hasattr(self, "training_tab"):
self.training_tab.refresh()
except Exception as exc:
logger.warning(f"Failed to refresh tabs after model deletion: {exc}")
details = []
if model_path:
details.append(f"Deleted model record for: {model_path}")
if removed_paths:
details.append("\nRemoved from disk:\n" + "\n".join(removed_paths))
if remove_errors:
details.append("\nDisk cleanup warnings:\n" + "\n".join(remove_errors))
QMessageBox.information(
self,
"Delete Model",
"Model deleted from database." + ("\n\n" + "\n".join(details) if details else ""),
)
def _delete_model_artifacts_from_disk(self, model: dict) -> tuple[list[str], list[str]]:
"""Best-effort removal of model artifacts on disk.
Strategy:
- Remove run directories inferred from:
- model.model_path (…/<run>/weights/*.pt => <run>)
- training_params.stage_results[].results.save_dir
but only if they are under the configured models directory.
- If the weights file itself exists and is outside the models directory, delete only the file.
Returns:
(removed_paths, errors)
"""
removed: list[str] = []
errors: list[str] = []
models_root = Path(self.config_manager.get_models_directory() or "data/models").expanduser()
try:
models_root_resolved = models_root.resolve()
except Exception:
models_root_resolved = models_root
inferred_dirs: list[Path] = []
# 1) From model_path
model_path_value = model.get("model_path")
if model_path_value:
try:
p = Path(str(model_path_value)).expanduser()
p_resolved = p.resolve() if p.exists() else p
if p_resolved.is_file():
if p_resolved.parent.name == "weights" and p_resolved.parent.parent.exists():
inferred_dirs.append(p_resolved.parent.parent)
elif p_resolved.parent.exists():
inferred_dirs.append(p_resolved.parent)
except Exception:
pass
# 2) From training_params.stage_results[].results.save_dir
training_params = model.get("training_params") or {}
if isinstance(training_params, dict):
stage_results = training_params.get("stage_results")
if isinstance(stage_results, list):
for stage in stage_results:
results = (stage or {}).get("results")
save_dir = (results or {}).get("save_dir") if isinstance(results, dict) else None
if not save_dir:
continue
try:
d = Path(str(save_dir)).expanduser()
if d.exists() and d.is_dir():
inferred_dirs.append(d)
except Exception:
continue
# Deduplicate inferred_dirs
unique_dirs: list[Path] = []
seen: set[str] = set()
for d in inferred_dirs:
try:
key = str(d.resolve())
except Exception:
key = str(d)
if key in seen:
continue
seen.add(key)
unique_dirs.append(d)
# Delete directories under models_root
for d in unique_dirs:
try:
d_resolved = d.resolve()
except Exception:
d_resolved = d
try:
if d_resolved.exists() and d_resolved.is_dir() and d_resolved.is_relative_to(models_root_resolved):
shutil.rmtree(d_resolved)
removed.append(str(d_resolved))
except Exception as exc:
errors.append(f"Failed to remove directory {d_resolved}: {exc}")
# If nothing matched (e.g., model_path outside models_root), delete just the file.
if model_path_value:
try:
p = Path(str(model_path_value)).expanduser()
if p.exists() and p.is_file():
p_resolved = p.resolve()
if not p_resolved.is_relative_to(models_root_resolved):
p_resolved.unlink()
removed.append(str(p_resolved))
except Exception as exc:
errors.append(f"Failed to remove model file {model_path_value}: {exc}")
return removed, errors
def _show_about(self):
"""Show about dialog."""
about_text = """
@@ -539,11 +301,6 @@ class MainWindow(QMainWindow):
if hasattr(self, "training_tab"):
self.training_tab.shutdown()
if hasattr(self, "annotation_tab"):
# Best-effort refresh so DB-backed UI state is consistent at shutdown.
try:
self.annotation_tab.refresh()
except Exception:
pass
self.annotation_tab.save_state()
logger.info("Application closing")

View File

@@ -13,11 +13,6 @@ from PySide6.QtWidgets import (
QFileDialog,
QMessageBox,
QSplitter,
QLineEdit,
QTableWidget,
QTableWidgetItem,
QHeaderView,
QAbstractItemView,
)
from PySide6.QtCore import Qt, QSettings
from pathlib import Path
@@ -34,7 +29,9 @@ logger = get_logger(__name__)
class AnnotationTab(QWidget):
"""Annotation tab for manual image annotation."""
def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
def __init__(
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
):
super().__init__(parent)
self.db_manager = db_manager
self.config_manager = config_manager
@@ -55,32 +52,6 @@ class AnnotationTab(QWidget):
self.main_splitter = QSplitter(Qt.Horizontal)
self.main_splitter.setHandleWidth(10)
# { Left-most pane: annotated images list
annotated_group = QGroupBox("Annotated Images")
annotated_layout = QVBoxLayout()
filter_row = QHBoxLayout()
filter_row.addWidget(QLabel("Filter:"))
self.annotated_filter_edit = QLineEdit()
self.annotated_filter_edit.setPlaceholderText("Type to filter by image name…")
self.annotated_filter_edit.textChanged.connect(self._refresh_annotated_images_list)
filter_row.addWidget(self.annotated_filter_edit, 1)
annotated_layout.addLayout(filter_row)
self.annotated_images_table = QTableWidget(0, 2)
self.annotated_images_table.setHorizontalHeaderLabels(["Image", "Annotations"])
self.annotated_images_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Stretch)
self.annotated_images_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeToContents)
self.annotated_images_table.setSelectionBehavior(QAbstractItemView.SelectRows)
self.annotated_images_table.setSelectionMode(QAbstractItemView.SingleSelection)
self.annotated_images_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
self.annotated_images_table.setSortingEnabled(True)
self.annotated_images_table.itemSelectionChanged.connect(self._on_annotated_image_selected)
annotated_layout.addWidget(self.annotated_images_table, 1)
annotated_group.setLayout(annotated_layout)
# }
# { Left splitter for image display and zoom info
self.left_splitter = QSplitter(Qt.Vertical)
self.left_splitter.setHandleWidth(10)
@@ -91,9 +62,6 @@ class AnnotationTab(QWidget):
# Use the AnnotationCanvasWidget
self.annotation_canvas = AnnotationCanvasWidget()
# Auto-zoom so newly loaded images fill the available canvas viewport.
# (Matches the behavior used in ResultsTab.)
self.annotation_canvas.set_auto_fit_to_view(True)
self.annotation_canvas.zoom_changed.connect(self._on_zoom_changed)
self.annotation_canvas.annotation_drawn.connect(self._on_annotation_drawn)
# Selection of existing polylines (when tool is not in drawing mode)
@@ -104,7 +72,9 @@ class AnnotationTab(QWidget):
self.left_splitter.addWidget(canvas_group)
# Controls info
controls_info = QLabel("Zoom: Mouse wheel or +/- keys | Drawing: Enable pen and drag mouse")
controls_info = QLabel(
"Zoom: Mouse wheel or +/- keys | Drawing: Enable pen and drag mouse"
)
controls_info.setStyleSheet("QLabel { color: #888; font-style: italic; }")
self.left_splitter.addWidget(controls_info)
# }
@@ -115,20 +85,36 @@ class AnnotationTab(QWidget):
# Annotation tools section
self.annotation_tools = AnnotationToolsWidget(self.db_manager)
self.annotation_tools.polyline_enabled_changed.connect(self.annotation_canvas.set_polyline_enabled)
self.annotation_tools.polyline_pen_color_changed.connect(self.annotation_canvas.set_polyline_pen_color)
self.annotation_tools.polyline_pen_width_changed.connect(self.annotation_canvas.set_polyline_pen_width)
self.annotation_tools.polyline_enabled_changed.connect(
self.annotation_canvas.set_polyline_enabled
)
self.annotation_tools.polyline_pen_color_changed.connect(
self.annotation_canvas.set_polyline_pen_color
)
self.annotation_tools.polyline_pen_width_changed.connect(
self.annotation_canvas.set_polyline_pen_width
)
# Show / hide bounding boxes
self.annotation_tools.show_bboxes_changed.connect(self.annotation_canvas.set_show_bboxes)
self.annotation_tools.show_bboxes_changed.connect(
self.annotation_canvas.set_show_bboxes
)
# RDP simplification controls
self.annotation_tools.simplify_on_finish_changed.connect(self._on_simplify_on_finish_changed)
self.annotation_tools.simplify_epsilon_changed.connect(self._on_simplify_epsilon_changed)
self.annotation_tools.simplify_on_finish_changed.connect(
self._on_simplify_on_finish_changed
)
self.annotation_tools.simplify_epsilon_changed.connect(
self._on_simplify_epsilon_changed
)
# Class selection and class-color changes
self.annotation_tools.class_selected.connect(self._on_class_selected)
self.annotation_tools.class_color_changed.connect(self._on_class_color_changed)
self.annotation_tools.clear_annotations_requested.connect(self._on_clear_annotations)
self.annotation_tools.clear_annotations_requested.connect(
self._on_clear_annotations
)
# Delete selected annotation on canvas
self.annotation_tools.delete_selected_annotation_requested.connect(self._on_delete_selected_annotation)
self.annotation_tools.delete_selected_annotation_requested.connect(
self._on_delete_selected_annotation
)
self.right_splitter.addWidget(self.annotation_tools)
# Image loading section
@@ -151,13 +137,12 @@ class AnnotationTab(QWidget):
self.right_splitter.addWidget(load_group)
# }
# Add list + both splitters to the main horizontal splitter
self.main_splitter.addWidget(annotated_group)
# Add both splitters to the main horizontal splitter
self.main_splitter.addWidget(self.left_splitter)
self.main_splitter.addWidget(self.right_splitter)
# Set initial sizes: list (left), canvas (middle), controls (right)
self.main_splitter.setSizes([320, 650, 280])
# Set initial sizes: 75% for left (image), 25% for right (controls)
self.main_splitter.setSizes([750, 250])
layout.addWidget(self.main_splitter)
self.setLayout(layout)
@@ -165,9 +150,6 @@ class AnnotationTab(QWidget):
# Restore splitter positions from settings
self._restore_state()
# Populate list on startup.
self._refresh_annotated_images_list()
def _load_image(self):
"""Load and display an image file."""
# Get last opened directory from QSettings
@@ -198,24 +180,12 @@ class AnnotationTab(QWidget):
self.current_image_path = file_path
# Store the directory for next time
settings.setValue("annotation_tab/last_directory", str(Path(file_path).parent))
settings.setValue(
"annotation_tab/last_directory", str(Path(file_path).parent)
)
# Get or create image in database
repo_root = self.config_manager.get_image_repository_path()
relative_path: str
try:
if repo_root:
repo_root_path = Path(repo_root).expanduser().resolve()
file_resolved = Path(file_path).expanduser().resolve()
if file_resolved.is_relative_to(repo_root_path):
relative_path = file_resolved.relative_to(repo_root_path).as_posix()
else:
# Fallback: store filename only to avoid leaking absolute paths.
relative_path = file_resolved.name
else:
relative_path = str(Path(file_path).name)
except Exception:
relative_path = str(Path(file_path).name)
relative_path = str(Path(file_path).name) # Simplified for now
self.current_image_id = self.db_manager.get_or_create_image(
relative_path,
Path(file_path).name,
@@ -229,9 +199,6 @@ class AnnotationTab(QWidget):
# Load and display any existing annotations for this image
self._load_annotations_for_current_image()
# Update annotated images list (newly annotated image added/selected).
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
# Update info label
self._update_image_info()
@@ -239,7 +206,9 @@ class AnnotationTab(QWidget):
except ImageLoadError as e:
logger.error(f"Failed to load image: {e}")
QMessageBox.critical(self, "Error Loading Image", f"Failed to load image:\n{str(e)}")
QMessageBox.critical(
self, "Error Loading Image", f"Failed to load image:\n{str(e)}"
)
except Exception as e:
logger.error(f"Unexpected error loading image: {e}")
QMessageBox.critical(self, "Error", f"Unexpected error:\n{str(e)}")
@@ -327,9 +296,6 @@ class AnnotationTab(QWidget):
# Reload annotations from DB and redraw (respecting current class filter)
self._load_annotations_for_current_image()
# Update list counts.
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
except Exception as e:
logger.error(f"Failed to save annotation: {e}")
QMessageBox.critical(self, "Error", f"Failed to save annotation:\n{str(e)}")
@@ -374,7 +340,9 @@ class AnnotationTab(QWidget):
if not self.current_image_id:
return
logger.debug(f"Class color changed; reloading annotations for image ID {self.current_image_id}")
logger.debug(
f"Class color changed; reloading annotations for image ID {self.current_image_id}"
)
self._load_annotations_for_current_image()
def _on_class_selected(self, class_data):
@@ -387,7 +355,9 @@ class AnnotationTab(QWidget):
if class_data:
logger.debug(f"Object class selected: {class_data['class_name']}")
else:
logger.debug('No class selected ("-- Select Class --"), showing all annotations')
logger.debug(
'No class selected ("-- Select Class --"), showing all annotations'
)
# Changing the class filter invalidates any previous selection
self.selected_annotation_ids = []
@@ -420,7 +390,9 @@ class AnnotationTab(QWidget):
question = "Are you sure you want to delete the selected annotation?"
title = "Delete Annotation"
else:
question = f"Are you sure you want to delete the {count} selected annotations?"
question = (
f"Are you sure you want to delete the {count} selected annotations?"
)
title = "Delete Annotations"
reply = QMessageBox.question(
@@ -448,11 +420,13 @@ class AnnotationTab(QWidget):
QMessageBox.warning(
self,
"Partial Failure",
"Some annotations could not be deleted:\n" + ", ".join(str(a) for a in failed_ids),
"Some annotations could not be deleted:\n"
+ ", ".join(str(a) for a in failed_ids),
)
else:
logger.info(
f"Deleted {count} annotation(s): " + ", ".join(str(a) for a in self.selected_annotation_ids)
f"Deleted {count} annotation(s): "
+ ", ".join(str(a) for a in self.selected_annotation_ids)
)
# Clear selection and reload annotations for the current image from DB
@@ -460,9 +434,6 @@ class AnnotationTab(QWidget):
self.annotation_tools.set_has_selected_annotation(False)
self._load_annotations_for_current_image()
# Update list counts.
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
except Exception as e:
logger.error(f"Failed to delete annotations: {e}")
QMessageBox.critical(
@@ -485,13 +456,17 @@ class AnnotationTab(QWidget):
return
try:
self.current_annotations = self.db_manager.get_annotations_for_image(self.current_image_id)
self.current_annotations = self.db_manager.get_annotations_for_image(
self.current_image_id
)
# New annotations loaded; reset any selection
self.selected_annotation_ids = []
self.annotation_tools.set_has_selected_annotation(False)
self._redraw_annotations_for_current_filter()
except Exception as e:
logger.error(f"Failed to load annotations for image {self.current_image_id}: {e}")
logger.error(
f"Failed to load annotations for image {self.current_image_id}: {e}"
)
QMessageBox.critical(
self,
"Error",
@@ -515,7 +490,10 @@ class AnnotationTab(QWidget):
drawn_count = 0
for ann in self.current_annotations:
# Filter by class if one is selected
if selected_class_id is not None and ann.get("class_id") != selected_class_id:
if (
selected_class_id is not None
and ann.get("class_id") != selected_class_id
):
continue
if ann.get("segmentation_mask"):
@@ -567,176 +545,22 @@ class AnnotationTab(QWidget):
settings = QSettings("microscopy_app", "object_detection")
# Save main splitter state
settings.setValue("annotation_tab/main_splitter_state", self.main_splitter.saveState())
settings.setValue(
"annotation_tab/main_splitter_state", self.main_splitter.saveState()
)
# Save left splitter state
settings.setValue("annotation_tab/left_splitter_state", self.left_splitter.saveState())
settings.setValue(
"annotation_tab/left_splitter_state", self.left_splitter.saveState()
)
# Save right splitter state
settings.setValue("annotation_tab/right_splitter_state", self.right_splitter.saveState())
settings.setValue(
"annotation_tab/right_splitter_state", self.right_splitter.saveState()
)
logger.debug("Saved annotation tab splitter states")
def refresh(self):
"""Refresh the tab."""
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
# ==================== Annotated images list ====================
def _refresh_annotated_images_list(self, select_image_id: int | None = None) -> None:
"""Reload annotated-images list from the database."""
if not hasattr(self, "annotated_images_table"):
return
# Preserve selection if possible
desired_id = select_image_id if select_image_id is not None else self.current_image_id
name_filter = ""
if hasattr(self, "annotated_filter_edit"):
name_filter = self.annotated_filter_edit.text().strip()
try:
rows = self.db_manager.get_annotated_images_summary(name_filter=name_filter)
except Exception as exc:
logger.error(f"Failed to load annotated images summary: {exc}")
rows = []
sorting_enabled = self.annotated_images_table.isSortingEnabled()
self.annotated_images_table.setSortingEnabled(False)
self.annotated_images_table.blockSignals(True)
try:
self.annotated_images_table.setRowCount(len(rows))
for r, entry in enumerate(rows):
image_name = str(entry.get("filename") or "")
count = int(entry.get("annotation_count") or 0)
rel_path = str(entry.get("relative_path") or "")
name_item = QTableWidgetItem(image_name)
# Tooltip shows full path of the image (best-effort: repository_root + relative_path)
full_path = rel_path
repo_root = self.config_manager.get_image_repository_path()
if repo_root and rel_path and not Path(rel_path).is_absolute():
try:
full_path = str((Path(repo_root) / rel_path).resolve())
except Exception:
full_path = str(Path(repo_root) / rel_path)
name_item.setToolTip(full_path)
name_item.setData(Qt.UserRole, int(entry.get("id")))
name_item.setData(Qt.UserRole + 1, rel_path)
count_item = QTableWidgetItem()
# Use EditRole to ensure numeric sorting.
count_item.setData(Qt.EditRole, count)
count_item.setData(Qt.UserRole, int(entry.get("id")))
count_item.setData(Qt.UserRole + 1, rel_path)
self.annotated_images_table.setItem(r, 0, name_item)
self.annotated_images_table.setItem(r, 1, count_item)
# Re-select desired row
if desired_id is not None:
for r in range(self.annotated_images_table.rowCount()):
item = self.annotated_images_table.item(r, 0)
if item and item.data(Qt.UserRole) == desired_id:
self.annotated_images_table.selectRow(r)
break
finally:
self.annotated_images_table.blockSignals(False)
self.annotated_images_table.setSortingEnabled(sorting_enabled)
def _on_annotated_image_selected(self) -> None:
"""When user clicks an item in the list, load that image in the annotation canvas."""
selected = self.annotated_images_table.selectedItems()
if not selected:
return
# Row selection -> take the first column item
row = self.annotated_images_table.currentRow()
item = self.annotated_images_table.item(row, 0)
if not item:
return
image_id = item.data(Qt.UserRole)
rel_path = item.data(Qt.UserRole + 1) or ""
if not image_id:
return
image_path = self._resolve_image_path_for_relative_path(rel_path)
if not image_path:
QMessageBox.warning(
self,
"Image Not Found",
"Unable to locate image on disk for:\n"
f"{rel_path}\n\n"
"Tip: set Settings → Image repository path to the folder containing your images.",
)
return
try:
self.current_image = Image(image_path)
self.current_image_path = image_path
self.current_image_id = int(image_id)
self.annotation_canvas.load_image(self.current_image)
self._load_annotations_for_current_image()
self._update_image_info()
except ImageLoadError as exc:
logger.error(f"Failed to load image '{image_path}': {exc}")
QMessageBox.critical(self, "Error Loading Image", f"Failed to load image:\n{exc}")
except Exception as exc:
logger.error(f"Unexpected error loading image '{image_path}': {exc}")
QMessageBox.critical(self, "Error", f"Unexpected error:\n{exc}")
def _resolve_image_path_for_relative_path(self, relative_path: str) -> str | None:
"""Best-effort conversion from a DB relative_path to an on-disk file path."""
rel = (relative_path or "").strip()
if not rel:
return None
candidates: list[Path] = []
# 1) Repository root + relative
repo_root = (self.config_manager.get_image_repository_path() or "").strip()
if repo_root:
candidates.append(Path(repo_root) / rel)
# 2) If the DB path is absolute, try it directly.
candidates.append(Path(rel))
# 3) Try the directory of the currently loaded image (helps when DB stores only filenames)
if self.current_image_path:
try:
candidates.append(Path(self.current_image_path).expanduser().resolve().parent / Path(rel).name)
except Exception:
pass
# 4) Try the last directory used by the annotation file picker
try:
settings = QSettings("microscopy_app", "object_detection")
last_dir = settings.value("annotation_tab/last_directory", None)
if last_dir:
candidates.append(Path(str(last_dir)) / Path(rel).name)
except Exception:
pass
for p in candidates:
try:
expanded = p.expanduser()
if expanded.exists() and expanded.is_file():
return str(expanded.resolve())
except Exception:
continue
# 5) Fallback: search by filename within repository root.
filename = Path(rel).name
if repo_root and filename:
root = Path(repo_root).expanduser()
try:
if root.exists():
for match in root.rglob(filename):
if match.is_file():
return str(match.resolve())
except Exception as exc:
logger.debug(f"Search for {filename} under {root} failed: {exc}")
return None

View File

@@ -3,7 +3,7 @@ Results tab for browsing stored detections and visualizing overlays.
"""
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional
from PySide6.QtWidgets import (
QWidget,
@@ -35,7 +35,9 @@ logger = get_logger(__name__)
class ResultsTab(QWidget):
"""Results tab showing detection history and preview overlays."""
def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
def __init__(
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
):
super().__init__(parent)
self.db_manager = db_manager
self.config_manager = config_manager
@@ -65,32 +67,28 @@ class ResultsTab(QWidget):
self.refresh_btn = QPushButton("Refresh")
self.refresh_btn.clicked.connect(self.refresh)
controls_layout.addWidget(self.refresh_btn)
self.delete_all_btn = QPushButton("Delete All Detections")
self.delete_all_btn.setToolTip(
"Permanently delete ALL detections from the database.\n" "This cannot be undone."
)
self.delete_all_btn.clicked.connect(self._delete_all_detections)
controls_layout.addWidget(self.delete_all_btn)
self.export_labels_btn = QPushButton("Export Labels")
self.export_labels_btn.setToolTip(
"Export YOLO .txt labels for the selected image/model run.\n"
"Output path is inferred from the image path (images/ -> labels/)."
)
self.export_labels_btn.clicked.connect(self._export_labels_for_current_selection)
controls_layout.addWidget(self.export_labels_btn)
controls_layout.addStretch()
left_layout.addLayout(controls_layout)
self.results_table = QTableWidget(0, 5)
self.results_table.setHorizontalHeaderLabels(["Image", "Model", "Detections", "Classes", "Last Updated"])
self.results_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Stretch)
self.results_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.Stretch)
self.results_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeToContents)
self.results_table.horizontalHeader().setSectionResizeMode(3, QHeaderView.Stretch)
self.results_table.horizontalHeader().setSectionResizeMode(4, QHeaderView.ResizeToContents)
self.results_table.setHorizontalHeaderLabels(
["Image", "Model", "Detections", "Classes", "Last Updated"]
)
self.results_table.horizontalHeader().setSectionResizeMode(
0, QHeaderView.Stretch
)
self.results_table.horizontalHeader().setSectionResizeMode(
1, QHeaderView.Stretch
)
self.results_table.horizontalHeader().setSectionResizeMode(
2, QHeaderView.ResizeToContents
)
self.results_table.horizontalHeader().setSectionResizeMode(
3, QHeaderView.Stretch
)
self.results_table.horizontalHeader().setSectionResizeMode(
4, QHeaderView.ResizeToContents
)
self.results_table.setSelectionBehavior(QAbstractItemView.SelectRows)
self.results_table.setSelectionMode(QAbstractItemView.SingleSelection)
self.results_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
@@ -108,8 +106,6 @@ class ResultsTab(QWidget):
preview_layout = QVBoxLayout()
self.preview_canvas = AnnotationCanvasWidget()
# Auto-zoom so newly loaded images fill the available preview viewport.
self.preview_canvas.set_auto_fit_to_view(True)
self.preview_canvas.set_polyline_enabled(False)
self.preview_canvas.set_show_bboxes(True)
preview_layout.addWidget(self.preview_canvas)
@@ -123,7 +119,9 @@ class ResultsTab(QWidget):
self.show_bboxes_checkbox.stateChanged.connect(self._toggle_bboxes)
self.show_confidence_checkbox = QCheckBox("Show Confidence")
self.show_confidence_checkbox.setChecked(False)
self.show_confidence_checkbox.stateChanged.connect(self._apply_detection_overlays)
self.show_confidence_checkbox.stateChanged.connect(
self._apply_detection_overlays
)
toggles_layout.addWidget(self.show_masks_checkbox)
toggles_layout.addWidget(self.show_bboxes_checkbox)
toggles_layout.addWidget(self.show_confidence_checkbox)
@@ -146,41 +144,6 @@ class ResultsTab(QWidget):
layout.addWidget(splitter)
self.setLayout(layout)
def _delete_all_detections(self):
"""Delete all detections from the database after user confirmation."""
confirm = QMessageBox.warning(
self,
"Delete All Detections",
"This will permanently delete ALL detections from the database.\n\n"
"This action cannot be undone.\n\n"
"Do you want to continue?",
QMessageBox.Yes | QMessageBox.No,
QMessageBox.No,
)
if confirm != QMessageBox.Yes:
return
try:
deleted = self.db_manager.delete_all_detections()
except Exception as exc:
logger.error(f"Failed to delete all detections: {exc}")
QMessageBox.critical(
self,
"Error",
f"Failed to delete detections:\n{exc}",
)
return
QMessageBox.information(
self,
"Delete All Detections",
f"Deleted {deleted} detection(s) from the database.",
)
# Reset UI state.
self.refresh()
def refresh(self):
"""Refresh the detection list and preview."""
self._load_detection_summary()
@@ -190,8 +153,6 @@ class ResultsTab(QWidget):
self.current_detections = []
self.preview_canvas.clear()
self.summary_label.setText("Select a detection result to preview.")
if hasattr(self, "export_labels_btn"):
self.export_labels_btn.setEnabled(False)
def _load_detection_summary(self):
"""Load latest detection summaries grouped by image + model."""
@@ -208,7 +169,8 @@ class ResultsTab(QWidget):
"image_id": det["image_id"],
"model_id": det["model_id"],
"image_path": det.get("image_path"),
"image_filename": det.get("image_filename") or det.get("image_path"),
"image_filename": det.get("image_filename")
or det.get("image_path"),
"model_name": det.get("model_name", ""),
"model_version": det.get("model_version", ""),
"last_detected": det.get("detected_at"),
@@ -221,7 +183,8 @@ class ResultsTab(QWidget):
entry["count"] += 1
if det.get("detected_at") and (
not entry.get("last_detected") or str(det.get("detected_at")) > str(entry.get("last_detected"))
not entry.get("last_detected")
or str(det.get("detected_at")) > str(entry.get("last_detected"))
):
entry["last_detected"] = det.get("detected_at")
if det.get("class_name"):
@@ -251,7 +214,9 @@ class ResultsTab(QWidget):
for row, entry in enumerate(self.detection_summary):
model_label = f"{entry['model_name']} {entry['model_version']}".strip()
class_list = ", ".join(sorted(entry["classes"])) if entry["classes"] else "-"
class_list = (
", ".join(sorted(entry["classes"])) if entry["classes"] else "-"
)
items = [
QTableWidgetItem(entry.get("image_filename", "")),
@@ -311,231 +276,6 @@ class ResultsTab(QWidget):
self._load_detections_for_selection(entry)
self._apply_detection_overlays()
self._update_summary_label(entry)
if hasattr(self, "export_labels_btn"):
self.export_labels_btn.setEnabled(True)
def _export_labels_for_current_selection(self):
"""Export YOLO label file(s) for the currently selected image/model."""
if not self.current_selection:
QMessageBox.information(self, "Export Labels", "Select a detection result first.")
return
entry = self.current_selection
image_path_str = self._resolve_image_path(entry)
if not image_path_str:
QMessageBox.warning(
self,
"Export Labels",
"Unable to locate the image file for this detection; cannot infer labels path.",
)
return
# Ensure we have the detections for the selection.
if not self.current_detections:
self._load_detections_for_selection(entry)
if not self.current_detections:
QMessageBox.information(
self,
"Export Labels",
"No detections found for this image/model selection.",
)
return
image_path = Path(image_path_str)
try:
label_path = self._infer_yolo_label_path(image_path)
except Exception as exc:
logger.error(f"Failed to infer label path for {image_path}: {exc}")
QMessageBox.critical(
self,
"Export Labels",
f"Failed to infer export path for labels:\n{exc}",
)
return
class_map = self._build_detection_class_index_map(self.current_detections)
if not class_map:
QMessageBox.warning(
self,
"Export Labels",
"Unable to build class->index mapping (missing class names).",
)
return
lines_written = 0
skipped = 0
label_path.parent.mkdir(parents=True, exist_ok=True)
try:
with open(label_path, "w", encoding="utf-8") as handle:
print("writing to", label_path)
for det in self.current_detections:
yolo_line = self._format_detection_as_yolo_line(det, class_map)
if not yolo_line:
skipped += 1
continue
handle.write(yolo_line + "\n")
lines_written += 1
except OSError as exc:
logger.error(f"Failed to write labels file {label_path}: {exc}")
QMessageBox.critical(
self,
"Export Labels",
f"Failed to write label file:\n{label_path}\n\n{exc}",
)
return
return
# Optional: write a classes.txt next to the labels root to make the mapping discoverable.
# This is not required by Ultralytics (data.yaml usually holds class names), but helps reuse.
try:
classes_txt = label_path.parent.parent / "classes.txt"
classes_txt.parent.mkdir(parents=True, exist_ok=True)
inv = {idx: name for name, idx in class_map.items()}
with open(classes_txt, "w", encoding="utf-8") as handle:
for idx in range(len(inv)):
handle.write(f"{inv[idx]}\n")
except Exception:
# Non-fatal
pass
QMessageBox.information(
self,
"Export Labels",
f"Exported {lines_written} label line(s) to:\n{label_path}\n\nSkipped {skipped} invalid detection(s).",
)
def _infer_yolo_label_path(self, image_path: Path) -> Path:
"""Infer a YOLO label path from an image path.
If the image lives under an `images/` directory (anywhere in the path), we mirror the
subpath under a sibling `labels/` directory at the same level.
Example:
/dataset/train/images/sub/img.jpg -> /dataset/train/labels/sub/img.txt
"""
resolved = image_path.expanduser().resolve()
# Find the nearest ancestor directory named 'images'
images_dir: Optional[Path] = None
for parent in [resolved.parent, *resolved.parents]:
if parent.name.lower() == "images":
images_dir = parent
break
if images_dir is not None:
rel = resolved.relative_to(images_dir)
labels_dir = images_dir.parent / "labels"
return (labels_dir / rel).with_suffix(".txt")
# Fallback: create a local sibling labels folder next to the image.
return (resolved.parent / "labels" / resolved.name).with_suffix(".txt")
def _build_detection_class_index_map(self, detections: List[Dict]) -> Dict[str, int]:
"""Build a stable class_name -> YOLO class index mapping.
Preference order:
1) Database object_classes table (alphabetical class_name order)
2) Fallback to class_name values present in the detections (alphabetical)
"""
names: List[str] = []
try:
db_classes = self.db_manager.get_object_classes() or []
names = [str(row.get("class_name")) for row in db_classes if row.get("class_name")]
except Exception:
names = []
if not names:
observed = sorted({str(det.get("class_name")) for det in detections if det.get("class_name")})
names = list(observed)
return {name: idx for idx, name in enumerate(names)}
def _format_detection_as_yolo_line(self, det: Dict, class_map: Dict[str, int]) -> Optional[str]:
"""Convert a detection row to a YOLO label line.
- If segmentation_mask is present, exports segmentation polygon format:
class x1 y1 x2 y2 ...
(normalized coordinates)
- Otherwise exports bbox format:
class x_center y_center width height
(normalized coordinates)
"""
class_name = det.get("class_name")
if not class_name or class_name not in class_map:
return None
class_idx = class_map[class_name]
mask = det.get("segmentation_mask")
polygon = self._convert_segmentation_mask_to_polygon(mask)
if polygon:
coords = " ".join(f"{value:.6f}" for value in polygon)
return f"{class_idx} {coords}".strip()
bbox = self._convert_bbox_to_yolo_xywh(det)
if bbox is None:
return None
x_center, y_center, width, height = bbox
return f"{class_idx} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
def _convert_bbox_to_yolo_xywh(self, det: Dict) -> Optional[Tuple[float, float, float, float]]:
"""Convert stored xyxy (normalized) bbox to YOLO xywh (normalized)."""
x_min = det.get("x_min")
y_min = det.get("y_min")
x_max = det.get("x_max")
y_max = det.get("y_max")
if any(v is None for v in (x_min, y_min, x_max, y_max)):
return None
try:
x_min_f = self._clamp01(float(x_min))
y_min_f = self._clamp01(float(y_min))
x_max_f = self._clamp01(float(x_max))
y_max_f = self._clamp01(float(y_max))
except (TypeError, ValueError):
return None
width = max(0.0, x_max_f - x_min_f)
height = max(0.0, y_max_f - y_min_f)
if width <= 0.0 or height <= 0.0:
return None
x_center = x_min_f + width / 2.0
y_center = y_min_f + height / 2.0
return x_center, y_center, width, height
def _convert_segmentation_mask_to_polygon(self, mask_data) -> List[float]:
"""Convert stored segmentation_mask [[x,y], ...] to YOLO polygon coords [x1,y1,...]."""
if not isinstance(mask_data, list):
return []
coords: List[float] = []
for point in mask_data:
if not isinstance(point, (list, tuple)) or len(point) < 2:
continue
try:
x = self._clamp01(float(point[0]))
y = self._clamp01(float(point[1]))
except (TypeError, ValueError):
continue
coords.extend([x, y])
# Need at least 3 points => 6 values.
return coords if len(coords) >= 6 else []
@staticmethod
def _clamp01(value: float) -> float:
if value < 0.0:
return 0.0
if value > 1.0:
return 1.0
return value
def _load_detections_for_selection(self, entry: Dict):
"""Load detection records for the selected image/model pair."""

View File

@@ -3,14 +3,12 @@ Training tab for the microscopy object detection application.
Handles model training with YOLO.
"""
import hashlib
import shutil
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import yaml
import numpy as np
import yaml
from PySide6.QtCore import Qt, QThread, Signal
from PySide6.QtWidgets import (
QWidget,
@@ -92,7 +90,10 @@ class TrainingWorker(QThread):
},
}
]
computed_total = sum(max(0, int((stage.get("params") or {}).get("epochs", 0))) for stage in self.stage_plan)
computed_total = sum(
max(0, int((stage.get("params") or {}).get("epochs", 0)))
for stage in self.stage_plan
)
self.total_epochs = total_epochs if total_epochs else computed_total or epochs
self._stop_requested = False
@@ -199,7 +200,9 @@ class TrainingWorker(QThread):
class TrainingTab(QWidget):
"""Training tab for model training."""
def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
def __init__(
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
):
super().__init__(parent)
self.db_manager = db_manager
self.config_manager = config_manager
@@ -333,14 +336,18 @@ class TrainingTab(QWidget):
self.model_version_edit = QLineEdit("v1")
form_layout.addRow("Version:", self.model_version_edit)
default_base_model = self.config_manager.get("models.default_base_model", "yolov8s-seg.pt")
default_base_model = self.config_manager.get(
"models.default_base_model", "yolov8s-seg.pt"
)
base_model_choices = self.config_manager.get("models.base_model_choices", [])
self.base_model_combo = QComboBox()
self.base_model_combo.addItem("Custom path…", "")
for choice in base_model_choices:
self.base_model_combo.addItem(choice, choice)
self.base_model_combo.currentIndexChanged.connect(self._on_base_model_preset_changed)
self.base_model_combo.currentIndexChanged.connect(
self._on_base_model_preset_changed
)
form_layout.addRow("Base Model Preset:", self.base_model_combo)
base_model_layout = QHBoxLayout()
@@ -426,8 +433,12 @@ class TrainingTab(QWidget):
group_layout = QVBoxLayout()
self.two_stage_checkbox = QCheckBox("Enable staged head-only + full fine-tune")
two_stage_defaults = training_defaults.get("two_stage", {}) if training_defaults else {}
self.two_stage_checkbox.setChecked(bool(two_stage_defaults.get("enabled", False)))
two_stage_defaults = (
training_defaults.get("two_stage", {}) if training_defaults else {}
)
self.two_stage_checkbox.setChecked(
bool(two_stage_defaults.get("enabled", False))
)
self.two_stage_checkbox.toggled.connect(self._on_two_stage_toggled)
group_layout.addWidget(self.two_stage_checkbox)
@@ -489,7 +500,9 @@ class TrainingTab(QWidget):
stage2_group.setLayout(stage2_form)
controls_layout.addWidget(stage2_group)
helper_label = QLabel("When enabled, staged hyperparameters override the global epochs/patience/lr.")
helper_label = QLabel(
"When enabled, staged hyperparameters override the global epochs/patience/lr."
)
helper_label.setWordWrap(True)
controls_layout.addWidget(helper_label)
@@ -534,7 +547,9 @@ class TrainingTab(QWidget):
if normalized == preset_value:
target_index = idx
break
if normalized.endswith(f"/{preset_value}") or normalized.endswith(f"\\{preset_value}"):
if normalized.endswith(f"/{preset_value}") or normalized.endswith(
f"\\{preset_value}"
):
target_index = idx
break
self.base_model_combo.blockSignals(True)
@@ -622,7 +637,9 @@ class TrainingTab(QWidget):
def _browse_dataset(self):
"""Open a file dialog to manually select data.yaml."""
start_dir = self.config_manager.get("training.last_dataset_dir", "data/datasets")
start_dir = self.config_manager.get(
"training.last_dataset_dir", "data/datasets"
)
start_path = Path(start_dir).expanduser()
if not start_path.exists():
start_path = Path.cwd()
@@ -658,7 +675,9 @@ class TrainingTab(QWidget):
return
except Exception as exc:
logger.exception("Unexpected error while generating data.yaml")
self._display_dataset_error("Unexpected error while generating data.yaml. Check logs for details.")
self._display_dataset_error(
"Unexpected error while generating data.yaml. Check logs for details."
)
QMessageBox.critical(
self,
"data.yaml Generation Failed",
@@ -735,9 +754,13 @@ class TrainingTab(QWidget):
self.selected_dataset = info
self.dataset_root_label.setText(info["root"]) # type: ignore[arg-type]
self.train_count_label.setText(self._format_split_info(info["splits"].get("train")))
self.train_count_label.setText(
self._format_split_info(info["splits"].get("train"))
)
self.val_count_label.setText(self._format_split_info(info["splits"].get("val")))
self.test_count_label.setText(self._format_split_info(info["splits"].get("test")))
self.test_count_label.setText(
self._format_split_info(info["splits"].get("test"))
)
self.num_classes_label.setText(str(info["num_classes"]))
class_names = ", ".join(info["class_names"]) or ""
self.class_names_label.setText(class_names)
@@ -791,12 +814,18 @@ class TrainingTab(QWidget):
if split_path.exists():
split_info["count"] = self._count_images(split_path)
if split_info["count"] == 0:
warnings.append(f"No images found for {split_name} split at {split_path}")
warnings.append(
f"No images found for {split_name} split at {split_path}"
)
else:
warnings.append(f"{split_name.capitalize()} path does not exist: {split_path}")
warnings.append(
f"{split_name.capitalize()} path does not exist: {split_path}"
)
else:
if split_name in ("train", "val"):
warnings.append(f"{split_name.capitalize()} split missing in data.yaml")
warnings.append(
f"{split_name.capitalize()} split missing in data.yaml"
)
splits[split_name] = split_info
names_list = self._normalize_class_names(data.get("names"))
@@ -814,7 +843,9 @@ class TrainingTab(QWidget):
if not names_list and nc_value:
names_list = [f"class_{idx}" for idx in range(int(nc_value))]
elif nc_value and len(names_list) not in (0, int(nc_value)):
warnings.append(f"Number of class names ({len(names_list)}) does not match nc={nc_value}")
warnings.append(
f"Number of class names ({len(names_list)}) does not match nc={nc_value}"
)
dataset_name = data.get("name") or base_path.name
@@ -866,12 +897,16 @@ class TrainingTab(QWidget):
class_index_map = self._build_class_index_map(dataset_info)
if not class_index_map:
self._append_training_log("Skipping label export: dataset classes do not match database entries.")
self._append_training_log(
"Skipping label export: dataset classes do not match database entries."
)
return
dataset_root_str = dataset_info.get("root")
dataset_yaml_path = dataset_info.get("yaml_path")
dataset_yaml = Path(dataset_yaml_path).expanduser() if dataset_yaml_path else None
dataset_yaml = (
Path(dataset_yaml_path).expanduser() if dataset_yaml_path else None
)
dataset_root: Optional[Path]
if dataset_root_str:
dataset_root = Path(dataset_root_str).resolve()
@@ -905,17 +940,12 @@ class TrainingTab(QWidget):
if stats["registered_images"]:
message += f" {stats['registered_images']} image(s) had database-backed annotations."
if stats["missing_records"]:
message += (
f" {stats['missing_records']} image(s) had no database entry; empty label files were written."
)
message += f" {stats['missing_records']} image(s) had no database entry; empty label files were written."
split_messages.append(message)
for msg in split_messages:
self._append_training_log(msg)
if dataset_yaml:
self._clear_rgb_cache_for_dataset(dataset_yaml)
def _export_labels_for_split(
self,
split_name: str,
@@ -939,7 +969,9 @@ class TrainingTab(QWidget):
continue
processed_images += 1
label_path = (labels_dir / image_file.relative_to(images_dir)).with_suffix(".txt")
label_path = (labels_dir / image_file.relative_to(images_dir)).with_suffix(
".txt"
)
label_path.parent.mkdir(parents=True, exist_ok=True)
found, annotation_entries = self._fetch_annotations_for_image(
@@ -955,23 +987,25 @@ class TrainingTab(QWidget):
for entry in annotation_entries:
polygon = entry.get("polygon") or []
if polygon:
print(image_file, polygon[:4], polygon[-2:], entry.get("bbox"))
# coords = " ".join(f"{value:.6f}" for value in entry.get("bbox"))
# coords += " "
coords = " ".join(f"{value:.6f}" for value in polygon)
handle.write(f"{entry['class_idx']} {coords}\n")
annotations_written += 1
elif entry.get("bbox"):
x_center, y_center, width, height = entry["bbox"]
handle.write(f"{entry['class_idx']} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n")
handle.write(
f"{entry['class_idx']} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n"
)
annotations_written += 1
total_annotations += annotations_written
cache_reset_root = labels_dir.parent
self._invalidate_split_cache(cache_reset_root)
if processed_images == 0:
self._append_training_log(f"[{split_name}] No images found to export labels for.")
self._append_training_log(
f"[{split_name}] No images found to export labels for."
)
return None
return {
@@ -1097,10 +1131,6 @@ class TrainingTab(QWidget):
xs.append(x_val)
ys.append(y_val)
if any(np.abs(np.array(coords[:2]) - np.array(coords[-2:])) < 1e-5):
print("Closing polygon")
coords.extend(coords[:2])
if len(coords) < 6:
continue
@@ -1113,11 +1143,6 @@ class TrainingTab(QWidget):
+ abs((min(ys) if ys else 0.0) - y_min)
+ abs((max(ys) if ys else 0.0) - y_max)
)
width = max(0.0, x_max - x_min)
height = max(0.0, y_max - y_min)
x_center = x_min + width / 2.0
y_center = y_min + height / 2.0
score = (x_center, y_center, width, height)
candidates.append((score, coords))
@@ -1135,35 +1160,6 @@ class TrainingTab(QWidget):
return 1.0
return value
def _prepare_dataset_for_training(self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None) -> Path:
dataset_info = dataset_info or (
self.selected_dataset
if self.selected_dataset and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
else self._parse_dataset_yaml(dataset_yaml)
)
train_split = dataset_info.get("splits", {}).get("train") or {}
images_path_str = train_split.get("path")
if not images_path_str:
return dataset_yaml
images_path = Path(images_path_str)
if not images_path.exists():
return dataset_yaml
if not self._dataset_requires_rgb_conversion(images_path):
return dataset_yaml
cache_root = self._get_rgb_cache_root(dataset_yaml)
rgb_yaml = cache_root / "data.yaml"
if rgb_yaml.exists():
self._append_training_log(f"Detected grayscale dataset; reusing RGB cache at {cache_root}")
return rgb_yaml
self._append_training_log(f"Detected grayscale dataset; creating RGB cache at {cache_root}")
self._build_rgb_dataset(cache_root, dataset_info)
return rgb_yaml
def _compose_stage_plan(self, params: Dict[str, Any]) -> List[Dict[str, Any]]:
two_stage = params.get("two_stage") or {}
base_stage = {
@@ -1248,113 +1244,6 @@ class TrainingTab(QWidget):
f"{stage_label}: epochs={epochs}, lr0={lr0}, patience={patience}, freeze={freeze}"
)
def _get_rgb_cache_root(self, dataset_yaml: Path) -> Path:
cache_base = Path("data/datasets/_rgb_cache")
cache_base.mkdir(parents=True, exist_ok=True)
key = hashlib.md5(str(dataset_yaml.parent.resolve()).encode()).hexdigest()[:8]
return cache_base / f"{dataset_yaml.parent.name}_{key}"
def _clear_rgb_cache_for_dataset(self, dataset_yaml: Path):
cache_root = self._get_rgb_cache_root(dataset_yaml)
if cache_root.exists():
try:
shutil.rmtree(cache_root)
logger.debug(f"Removed RGB cache at {cache_root}")
except OSError as exc:
logger.warning(f"Failed to remove RGB cache {cache_root}: {exc}")
def _dataset_requires_rgb_conversion(self, images_dir: Path) -> bool:
sample_image = self._find_first_image(images_dir)
if not sample_image:
return False
# Do not force an RGB cache for TIFF datasets.
# We handle grayscale/16-bit TIFFs via runtime Ultralytics patches that:
# - load TIFFs with `tifffile`
# - replicate grayscale to 3 channels without quantization
# - normalize uint16 correctly during training
if sample_image.suffix.lower() in {".tif", ".tiff"}:
return False
try:
img = Image(sample_image)
return img.pil_image.mode.upper() != "RGB"
except Exception as exc:
logger.warning(f"Failed to inspect image {sample_image}: {exc}")
return False
def _find_first_image(self, directory: Path) -> Optional[Path]:
if not directory.exists():
return None
for path in directory.rglob("*"):
if path.is_file() and path.suffix.lower() in self.allowed_extensions:
return path
return None
def _build_rgb_dataset(self, cache_root: Path, dataset_info: Dict[str, Any]):
if cache_root.exists():
shutil.rmtree(cache_root)
cache_root.mkdir(parents=True, exist_ok=True)
splits = dataset_info.get("splits", {})
for split_name in ("train", "val", "test"):
split_entry = splits.get(split_name)
if not split_entry:
continue
images_src = Path(split_entry.get("path", ""))
if not images_src.exists():
continue
images_dst = cache_root / split_name / "images"
self._convert_images_to_rgb(images_src, images_dst)
labels_src = self._infer_labels_dir(images_src)
if labels_src.exists():
labels_dst = cache_root / split_name / "labels"
self._copy_labels(labels_src, labels_dst)
class_names = dataset_info.get("class_names") or []
names_map = {idx: name for idx, name in enumerate(class_names)}
num_classes = dataset_info.get("num_classes") or len(class_names)
yaml_payload: Dict[str, Any] = {
"path": cache_root.as_posix(),
"names": names_map,
"nc": num_classes,
}
for split_name in ("train", "val", "test"):
images_dir = cache_root / split_name / "images"
if images_dir.exists():
yaml_payload[split_name] = f"{split_name}/images"
with open(cache_root / "data.yaml", "w", encoding="utf-8") as handle:
yaml.safe_dump(yaml_payload, handle, sort_keys=False)
def _convert_images_to_rgb(self, src_dir: Path, dst_dir: Path):
for src in src_dir.rglob("*"):
if not src.is_file() or src.suffix.lower() not in self.allowed_extensions:
continue
relative = src.relative_to(src_dir)
dst = dst_dir / relative
dst.parent.mkdir(parents=True, exist_ok=True)
try:
img_obj = Image(src)
pil_img = img_obj.pil_image
if len(pil_img.getbands()) == 1:
rgb_img = img_obj.convert_grayscale_to_rgb_preserve_range()
else:
rgb_img = pil_img.convert("RGB")
rgb_img.save(dst)
except Exception as exc:
logger.warning(f"Failed to convert {src} to RGB: {exc}")
def _copy_labels(self, labels_src: Path, labels_dst: Path):
label_files = list(labels_src.rglob("*.txt"))
for label_file in label_files:
relative = label_file.relative_to(labels_src)
dst = labels_dst / relative
dst.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(label_file, dst)
def _infer_labels_dir(self, images_dir: Path) -> Path:
return images_dir.parent / "labels"
@@ -1427,26 +1316,31 @@ class TrainingTab(QWidget):
dataset_path = Path(dataset_yaml).expanduser()
if not dataset_path.exists():
QMessageBox.warning(self, "Invalid Dataset", "Selected data.yaml file does not exist.")
QMessageBox.warning(
self, "Invalid Dataset", "Selected data.yaml file does not exist."
)
return
dataset_info = (
self.selected_dataset
if self.selected_dataset and self.selected_dataset.get("yaml_path") == str(dataset_path)
if self.selected_dataset
and self.selected_dataset.get("yaml_path") == str(dataset_path)
else self._parse_dataset_yaml(dataset_path)
)
self.training_log.clear()
self._export_labels_from_database(dataset_info)
dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info)
if dataset_to_use != dataset_path:
self._append_training_log(f"Using RGB-converted dataset at {dataset_to_use.parent}")
self._append_training_log(
"Using Float32 on-the-fly loader for 16-bit TIFF support (no disk caching)"
)
params = self._collect_training_params()
stage_plan = self._compose_stage_plan(params)
params["stage_plan"] = stage_plan
total_planned_epochs = self._calculate_total_stage_epochs(stage_plan) or params["epochs"]
total_planned_epochs = (
self._calculate_total_stage_epochs(stage_plan) or params["epochs"]
)
params["total_planned_epochs"] = total_planned_epochs
self._active_training_params = params
self._training_cancelled = False
@@ -1455,7 +1349,9 @@ class TrainingTab(QWidget):
self._append_training_log("Two-stage fine-tuning schedule:")
self._log_stage_plan(stage_plan)
self._append_training_log(f"Starting training run '{params['run_name']}' using {params['base_model']}")
self._append_training_log(
f"Starting training run '{params['run_name']}' using {params['base_model']}"
)
self.training_progress_bar.setVisible(True)
self.training_progress_bar.setMaximum(max(1, total_planned_epochs))
@@ -1463,7 +1359,7 @@ class TrainingTab(QWidget):
self._set_training_state(True)
self.training_worker = TrainingWorker(
data_yaml=dataset_to_use.as_posix(),
data_yaml=dataset_path.as_posix(),
base_model=params["base_model"],
epochs=params["epochs"],
batch=params["batch"],
@@ -1483,7 +1379,9 @@ class TrainingTab(QWidget):
def _stop_training(self):
if self.training_worker and self.training_worker.isRunning():
self._training_cancelled = True
self._append_training_log("Stop requested. Waiting for the current epoch to finish...")
self._append_training_log(
"Stop requested. Waiting for the current epoch to finish..."
)
self.training_worker.stop()
self.stop_training_button.setEnabled(False)
@@ -1519,7 +1417,9 @@ class TrainingTab(QWidget):
if worker.isRunning():
if not worker.wait(wait_timeout_ms):
logger.warning("Training worker did not finish within %sms", wait_timeout_ms)
logger.warning(
"Training worker did not finish within %sms", wait_timeout_ms
)
worker.deleteLater()
@@ -1536,12 +1436,16 @@ class TrainingTab(QWidget):
self._set_training_state(False)
self.training_progress_bar.setVisible(False)
def _on_training_progress(self, current_epoch: int, total_epochs: int, metrics: Dict[str, Any]):
def _on_training_progress(
self, current_epoch: int, total_epochs: int, metrics: Dict[str, Any]
):
self.training_progress_bar.setMaximum(total_epochs)
self.training_progress_bar.setValue(current_epoch)
parts = [f"Epoch {current_epoch}/{total_epochs}"]
if metrics:
metric_text = ", ".join(f"{key}: {value:.4f}" for key, value in metrics.items())
metric_text = ", ".join(
f"{key}: {value:.4f}" for key, value in metrics.items()
)
parts.append(metric_text)
self._append_training_log(" | ".join(parts))
@@ -1568,7 +1472,9 @@ class TrainingTab(QWidget):
f"Model trained but not registered: {exc}",
)
else:
QMessageBox.information(self, "Training Complete", "Training finished successfully.")
QMessageBox.information(
self, "Training Complete", "Training finished successfully."
)
def _on_training_error(self, message: str):
self._cleanup_training_worker()
@@ -1614,7 +1520,9 @@ class TrainingTab(QWidget):
metrics=results.get("metrics"),
)
self._append_training_log(f"Registered model '{params['model_name']}' (ID {model_id}) at {model_path}")
self._append_training_log(
f"Registered model '{params['model_name']}' (ID {model_id}) at {model_path}"
)
self._active_training_params = None
def _set_training_state(self, is_training: bool):
@@ -1657,7 +1565,9 @@ class TrainingTab(QWidget):
def _browse_save_dir(self):
start_path = self.save_dir_edit.text().strip() or "data/models"
directory = QFileDialog.getExistingDirectory(self, "Select Save Directory", start_path)
directory = QFileDialog.getExistingDirectory(
self, "Select Save Directory", start_path
)
if directory:
self.save_dir_edit.setText(directory)

View File

@@ -2,554 +2,45 @@
Validation tab for the microscopy object detection application.
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple
from PySide6.QtCore import Qt, QSize
from PySide6.QtGui import QPainter, QPixmap
from PySide6.QtWidgets import (
QWidget,
QVBoxLayout,
QLabel,
QGroupBox,
QHBoxLayout,
QPushButton,
QComboBox,
QFormLayout,
QScrollArea,
QGridLayout,
QFrame,
QTableWidget,
QTableWidgetItem,
QHeaderView,
QSplitter,
QListWidget,
QListWidgetItem,
QAbstractItemView,
QGraphicsView,
QGraphicsScene,
QGraphicsPixmapItem,
)
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
from src.database.db_manager import DatabaseManager
from src.utils.config_manager import ConfigManager
from src.utils.logger import get_logger
logger = get_logger(__name__)
@dataclass(frozen=True)
class _PlotItem:
label: str
path: Path
class _ZoomableImageView(QGraphicsView):
"""Zoomable image viewer.
- Mouse wheel: zoom in/out
- Left mouse drag: pan (ScrollHandDrag)
"""
def __init__(self, parent: Optional[QWidget] = None):
super().__init__(parent)
self._scene = QGraphicsScene(self)
self.setScene(self._scene)
self._pixmap_item = QGraphicsPixmapItem()
self._scene.addItem(self._pixmap_item)
# QGraphicsView render hints are QPainter.RenderHints.
self.setRenderHints(self.renderHints() | QPainter.RenderHint.SmoothPixmapTransform)
self.setDragMode(QGraphicsView.DragMode.ScrollHandDrag)
self.setTransformationAnchor(QGraphicsView.ViewportAnchor.AnchorUnderMouse)
self.setResizeAnchor(QGraphicsView.ViewportAnchor.AnchorUnderMouse)
self._has_pixmap = False
def clear(self) -> None:
self._pixmap_item.setPixmap(QPixmap())
self._scene.setSceneRect(0, 0, 1, 1)
self.resetTransform()
self._has_pixmap = False
def set_pixmap(self, pixmap: QPixmap, *, fit: bool = True) -> None:
self._pixmap_item.setPixmap(pixmap)
self._scene.setSceneRect(pixmap.rect())
self._has_pixmap = not pixmap.isNull()
self.resetTransform()
if fit and self._has_pixmap:
self.fitInView(self._pixmap_item, Qt.AspectRatioMode.KeepAspectRatio)
def wheelEvent(self, event) -> None: # type: ignore[override]
if not self._has_pixmap:
return
zoom_in_factor = 1.25
zoom_out_factor = 1.0 / zoom_in_factor
factor = zoom_in_factor if event.angleDelta().y() > 0 else zoom_out_factor
self.scale(factor, factor)
class ValidationTab(QWidget):
"""Validation tab that shows stored validation metrics + plots for a selected model."""
"""Validation tab placeholder."""
def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
def __init__(
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
):
super().__init__(parent)
self.db_manager = db_manager
self.config_manager = config_manager
self._models: List[Dict[str, Any]] = []
self._selected_model_id: Optional[int] = None
self._plot_widgets: List[QWidget] = []
self._plot_items: List[_PlotItem] = []
self._setup_ui()
self.refresh()
def _setup_ui(self):
"""Setup user interface."""
layout = QVBoxLayout(self)
layout = QVBoxLayout()
# ===== Header controls =====
header = QGroupBox("Validation")
header_layout = QVBoxLayout()
header_row = QHBoxLayout()
group = QGroupBox("Validation")
group_layout = QVBoxLayout()
label = QLabel(
"Validation functionality will be implemented here.\n\n"
"Features:\n"
"- Model validation\n"
"- Metrics visualization\n"
"- Confusion matrix\n"
"- Precision-Recall curves"
)
group_layout.addWidget(label)
group.setLayout(group_layout)
header_row.addWidget(QLabel("Select model:"))
self.model_combo = QComboBox()
self.model_combo.setMinimumWidth(420)
self.model_combo.currentIndexChanged.connect(self._on_model_selected)
header_row.addWidget(self.model_combo, 1)
self.refresh_btn = QPushButton("Refresh")
self.refresh_btn.clicked.connect(self.refresh)
header_row.addWidget(self.refresh_btn)
header_row.addStretch()
header_layout.addLayout(header_row)
self.header_status = QLabel("No models loaded.")
self.header_status.setWordWrap(True)
header_layout.addWidget(self.header_status)
header.setLayout(header_layout)
layout.addWidget(header)
# ===== Metrics =====
metrics_group = QGroupBox("Validation Metrics")
metrics_layout = QVBoxLayout()
self.metrics_form = QFormLayout()
self.metric_labels: Dict[str, QLabel] = {}
for key in ("mAP50", "mAP50-95", "precision", "recall", "fitness"):
value_label = QLabel("")
value_label.setTextInteractionFlags(Qt.TextSelectableByMouse)
self.metric_labels[key] = value_label
self.metrics_form.addRow(f"{key}:", value_label)
metrics_layout.addLayout(self.metrics_form)
self.per_class_table = QTableWidget(0, 3)
self.per_class_table.setHorizontalHeaderLabels(["Class", "AP", "AP50"])
self.per_class_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Stretch)
self.per_class_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeToContents)
self.per_class_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeToContents)
self.per_class_table.setEditTriggers(QTableWidget.NoEditTriggers)
self.per_class_table.setMinimumHeight(160)
metrics_layout.addWidget(QLabel("Per-class metrics (if available):"))
metrics_layout.addWidget(self.per_class_table)
metrics_group.setLayout(metrics_layout)
layout.addWidget(metrics_group)
# ===== Plots =====
plots_group = QGroupBox("Validation Plots")
plots_layout = QVBoxLayout()
self.plots_status = QLabel("Select a model to see validation plots.")
self.plots_status.setWordWrap(True)
plots_layout.addWidget(self.plots_status)
self.plots_splitter = QSplitter(Qt.Orientation.Horizontal)
# Left: selected image viewer
left_widget = QWidget()
left_layout = QVBoxLayout(left_widget)
left_layout.setContentsMargins(0, 0, 0, 0)
self.selected_plot_title = QLabel("No image selected.")
self.selected_plot_title.setWordWrap(True)
self.selected_plot_title.setTextInteractionFlags(Qt.TextSelectableByMouse)
left_layout.addWidget(self.selected_plot_title)
self.plot_view = _ZoomableImageView()
self.plot_view.setMinimumHeight(360)
left_layout.addWidget(self.plot_view, 1)
self.selected_plot_path = QLabel("")
self.selected_plot_path.setWordWrap(True)
self.selected_plot_path.setStyleSheet("color: #888;")
self.selected_plot_path.setTextInteractionFlags(Qt.TextSelectableByMouse)
left_layout.addWidget(self.selected_plot_path)
# Right: scrollable list
right_widget = QWidget()
right_layout = QVBoxLayout(right_widget)
right_layout.setContentsMargins(0, 0, 0, 0)
right_layout.addWidget(QLabel("Images:"))
self.plots_list = QListWidget()
self.plots_list.setSelectionMode(QAbstractItemView.SelectionMode.SingleSelection)
self.plots_list.setIconSize(QSize(160, 160))
self.plots_list.itemSelectionChanged.connect(self._on_plot_item_selected)
right_layout.addWidget(self.plots_list, 1)
self.plots_splitter.addWidget(left_widget)
self.plots_splitter.addWidget(right_widget)
self.plots_splitter.setStretchFactor(0, 3)
self.plots_splitter.setStretchFactor(1, 1)
plots_layout.addWidget(self.plots_splitter, 1)
plots_group.setLayout(plots_layout)
layout.addWidget(plots_group, 1)
layout.addStretch(0)
self._clear_metrics()
self._clear_plots()
# ==================== Public API ====================
layout.addWidget(group)
layout.addStretch()
self.setLayout(layout)
def refresh(self):
"""Refresh the tab."""
self._load_models()
self._populate_model_combo()
self._restore_or_select_default_model()
# ==================== Internal: models ====================
def _load_models(self) -> None:
try:
self._models = self.db_manager.get_models() or []
except Exception as exc:
logger.error("Failed to load models: %s", exc)
self._models = []
def _populate_model_combo(self) -> None:
self.model_combo.blockSignals(True)
self.model_combo.clear()
self.model_combo.addItem("Select a model…", None)
for model in self._models:
model_id = model.get("id")
name = (model.get("model_name") or "").strip()
version = (model.get("model_version") or "").strip()
created_at = model.get("created_at")
label = f"{name} {version}".strip()
if created_at:
label = f"{label} ({created_at})"
self.model_combo.addItem(label, model_id)
self.model_combo.blockSignals(False)
if self._models:
self.header_status.setText(f"Loaded {len(self._models)} model(s).")
else:
self.header_status.setText("No models found. Train a model first.")
def _restore_or_select_default_model(self) -> None:
if not self._models:
self._selected_model_id = None
self._clear_metrics()
self._clear_plots()
return
# Keep selection if still present.
if self._selected_model_id is not None:
for idx in range(1, self.model_combo.count()):
if self.model_combo.itemData(idx) == self._selected_model_id:
self.model_combo.setCurrentIndex(idx)
return
# Otherwise select the newest model (top of get_models ORDER BY created_at DESC).
first_model_id = self.model_combo.itemData(1) if self.model_combo.count() > 1 else None
if first_model_id is not None:
self.model_combo.setCurrentIndex(1)
def _on_model_selected(self, index: int) -> None:
model_id = self.model_combo.itemData(index)
if not model_id:
self._selected_model_id = None
self._clear_metrics()
self._clear_plots()
self.plots_status.setText("Select a model to see validation plots.")
return
self._selected_model_id = int(model_id)
model = self._get_model_by_id(self._selected_model_id)
if not model:
self._clear_metrics()
self._clear_plots()
self.plots_status.setText("Selected model not found.")
return
self._render_metrics(model)
self._render_plots(model)
def _get_model_by_id(self, model_id: int) -> Optional[Dict[str, Any]]:
for model in self._models:
if model.get("id") == model_id:
return model
try:
return self.db_manager.get_model_by_id(model_id)
except Exception:
return None
# ==================== Internal: metrics ====================
def _clear_metrics(self) -> None:
for label in self.metric_labels.values():
label.setText("")
self.per_class_table.setRowCount(0)
def _render_metrics(self, model: Dict[str, Any]) -> None:
self._clear_metrics()
metrics: Dict[str, Any] = model.get("metrics") or {}
# Training tab stores metrics under results['metrics'] in training results payload.
if isinstance(metrics, dict) and "metrics" in metrics and isinstance(metrics.get("metrics"), dict):
metrics = metrics.get("metrics") or {}
def set_metric(key: str, value: Any) -> None:
if key not in self.metric_labels:
return
if value is None:
self.metric_labels[key].setText("")
return
try:
self.metric_labels[key].setText(f"{float(value):.4f}")
except Exception:
self.metric_labels[key].setText(str(value))
set_metric("mAP50", metrics.get("mAP50"))
set_metric("mAP50-95", metrics.get("mAP50-95") or metrics.get("mAP50_95") or metrics.get("mAP50-95"))
set_metric("precision", metrics.get("precision"))
set_metric("recall", metrics.get("recall"))
set_metric("fitness", metrics.get("fitness"))
# Optional per-class metrics
class_metrics = metrics.get("class_metrics") if isinstance(metrics, dict) else None
if isinstance(class_metrics, dict) and class_metrics:
items = sorted(class_metrics.items(), key=lambda kv: str(kv[0]))
self.per_class_table.setRowCount(len(items))
for row, (cls_name, cls_stats) in enumerate(items):
ap = (cls_stats or {}).get("ap")
ap50 = (cls_stats or {}).get("ap50")
self.per_class_table.setItem(row, 0, QTableWidgetItem(str(cls_name)))
self.per_class_table.setItem(row, 1, QTableWidgetItem(self._format_float(ap)))
self.per_class_table.setItem(row, 2, QTableWidgetItem(self._format_float(ap50)))
else:
self.per_class_table.setRowCount(0)
@staticmethod
def _format_float(value: Any) -> str:
if value is None:
return ""
try:
return f"{float(value):.4f}"
except Exception:
return str(value)
# ==================== Internal: plots ====================
def _clear_plots(self) -> None:
# Remove legacy grid widgets (from the initial implementation).
for widget in self._plot_widgets:
widget.setParent(None)
widget.deleteLater()
self._plot_widgets = []
self._plot_items = []
if hasattr(self, "plots_list"):
self.plots_list.blockSignals(True)
self.plots_list.clear()
self.plots_list.blockSignals(False)
if hasattr(self, "plot_view"):
self.plot_view.clear()
if hasattr(self, "selected_plot_title"):
self.selected_plot_title.setText("No image selected.")
if hasattr(self, "selected_plot_path"):
self.selected_plot_path.setText("")
def _render_plots(self, model: Dict[str, Any]) -> None:
self._clear_plots()
plot_dirs = self._infer_run_directories(model)
plot_items = self._discover_plot_items(plot_dirs)
if not plot_items:
dirs_text = "\n".join(str(p) for p in plot_dirs if p)
self.plots_status.setText(
"No validation plot images found for this model.\n\n"
"Searched directories:\n" + (dirs_text or "(none)")
)
return
self._plot_items = list(plot_items)
self.plots_status.setText(f"Found {len(plot_items)} plot image(s). Select one to view/zoom.")
self.plots_list.blockSignals(True)
self.plots_list.clear()
for idx, item in enumerate(self._plot_items):
qitem = QListWidgetItem(item.label)
qitem.setData(Qt.ItemDataRole.UserRole, idx)
pix = QPixmap(str(item.path))
if not pix.isNull():
thumb = pix.scaled(
self.plots_list.iconSize(),
Qt.AspectRatioMode.KeepAspectRatio,
Qt.TransformationMode.SmoothTransformation,
)
qitem.setIcon(thumb)
self.plots_list.addItem(qitem)
self.plots_list.blockSignals(False)
if self.plots_list.count() > 0:
self.plots_list.setCurrentRow(0)
def _on_plot_item_selected(self) -> None:
if not self._plot_items:
return
selected = self.plots_list.selectedItems()
if not selected:
return
idx = selected[0].data(Qt.ItemDataRole.UserRole)
try:
idx_int = int(idx)
except Exception:
return
if idx_int < 0 or idx_int >= len(self._plot_items):
return
plot = self._plot_items[idx_int]
self.selected_plot_title.setText(plot.label)
self.selected_plot_path.setText(str(plot.path))
pix = QPixmap(str(plot.path))
if pix.isNull():
self.plot_view.clear()
return
self.plot_view.set_pixmap(pix, fit=True)
def _infer_run_directories(self, model: Dict[str, Any]) -> List[Path]:
dirs: List[Path] = []
# 1) Infer from model_path: .../<run>/weights/best.pt -> <run>
model_path = model.get("model_path")
if model_path:
try:
p = Path(str(model_path)).expanduser()
if p.name.lower().endswith(".pt"):
# If it lives under weights/, use parent.parent.
if p.parent.name == "weights" and p.parent.parent.exists():
dirs.append(p.parent.parent)
elif p.parent.exists():
dirs.append(p.parent)
except Exception:
pass
# 2) Look at training_params.stage_results[].results.save_dir
training_params = model.get("training_params") or {}
stage_results = None
if isinstance(training_params, dict):
stage_results = training_params.get("stage_results")
if isinstance(stage_results, list):
for stage in stage_results:
results = (stage or {}).get("results")
save_dir = (results or {}).get("save_dir") if isinstance(results, dict) else None
if save_dir:
try:
save_path = Path(str(save_dir)).expanduser()
if save_path.exists():
dirs.append(save_path)
except Exception:
continue
# Deduplicate while preserving order.
unique: List[Path] = []
seen: set[str] = set()
for d in dirs:
try:
resolved = str(d.resolve())
except Exception:
resolved = str(d)
if resolved not in seen and d.exists() and d.is_dir():
seen.add(resolved)
unique.append(d)
return unique
def _discover_plot_items(self, directories: Sequence[Path]) -> List[_PlotItem]:
# Prefer canonical Ultralytics filenames first, then fall back to any png/jpg.
preferred_names = [
"results.png",
"results.jpg",
"confusion_matrix.png",
"confusion_matrix_normalized.png",
"labels.jpg",
"labels.png",
"BoxPR_curve.png",
"BoxP_curve.png",
"BoxR_curve.png",
"BoxF1_curve.png",
"MaskPR_curve.png",
"MaskP_curve.png",
"MaskR_curve.png",
"MaskF1_curve.png",
"val_batch0_pred.jpg",
"val_batch0_labels.jpg",
]
found: List[_PlotItem] = []
seen: set[str] = set()
for d in directories:
# 1) Preferred
for name in preferred_names:
p = d / name
if p.exists() and p.is_file():
key = str(p)
if key in seen:
continue
seen.add(key)
found.append(_PlotItem(label=f"{name} (from {d.name})", path=p))
# 2) Curated globs
for pattern in ("train_batch*.jpg", "val_batch*.jpg", "*curve*.png"):
for p in sorted(d.glob(pattern)):
if not p.is_file():
continue
key = str(p)
if key in seen:
continue
seen.add(key)
found.append(_PlotItem(label=f"{p.name} (from {d.name})", path=p))
# 3) Fallback: any top-level png/jpg (excluding weights dir contents)
for ext in ("*.png", "*.jpg", "*.jpeg", "*.webp"):
for p in sorted(d.glob(ext)):
if not p.is_file():
continue
key = str(p)
if key in seen:
continue
seen.add(key)
found.append(_PlotItem(label=f"{p.name} (from {d.name})", path=p))
# Keep list bounded to avoid UI overload for huge runs.
return found[:60]

View File

@@ -18,7 +18,7 @@ from PySide6.QtGui import (
QPaintEvent,
QPolygonF,
)
from PySide6.QtCore import Qt, QEvent, Signal, QPoint, QPointF, QRect, QTimer
from PySide6.QtCore import Qt, QEvent, Signal, QPoint, QPointF, QRect
from typing import Any, Dict, List, Optional, Tuple
from src.utils.image import Image, ImageLoadError
@@ -79,7 +79,9 @@ def rdp(points: List[Tuple[float, float]], epsilon: float) -> List[Tuple[float,
return [start, end]
def simplify_polyline(points: List[Tuple[float, float]], epsilon: float) -> List[Tuple[float, float]]:
def simplify_polyline(
points: List[Tuple[float, float]], epsilon: float
) -> List[Tuple[float, float]]:
"""
Simplify a polyline with RDP while preserving closure semantics.
@@ -143,10 +145,6 @@ class AnnotationCanvasWidget(QWidget):
self.zoom_step = 0.1
self.zoom_wheel_step = 0.15
# Auto-fit behavior (opt-in): when enabled, newly loaded images (and resizes)
# will scale to fill the available viewport while preserving aspect ratio.
self._auto_fit_to_view: bool = False
# Drawing / interaction state
self.is_drawing = False
self.polyline_enabled = False
@@ -177,35 +175,6 @@ class AnnotationCanvasWidget(QWidget):
self._setup_ui()
def set_auto_fit_to_view(self, enabled: bool):
"""Enable/disable automatic zoom-to-fit behavior."""
self._auto_fit_to_view = bool(enabled)
if self._auto_fit_to_view and self.original_pixmap is not None:
QTimer.singleShot(0, self.fit_to_view)
def fit_to_view(self, padding_px: int = 6):
"""Zoom the image so it fits the scroll area's viewport (aspect preserved)."""
if self.original_pixmap is None:
return
viewport = self.scroll_area.viewport().size()
available_w = max(1, int(viewport.width()) - int(padding_px))
available_h = max(1, int(viewport.height()) - int(padding_px))
img_w = max(1, int(self.original_pixmap.width()))
img_h = max(1, int(self.original_pixmap.height()))
scale_w = available_w / img_w
scale_h = available_h / img_h
new_scale = min(scale_w, scale_h)
new_scale = max(self.zoom_min, min(self.zoom_max, float(new_scale)))
if abs(new_scale - self.zoom_scale) < 1e-4:
return
self.zoom_scale = new_scale
self._apply_zoom()
def _setup_ui(self):
"""Setup user interface."""
layout = QVBoxLayout()
@@ -218,7 +187,9 @@ class AnnotationCanvasWidget(QWidget):
self.canvas_label = QLabel("No image loaded")
self.canvas_label.setAlignment(Qt.AlignCenter)
self.canvas_label.setStyleSheet("QLabel { background-color: #2b2b2b; color: #888; }")
self.canvas_label.setStyleSheet(
"QLabel { background-color: #2b2b2b; color: #888; }"
)
self.canvas_label.setScaledContents(False)
self.canvas_label.setMouseTracking(True)
@@ -241,18 +212,9 @@ class AnnotationCanvasWidget(QWidget):
self.zoom_scale = 1.0
self.clear_annotations()
self._display_image()
# Defer fit-to-view until the widget has a valid viewport size.
if self._auto_fit_to_view:
QTimer.singleShot(0, self.fit_to_view)
logger.debug(f"Loaded image into annotation canvas: {image.width}x{image.height}")
def resizeEvent(self, event):
"""Optionally keep the image fitted when the widget is resized."""
super().resizeEvent(event)
if self._auto_fit_to_view and self.original_pixmap is not None:
QTimer.singleShot(0, self.fit_to_view)
logger.debug(
f"Loaded image into annotation canvas: {image.width}x{image.height}"
)
def clear(self):
"""Clear the displayed image and all annotations."""
@@ -288,10 +250,12 @@ class AnnotationCanvasWidget(QWidget):
# Get image data in a format compatible with Qt
if self.current_image.channels in (3, 4):
image_data = self.current_image.get_rgb()
else:
image_data = self.current_image.get_qt_rgb()
height, width = image_data.shape[:2]
else:
image_data = self.current_image.get_grayscale()
height, width = image_data.shape
image_data = np.ascontiguousarray(image_data)
bytes_per_line = image_data.strides[0]
qimage = QImage(
@@ -299,7 +263,7 @@ class AnnotationCanvasWidget(QWidget):
width,
height,
bytes_per_line,
QImage.Format_RGBX32FPx4, # self.current_image.qtimage_format,
self.current_image.qtimage_format,
).copy() # Copy so Qt owns the buffer even after numpy array goes out of scope
self.original_pixmap = QPixmap.fromImage(qimage)
@@ -327,14 +291,22 @@ class AnnotationCanvasWidget(QWidget):
scaled_width,
scaled_height,
Qt.KeepAspectRatio,
(Qt.SmoothTransformation if self.zoom_scale >= 1.0 else Qt.FastTransformation),
(
Qt.SmoothTransformation
if self.zoom_scale >= 1.0
else Qt.FastTransformation
),
)
scaled_annotations = self.annotation_pixmap.scaled(
scaled_width,
scaled_height,
Qt.KeepAspectRatio,
(Qt.SmoothTransformation if self.zoom_scale >= 1.0 else Qt.FastTransformation),
(
Qt.SmoothTransformation
if self.zoom_scale >= 1.0
else Qt.FastTransformation
),
)
# Composite image and annotations
@@ -420,11 +392,16 @@ class AnnotationCanvasWidget(QWidget):
y = (pos.y() - offset_y) / self.zoom_scale
# Check bounds
if 0 <= x < self.original_pixmap.width() and 0 <= y < self.original_pixmap.height():
if (
0 <= x < self.original_pixmap.width()
and 0 <= y < self.original_pixmap.height()
):
return (int(x), int(y))
return None
def _find_polyline_at(self, img_x: float, img_y: float, threshold_px: float = 5.0) -> Optional[int]:
def _find_polyline_at(
self, img_x: float, img_y: float, threshold_px: float = 5.0
) -> Optional[int]:
"""
Find index of polyline whose geometry is within threshold_px of (img_x, img_y).
Returns the index in self.polylines, or None if none is close enough.
@@ -446,7 +423,9 @@ class AnnotationCanvasWidget(QWidget):
# Precise distance to all segments
for (x1, y1), (x2, y2) in zip(polyline[:-1], polyline[1:]):
d = perpendicular_distance((img_x, img_y), (float(x1), float(y1)), (float(x2), float(y2)))
d = perpendicular_distance(
(img_x, img_y), (float(x1), float(y1)), (float(x2), float(y2))
)
if d < best_dist:
best_dist = d
best_index = idx
@@ -647,7 +626,11 @@ class AnnotationCanvasWidget(QWidget):
def mouseMoveEvent(self, event: QMouseEvent):
"""Handle mouse move events for drawing."""
if not self.is_drawing or not self.polyline_enabled or self.annotation_pixmap is None:
if (
not self.is_drawing
or not self.polyline_enabled
or self.annotation_pixmap is None
):
super().mouseMoveEvent(event)
return
@@ -707,10 +690,15 @@ class AnnotationCanvasWidget(QWidget):
if len(simplified) >= 2:
# Store polyline and redraw all annotations
self._add_polyline(simplified, self.polyline_pen_color, self.polyline_pen_width)
self._add_polyline(
simplified, self.polyline_pen_color, self.polyline_pen_width
)
# Convert to normalized coordinates for metadata + signal
normalized_stroke = [self._image_to_normalized_coords(int(x), int(y)) for (x, y) in simplified]
normalized_stroke = [
self._image_to_normalized_coords(int(x), int(y))
for (x, y) in simplified
]
self.all_strokes.append(
{
"points": normalized_stroke,
@@ -723,7 +711,8 @@ class AnnotationCanvasWidget(QWidget):
# Emit signal with normalized coordinates
self.annotation_drawn.emit(normalized_stroke)
logger.debug(
f"Completed stroke with {len(simplified)} points " f"(normalized len={len(normalized_stroke)})"
f"Completed stroke with {len(simplified)} points "
f"(normalized len={len(normalized_stroke)})"
)
self.current_stroke = []
@@ -763,7 +752,9 @@ class AnnotationCanvasWidget(QWidget):
# Store polyline as [y_norm, x_norm] to match DB convention and
# the expectations of draw_saved_polyline().
normalized_polyline = [[y / img_height, x / img_width] for (x, y) in polyline]
normalized_polyline = [
[y / img_height, x / img_width] for (x, y) in polyline
]
logger.debug(
f"Polyline {idx}: {len(polyline)} points, "
@@ -783,7 +774,7 @@ class AnnotationCanvasWidget(QWidget):
self,
polyline: List[List[float]],
color: str,
width: int = 1,
width: int = 3,
annotation_id: Optional[int] = None,
):
"""
@@ -821,13 +812,17 @@ class AnnotationCanvasWidget(QWidget):
# Store and redraw using common pipeline
pen_color = QColor(color)
pen_color.setAlpha(255) # Add semi-transparency
pen_color.setAlpha(128) # Add semi-transparency
self._add_polyline(img_coords, pen_color, width, annotation_id=annotation_id)
# Store in all_strokes for consistency (uses normalized coordinates)
self.all_strokes.append({"points": polyline, "color": color, "alpha": 255, "width": width})
self.all_strokes.append(
{"points": polyline, "color": color, "alpha": 128, "width": width}
)
logger.debug(f"Drew saved polyline with {len(polyline)} points in color {color}")
logger.debug(
f"Drew saved polyline with {len(polyline)} points in color {color}"
)
def draw_saved_bbox(
self,
@@ -851,7 +846,9 @@ class AnnotationCanvasWidget(QWidget):
return
if len(bbox) != 4:
logger.warning(f"Invalid bounding box format: expected 4 values, got {len(bbox)}")
logger.warning(
f"Invalid bounding box format: expected 4 values, got {len(bbox)}"
)
return
# Convert normalized coordinates to image coordinates (for logging/debug)
@@ -872,11 +869,15 @@ class AnnotationCanvasWidget(QWidget):
# in _redraw_annotations() together with all polylines.
pen_color = QColor(color)
pen_color.setAlpha(128) # Add semi-transparency
self.bboxes.append([float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)])
self.bboxes.append(
[float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)]
)
self.bbox_meta.append({"color": pen_color, "width": int(width), "label": label})
# Store in all_strokes for consistency
self.all_strokes.append({"bbox": bbox, "color": color, "alpha": 128, "width": width, "label": label})
self.all_strokes.append(
{"bbox": bbox, "color": color, "alpha": 128, "width": width, "label": label}
)
# Redraw overlay (polylines + all bounding boxes)
self._redraw_annotations()

View File

@@ -1,21 +1,18 @@
"""YOLO model wrapper for the microscopy object detection application.
Notes on 16-bit TIFF support:
- Ultralytics training defaults assume 8-bit images and normalize by dividing by 255.
- This project can patch Ultralytics at runtime to decode TIFFs via `tifffile` and
normalize `uint16` correctly.
See [`apply_ultralytics_16bit_tiff_patches()`](src/utils/ultralytics_16bit_patch.py:1).
"""
YOLO model wrapper for the microscopy object detection application.
Provides a clean interface to YOLOv8 for training, validation, and inference.
"""
from ultralytics import YOLO
from pathlib import Path
from typing import Optional, List, Dict, Callable, Any
import torch
import tempfile
import os
from src.utils.image import Image
import numpy as np
from src.utils.image import Image, convert_grayscale_to_rgb_preserve_range
from src.utils.logger import get_logger
from src.utils.ultralytics_16bit_patch import apply_ultralytics_16bit_tiff_patches
from src.utils.train_ultralytics_float import train_with_float32_loader
logger = get_logger(__name__)
@@ -36,9 +33,6 @@ class YOLOWrapper:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"YOLOWrapper initialized with device: {self.device}")
# Apply Ultralytics runtime patches early (before first import/instantiation of YOLO datasets/trainers).
apply_ultralytics_16bit_tiff_patches()
def load_model(self) -> bool:
"""
Load YOLO model from path.
@@ -48,9 +42,6 @@ class YOLOWrapper:
"""
try:
logger.info(f"Loading YOLO model from {self.model_path}")
# Import YOLO lazily to ensure runtime patches are applied first.
from ultralytics import YOLO
self.model = YOLO(self.model_path)
self.model.to(self.device)
logger.info("Model loaded successfully")
@@ -70,10 +61,11 @@ class YOLOWrapper:
name: str = "custom_model",
resume: bool = False,
callbacks: Optional[Dict[str, Callable]] = None,
use_float32_loader: bool = True,
**kwargs,
) -> Dict[str, Any]:
"""
Train the YOLO model.
Train the YOLO model with optional float32 loader for 16-bit TIFFs.
Args:
data_yaml: Path to data.yaml configuration file
@@ -85,30 +77,43 @@ class YOLOWrapper:
name: Name for the training run
resume: Resume training from last checkpoint
callbacks: Optional Ultralytics callback dictionary
use_float32_loader: Use custom Float32Dataset for 16-bit TIFFs (default: True)
**kwargs: Additional training arguments
Returns:
Dictionary with training results
"""
if 1:
logger.info(f"Starting training: {name}")
logger.info(
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
)
# Check if dataset has 16-bit TIFFs and use float32 loader
if use_float32_loader:
logger.info("Using Float32Dataset loader for 16-bit TIFF support")
return train_with_float32_loader(
model_path=self.model_path,
data_yaml=data_yaml,
epochs=epochs,
imgsz=imgsz,
batch=batch,
patience=patience,
save_dir=save_dir,
name=name,
callbacks=callbacks,
device=self.device,
resume=resume,
**kwargs,
)
else:
# Standard training (old behavior)
if self.model is None:
if not self.load_model():
raise RuntimeError(f"Failed to load model from {self.model_path}")
raise RuntimeError(
f"Failed to load model from {self.model_path}"
)
try:
logger.info(f"Starting training: {name}")
logger.info(f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}")
# Defaults for 16-bit safety: disable augmentations that force uint8 and HSV ops that assume 0..255.
# Users can override by passing explicit kwargs.
kwargs.setdefault("mosaic", 0.0)
kwargs.setdefault("mixup", 0.0)
kwargs.setdefault("cutmix", 0.0)
kwargs.setdefault("copy_paste", 0.0)
kwargs.setdefault("hsv_h", 0.0)
kwargs.setdefault("hsv_s", 0.0)
kwargs.setdefault("hsv_v", 0.0)
# Train the model
results = self.model.train(
data=data_yaml,
epochs=epochs,
@@ -125,9 +130,9 @@ class YOLOWrapper:
logger.info("Training completed successfully")
return self._format_training_results(results)
except Exception as e:
logger.error(f"Error during training: {e}")
raise
# except Exception as e:
# logger.error(f"Error during training: {e}")
# raise
def validate(self, data_yaml: str, split: str = "val", **kwargs) -> Dict[str, Any]:
"""
@@ -147,7 +152,9 @@ class YOLOWrapper:
try:
logger.info(f"Starting validation on {split} split")
results = self.model.val(data=data_yaml, split=split, device=self.device, **kwargs)
results = self.model.val(
data=data_yaml, split=split, device=self.device, **kwargs
)
logger.info("Validation completed successfully")
return self._format_validation_results(results)
@@ -186,18 +193,17 @@ class YOLOWrapper:
raise RuntimeError(f"Failed to load model from {self.model_path}")
prepared_source, cleanup_path = self._prepare_source(source)
imgsz = 1088
try:
logger.info(f"Running inference on {source} -> prepared_source {prepared_source}")
logger.info(f"Running inference on {source}")
results = self.model.predict(
source=source,
source=prepared_source,
conf=conf,
iou=iou,
save=save,
save_txt=save_txt,
save_conf=save_conf,
device=self.device,
imgsz=imgsz,
**kwargs,
)
@@ -209,13 +215,20 @@ class YOLOWrapper:
logger.error(f"Error during inference: {e}")
raise
finally:
if 0: # cleanup_path:
# Clean up temporary files (only for non-16-bit images)
# 16-bit TIFFs return numpy arrays directly, so cleanup_path is None
if cleanup_path:
try:
os.remove(cleanup_path)
logger.debug(f"Cleaned up temporary file: {cleanup_path}")
except OSError as cleanup_error:
logger.warning(f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}")
logger.warning(
f"Failed to delete temporary file {cleanup_path}: {cleanup_error}"
)
def export(self, format: str = "onnx", output_path: Optional[str] = None, **kwargs) -> str:
def export(
self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
) -> str:
"""
Export model to different format.
@@ -242,7 +255,14 @@ class YOLOWrapper:
raise
def _prepare_source(self, source):
"""Convert single-channel images to RGB temporarily for inference."""
"""Convert single-channel images to RGB for inference.
For 16-bit TIFF files, this will:
1. Load using tifffile
2. Normalize to float32 [0-1] (NO uint8 conversion to avoid data loss)
3. Replicate grayscale → RGB (3 channels)
4. Pass directly as numpy array to YOLO
"""
cleanup_path = None
if isinstance(source, (str, Path)):
@@ -250,13 +270,60 @@ class YOLOWrapper:
if source_path.is_file():
try:
img_obj = Image(source_path)
# Check if it's a 16-bit TIFF file
is_16bit_tiff = (
source_path.suffix.lower() in [".tif", ".tiff"]
and img_obj.dtype == np.uint16
)
if is_16bit_tiff:
# Process 16-bit TIFF: normalize to float32 [0-1]
# NO uint8 conversion - pass float32 directly to avoid data loss
normalized_float = img_obj.to_normalized_float32()
# Convert grayscale to RGB by replicating channels
if len(normalized_float.shape) == 2:
# Grayscale: H,W → H,W,3
rgb_float = np.stack([normalized_float] * 3, axis=-1)
elif (
len(normalized_float.shape) == 3
and normalized_float.shape[2] == 1
):
# Grayscale with channel dim: H,W,1 → H,W,3
rgb_float = np.repeat(normalized_float, 3, axis=2)
else:
# Already multi-channel
rgb_float = normalized_float
# Ensure contiguous array and float32
rgb_float = np.ascontiguousarray(rgb_float, dtype=np.float32)
logger.info(
f"Loaded 16-bit TIFF {source_path} as float32 [0-1] RGB "
f"(shape: {rgb_float.shape}, dtype: {rgb_float.dtype}, "
f"range: [{rgb_float.min():.4f}, {rgb_float.max():.4f}])"
)
# Return numpy array directly - YOLO can handle it
return rgb_float, cleanup_path
else:
# Standard processing for other images
pil_img = img_obj.pil_image
if len(pil_img.getbands()) == 1:
rgb_img = convert_grayscale_to_rgb_preserve_range(pil_img)
else:
rgb_img = pil_img.convert("RGB")
suffix = source_path.suffix or ".png"
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
tmp_path = tmp.name
tmp.close()
img_obj.save(tmp_path)
rgb_img.save(tmp_path)
cleanup_path = tmp_path
logger.info(f"Converted image {source_path} to RGB for inference at {tmp_path}")
logger.info(
f"Converted image {source_path} to RGB for inference at {tmp_path}"
)
return tmp_path, cleanup_path
except Exception as convert_error:
logger.warning(
@@ -269,7 +336,9 @@ class YOLOWrapper:
"""Format training results into dictionary."""
try:
# Get the results dict
results_dict = results.results_dict if hasattr(results, "results_dict") else {}
results_dict = (
results.results_dict if hasattr(results, "results_dict") else {}
)
formatted = {
"success": True,
@@ -302,7 +371,9 @@ class YOLOWrapper:
"mAP50-95": float(box_metrics.map),
"precision": float(box_metrics.mp),
"recall": float(box_metrics.mr),
"fitness": (float(results.fitness) if hasattr(results, "fitness") else 0.0),
"fitness": (
float(results.fitness) if hasattr(results, "fitness") else 0.0
),
}
# Add per-class metrics if available
@@ -312,7 +383,11 @@ class YOLOWrapper:
if idx < len(box_metrics.ap):
class_metrics[name] = {
"ap": float(box_metrics.ap[idx]),
"ap50": (float(box_metrics.ap50[idx]) if hasattr(box_metrics, "ap50") else 0.0),
"ap50": (
float(box_metrics.ap50[idx])
if hasattr(box_metrics, "ap50")
else 0.0
),
}
formatted["class_metrics"] = class_metrics
@@ -345,15 +420,21 @@ class YOLOWrapper:
"class_id": int(boxes.cls[i]),
"class_name": result.names[int(boxes.cls[i])],
"confidence": float(boxes.conf[i]),
"bbox_normalized": [float(v) for v in xyxyn], # [x_min, y_min, x_max, y_max]
"bbox_absolute": [float(v) for v in boxes.xyxy[i].cpu().numpy()], # Absolute pixels
"bbox_normalized": [
float(v) for v in xyxyn
], # [x_min, y_min, x_max, y_max]
"bbox_absolute": [
float(v) for v in boxes.xyxy[i].cpu().numpy()
], # Absolute pixels
}
# Extract segmentation mask if available
if has_masks:
try:
# Get the mask for this detection
mask_data = result.masks.xy[i] # Polygon coordinates in absolute pixels
mask_data = result.masks.xy[
i
] # Polygon coordinates in absolute pixels
# Convert to normalized coordinates
if len(mask_data) > 0:
@@ -366,7 +447,9 @@ class YOLOWrapper:
else:
detection["segmentation_mask"] = None
except Exception as mask_error:
logger.warning(f"Error extracting mask for detection {i}: {mask_error}")
logger.warning(
f"Error extracting mask for detection {i}: {mask_error}"
)
detection["segmentation_mask"] = None
else:
detection["segmentation_mask"] = None
@@ -380,7 +463,9 @@ class YOLOWrapper:
return []
@staticmethod
def convert_bbox_format(bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy") -> List[float]:
def convert_bbox_format(
bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy"
) -> List[float]:
"""
Convert bounding box between formats.

View File

@@ -54,7 +54,7 @@ class ConfigManager:
"models_directory": "data/models",
"base_model_choices": [
"yolov8s-seg.pt",
"yolo11s-seg.pt",
"yolov11s-seg.pt",
],
},
"training": {
@@ -225,4 +225,6 @@ class ConfigManager:
def get_allowed_extensions(self) -> list:
"""Get list of allowed image file extensions."""
return self.get("image_repository.allowed_extensions", Image.SUPPORTED_EXTENSIONS)
return self.get(
"image_repository.allowed_extensions", Image.SUPPORTED_EXTENSIONS
)

View File

@@ -1,103 +0,0 @@
import numpy as np
from pathlib import Path
from skimage.draw import polygon
from tifffile import TiffFile
from src.database.db_manager import DatabaseManager
def read_image(image_path: Path) -> np.ndarray:
metadata = {}
with TiffFile(image_path) as tif:
image = tif.asarray()
metadata = tif.imagej_metadata
return image, metadata
def main():
polygon_vertices = np.array([[10, 10], [50, 10], [50, 50], [10, 50]])
image = np.zeros((100, 100), dtype=np.uint8)
rr, cc = polygon(polygon_vertices[:, 0], polygon_vertices[:, 1])
image[rr, cc] = 255
if __name__ == "__main__":
db = DatabaseManager()
model_name = "c17"
model_id = db.get_models(filters={"model_name": model_name})[0]["id"]
print(f"Model name {model_name}, id {model_id}")
detections = db.get_detections(filters={"model_id": model_id})
file_stems = set()
for detection in detections:
file_stems.add(detection["image_filename"].split("_")[0])
print("Files:", file_stems)
for stem in file_stems:
print(stem)
detections = db.get_detections(filters={"model_id": model_id, "i.filename": f"LIKE %{stem}%"})
annotations = []
for detection in detections:
source_path = Path(detection["metadata"]["source_path"])
image, metadata = read_image(source_path)
offset = np.array(list(map(int, metadata["tile_section"].split(","))))[::-1]
scale = np.array(list(map(int, metadata["patch_size"].split(","))))[::-1]
# tile_size = np.array(list(map(int, metadata["tile_size"].split(","))))
segmentation = np.array(detection["segmentation_mask"]) # * tile_size
# print(source_path, image, metadata, segmentation.shape)
# print(offset)
# print(scale)
# print(segmentation)
# segmentation = (segmentation + offset * tile_size) / (tile_size * scale)
segmentation = (segmentation + offset) / scale
yolo_annotation = f"{detection['metadata']['class_id']} " + " ".join(
[f"{x:.6f} {y:.6f}" for x, y in segmentation]
)
annotations.append(yolo_annotation)
# print(segmentation)
# print(yolo_annotation)
# aa
print(
" ",
detection["model_name"],
detection["image_id"],
detection["image_filename"],
source_path,
metadata["label_path"],
)
# section_i_section_j = detection["image_filename"].split("_")[1].split(".")[0]
# print(" ", section_i_section_j)
label_path = metadata["label_path"]
print(" ", label_path)
with open(label_path, "w") as f:
f.write("\n".join(annotations))
exit()
for detection in detections:
print(detection["model_name"], detection["image_id"], detection["image_filename"])
print(detections[0])
# polygon_vertices = np.array([[10, 10], [50, 10], [50, 50], [10, 50]])
# image = np.zeros((100, 100), dtype=np.uint8)
# rr, cc = polygon(polygon_vertices[:, 0], polygon_vertices[:, 1])
# image[rr, cc] = 255
# import matplotlib.pyplot as plt
# plt.imshow(image, cmap='gray')
# plt.show()

View File

@@ -6,55 +6,17 @@ import cv2
import numpy as np
from pathlib import Path
from typing import Optional, Tuple, Union
from PIL import Image as PILImage
import tifffile
from src.utils.logger import get_logger
from src.utils.file_utils import validate_file_path, is_image_file
from PySide6.QtGui import QImage
from tifffile import imread, imwrite
logger = get_logger(__name__)
def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.5) -> np.ndarray:
"""
Convert a grayscale image to a pseudo-RGB image using a gamma correction.
Args:
arr: Input grayscale image as numpy array
Returns:
Pseudo-RGB image as numpy array
"""
if arr.ndim != 2:
raise ValueError("Input array must be a grayscale image with shape (H, W)")
a1 = arr.copy().astype(np.float32)
a1 -= np.percentile(a1, 2)
a1[a1 < 0] = 0
p999 = np.percentile(a1, 99.9)
a1[a1 > p999] = p999
a1 /= a1.max()
if 1:
a2 = a1.copy()
a2 = a2**gamma
a2 /= a2.max()
a3 = a1.copy()
p9999 = np.percentile(a3, 99.99)
a3[a3 > p9999] = p9999
a3 /= a3.max()
# return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
# return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
out = np.stack([a1, a2, a3], axis=0)
# print(any(np.isnan(out).flatten()))
return out
class ImageLoadError(Exception):
"""Exception raised when an image cannot be loaded."""
@@ -93,6 +55,7 @@ class Image:
"""
self.path = Path(image_path)
self._data: Optional[np.ndarray] = None
self._pil_image: Optional[PILImage.Image] = None
self._width: int = 0
self._height: int = 0
self._channels: int = 0
@@ -118,34 +81,75 @@ class Image:
if not is_image_file(str(self.path), self.SUPPORTED_EXTENSIONS):
ext = self.path.suffix.lower()
raise ImageLoadError(
f"Unsupported image format: {ext}. " f"Supported formats: {', '.join(self.SUPPORTED_EXTENSIONS)}"
f"Unsupported image format: {ext}. "
f"Supported formats: {', '.join(self.SUPPORTED_EXTENSIONS)}"
)
try:
# Check if it's a TIFF file - use tifffile for better support
if self.path.suffix.lower() in [".tif", ".tiff"]:
self._data = imread(str(self.path))
else:
# raise NotImplementedError("RGB is not implemented")
# Load with OpenCV (returns BGR format)
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
self._data = tifffile.imread(str(self.path))
if self._data is None:
raise ImageLoadError(f"Failed to load image with OpenCV: {self.path}")
raise ImageLoadError(
f"Failed to load TIFF with tifffile: {self.path}"
)
# Extract metadata
# print(self._data.shape)
if len(self._data.shape) == 2:
self._height, self._width = self._data.shape[:2]
self._channels = 1
else:
self._height, self._width = self._data.shape[1:]
self._channels = self._data.shape[0]
# self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1
self._height, self._width = (
self._data.shape[:2]
if len(self._data.shape) >= 2
else (self._data.shape[0], 1)
)
self._channels = (
self._data.shape[2] if len(self._data.shape) == 3 else 1
)
self._format = self.path.suffix.lower().lstrip(".")
self._size_bytes = self.path.stat().st_size
self._dtype = self._data.dtype
if 0:
# Load PIL version for compatibility
if self._channels == 1:
# Grayscale
self._pil_image = PILImage.fromarray(self._data)
else:
# Multi-channel (RGB or RGBA)
self._pil_image = PILImage.fromarray(self._data)
logger.info(
f"Successfully loaded TIFF image: {self.path.name} "
f"({self._width}x{self._height}, {self._channels} channels, "
f"dtype={self._dtype}, {self._format.upper()})"
)
else:
# Load with OpenCV (returns BGR format) for non-TIFF images
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
if self._data is None:
raise ImageLoadError(
f"Failed to load image with OpenCV: {self.path}"
)
# Extract metadata
self._height, self._width = self._data.shape[:2]
self._channels = (
self._data.shape[2] if len(self._data.shape) == 3 else 1
)
self._format = self.path.suffix.lower().lstrip(".")
self._size_bytes = self.path.stat().st_size
self._dtype = self._data.dtype
# Load PIL version for compatibility (convert BGR to RGB)
if self._channels == 3:
rgb_data = cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
self._pil_image = PILImage.fromarray(rgb_data)
elif self._channels == 4:
rgba_data = cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA)
self._pil_image = PILImage.fromarray(rgba_data)
else:
# Grayscale
self._pil_image = PILImage.fromarray(self._data)
logger.info(
f"Successfully loaded image: {self.path.name} "
f"({self._width}x{self._height}, {self._channels} channels, "
@@ -168,6 +172,18 @@ class Image:
raise ImageLoadError("Image data not available")
return self._data
@property
def pil_image(self) -> PILImage.Image:
"""
Get image data as PIL Image (RGB or grayscale).
Returns:
PIL Image object
"""
if self._pil_image is None:
raise ImageLoadError("PIL image not available")
return self._pil_image
@property
def width(self) -> int:
"""Get image width in pixels."""
@@ -212,7 +228,6 @@ class Image:
@property
def dtype(self) -> np.dtype:
"""Get the data type of the image array."""
if self._dtype is None:
raise ImageLoadError("Image dtype not available")
return self._dtype
@@ -232,10 +247,8 @@ class Image:
elif self._channels == 1:
if self._dtype == np.uint16:
return QImage.Format_Grayscale16
elif self._dtype == np.uint8:
else:
return QImage.Format_Grayscale8
elif self._dtype == np.float32:
return QImage.Format_BGR30
else:
raise ImageLoadError(f"Unsupported number of channels: {self._channels}")
@@ -246,36 +259,12 @@ class Image:
Returns:
Image data in RGB format as numpy array
"""
if self.channels == 1:
img = get_pseudo_rgb(self.data)
self._dtype = img.dtype
return img, True
elif self._channels == 3:
return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB), False
if self._channels == 3:
return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
elif self._channels == 4:
return cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA), False
return cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA)
else:
raise NotImplementedError
# else:
# return self._data
def get_qt_rgb(self) -> np.ascontiguousarray:
# we keep data as (C, H, W)
_img, pseudo = self.get_rgb()
if pseudo:
img = np.zeros((self.height, self.width, 4), dtype=np.float32)
img[..., 0] = _img[0] # R gradient
img[..., 1] = _img[1] # G gradient
img[..., 2] = _img[2] # B constant
img[..., 3] = 1.0 # A = 1.0 (opaque)
return np.ascontiguousarray(img)
else:
return np.ascontiguousarray(_img)
return self._data
def get_grayscale(self) -> np.ndarray:
"""
@@ -329,26 +318,49 @@ class Image:
"""
return self._channels >= 3
def save(self, path: Union[str, Path], pseudo_rgb: bool = True) -> None:
def to_normalized_float32(self) -> np.ndarray:
"""
Convert image data to normalized float32 in range [0, 1].
if self.channels == 1:
if pseudo_rgb:
img = get_pseudo_rgb(self.data)
print("Image.save", img.shape)
For 16-bit images, this properly scales the full dynamic range.
For 8-bit images, divides by 255.
Already float images are clipped to [0, 1].
Returns:
Normalized image data as float32 numpy array [0, 1]
"""
data = self._data.astype(np.float32)
if self._dtype == np.uint16:
# 16-bit: normalize by max value (65535)
data = data / 65535.0
elif self._dtype == np.uint8:
# 8-bit: normalize by 255
data = data / 255.0
elif np.issubdtype(self._dtype, np.floating):
# Already float, just clip to [0, 1]
data = np.clip(data, 0.0, 1.0)
else:
img = np.repeat(self.data, 3, axis=2)
# Other integer types: use dtype info
if np.issubdtype(self._dtype, np.integer):
max_val = np.iinfo(self._dtype).max
data = data / float(max_val)
else:
raise NotImplementedError("Only grayscale images are supported for now.")
# Unknown type: attempt min-max normalization
min_val = data.min()
max_val = data.max()
if max_val > min_val:
data = (data - min_val) / (max_val - min_val)
else:
data = np.zeros_like(data)
imwrite(path, data=img)
return np.clip(data, 0.0, 1.0)
def __repr__(self) -> str:
"""String representation of the Image object."""
return (
f"Image(path='{self.path.name}', "
# Display as HxWxC to match the conventional NumPy shape semantics.
f"shape=({self._height}x{self._width}x{self._channels}), "
f"shape=({self._width}x{self._height}x{self._channels}), "
f"format={self._format}, "
f"size={self.size_mb:.2f}MB)"
)
@@ -358,13 +370,38 @@ class Image:
return self.__repr__()
if __name__ == "__main__":
import argparse
def convert_grayscale_to_rgb_preserve_range(
pil_image: PILImage.Image,
) -> PILImage.Image:
"""Convert a single-channel PIL image to RGB while preserving dynamic range.
parser = argparse.ArgumentParser()
parser.add_argument("--path", type=str, required=True)
args = parser.parse_args()
Args:
pil_image: Single-channel PIL image (e.g., 16-bit grayscale).
img = Image(args.path)
img.save(args.path + "test.tif")
print(img)
Returns:
PIL Image in RGB mode with intensities normalized to 0-255.
"""
if pil_image.mode == "RGB":
return pil_image
grayscale = np.array(pil_image)
if grayscale.ndim == 3:
grayscale = grayscale[:, :, 0]
original_dtype = grayscale.dtype
grayscale = grayscale.astype(np.float32)
if grayscale.size == 0:
return PILImage.new("RGB", pil_image.size, color=(0, 0, 0))
if np.issubdtype(original_dtype, np.integer):
denom = float(max(np.iinfo(original_dtype).max, 1))
else:
max_val = float(grayscale.max())
denom = max(max_val, 1.0)
grayscale = np.clip(grayscale / denom, 0.0, 1.0)
grayscale_u8 = (grayscale * 255.0).round().astype(np.uint8)
rgb_arr = np.repeat(grayscale_u8[:, :, None], 3, axis=2)
return PILImage.fromarray(rgb_arr, mode="RGB")

View File

@@ -12,38 +12,23 @@ class UT:
Operetta files along with rois drawn in ImageJ
"""
def __init__(self, roifile_fn: Path, no_labels: bool):
def __init__(self, roifile_fn: Path):
self.roifile_fn = roifile_fn
print("is file", self.roifile_fn.is_file())
self.rois = None
if no_labels:
self.rois = ImagejRoi.fromfile(self.roifile_fn)
print(self.roifile_fn.stem)
print(self.roifile_fn.parent.parts[-1])
if "Roi-" in self.roifile_fn.stem:
self.stem = self.roifile_fn.stem.split("Roi-")[1]
else:
self.stem = self.roifile_fn.parent.parts[-1]
else:
self.roifile_fn = roifile_fn / roifile_fn.parts[-1]
self.stem = self.roifile_fn.stem
print(self.roifile_fn)
print(self.stem)
self.stem = self.roifile_fn.stem.strip("-RoiSet")
self.image, self.image_props = self._load_images()
def _load_images(self):
"""Loading sequence of tif files
array sequence is CZYX
"""
print("Loading images:", self.roifile_fn.parent, self.stem)
fns = list(self.roifile_fn.parent.glob(f"{self.stem.lower()}*.tif*"))
print(self.roifile_fn.parent, self.stem)
fns = list(self.roifile_fn.parent.glob(f"{self.stem}*.tif*"))
stems = [fn.stem.split(self.stem)[-1] for fn in fns]
n_ch = len(set([stem.split("-ch")[-1].split("t")[0] for stem in stems]))
n_p = len(set([stem.split("-")[0] for stem in stems]))
n_t = len(set([stem.split("t")[1] for stem in stems]))
print(n_ch, n_p, n_t)
with TiffFile(fns[0]) as tif:
img = tif.asarray()
@@ -57,7 +42,6 @@ class UT:
"height": h,
"dtype": dtype,
}
print("Image props", self.image_props)
image_stack = np.zeros((n_ch, n_p, w, h), dtype=dtype)
for fn in fns:
@@ -65,7 +49,7 @@ class UT:
img = tif.asarray()
stem = fn.stem.split(self.stem)[-1]
ch = int(stem.split("-ch")[-1].split("t")[0])
p = int(stem.split("-")[0].split("p")[1])
p = int(stem.split("-")[0].lstrip("p"))
t = int(stem.split("t")[1])
print(fn.stem, "ch", ch, "p", p, "t", t)
image_stack[ch - 1, p - 1] = img
@@ -98,19 +82,10 @@ class UT:
):
"""Export rois to a file"""
with open(path / subfolder / f"{self.stem}.txt", "w") as f:
for i, roi in enumerate(self.rois):
rc = roi.subpixel_coordinates
if rc is None:
print(f"No coordinates: {self.roifile_fn}, element {i}, out of {len(self.rois)}")
continue
xmn, ymn = rc.min(axis=0)
xmx, ymx = rc.max(axis=0)
xc = (xmn + xmx) / 2
yc = (ymn + ymx) / 2
bw = xmx - xmn
bh = ymx - ymn
coords = f"{xc/self.width} {yc/self.height} {bw/self.width} {bh/self.height} "
for x, y in rc:
for roi in self.rois:
# TODO add image coordinates normalization
coords = ""
for x, y in roi.subpixel_coordinates:
coords += f"{x/self.width} {y/self.height} "
f.write(f"{class_index} {coords}\n")
@@ -129,7 +104,6 @@ class UT:
self.image = np.max(self.image[channel], axis=0)
print(self.image.shape)
print(path / subfolder / f"{self.stem}.tif")
with TiffWriter(path / subfolder / f"{self.stem}.tif") as tif:
tif.write(self.image)
@@ -138,31 +112,11 @@ if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input", nargs="*", type=Path)
parser.add_argument("-o", "--output", type=Path)
parser.add_argument(
"--no-labels",
action="store_false",
help="Source does not have labels, export only images",
)
parser.add_argument("input", type=Path)
parser.add_argument("output", type=Path)
args = parser.parse_args()
# print(args)
# aa
for path in args.input:
print("Path:", path)
if not args.no_labels:
print("No labels")
ut = UT(path, args.no_labels)
ut.export_image(args.output, plane_mode="max projection", channel=0)
else:
for rfn in Path(path).glob("*.zip"):
# if Path(path).suffix == ".zip":
print("Roi FN:", rfn)
ut = UT(rfn, args.no_labels)
for rfn in args.input.glob("*.zip"):
ut = UT(rfn)
ut.export_rois(args.output, class_index=0)
ut.export_image(args.output, plane_mode="max projection", channel=0)
print()

View File

@@ -1,368 +0,0 @@
import numpy as np
from pathlib import Path
from tifffile import imread, imwrite
from shapely.geometry import LineString
from copy import deepcopy
from scipy.ndimage import zoom
# debug
from src.utils.image import Image
from show_yolo_seg import draw_annotations
import pylab as plt
import cv2
class Label:
def __init__(self, yolo_annotation: str):
class_id, bbox, polygon = self.parse_yolo_annotation(yolo_annotation)
self.class_id = class_id
self.bbox = bbox
self.polygon = polygon
def parse_yolo_annotation(self, yolo_annotation: str):
class_id, *coords = yolo_annotation.split()
class_id = int(class_id)
bbox = np.array(coords[:4], dtype=np.float32)
polygon = np.array(coords[4:], dtype=np.float32).reshape(-1, 2) if len(coords) > 4 else None
if not any(np.isclose(polygon[0], polygon[-1])):
polygon = np.vstack([polygon, polygon[0]])
return class_id, bbox, polygon
def offset_label(
self,
img_w,
img_h,
distance: float = 1.0,
cap_style: int = 2,
join_style: int = 2,
):
if self.polygon is None:
self.bbox = np.array(
[
self.bbox[0] - distance if self.bbox[0] - distance > 0 else 0,
self.bbox[1] - distance if self.bbox[1] - distance > 0 else 0,
self.bbox[2] + distance if self.bbox[2] + distance < 1 else 1,
self.bbox[3] + distance if self.bbox[3] + distance < 1 else 1,
],
dtype=np.float32,
)
return self.bbox
def coords_are_normalized(coords):
# If every coordinate is between 0 and 1 (inclusive-ish), assume normalized
print(coords)
# if not coords:
# return False
return all(max(coords.flatten)) <= 1.001
def poly_to_pts(coords, img_w, img_h):
# coords: [x1 y1 x2 y2 ...] either normalized or absolute
# if coords_are_normalized(coords):
coords = [coords[i] * (img_w if i % 2 == 0 else img_h) for i in range(len(coords))]
pts = np.array(coords, dtype=np.int32).reshape(-1, 2)
return pts
pts = poly_to_pts(self.polygon, img_w, img_h)
line = LineString(pts)
# Buffer distance in pixels
buffered = line.buffer(distance=distance, cap_style=cap_style, join_style=join_style)
self.polygon = np.array(buffered.exterior.coords, dtype=np.float32) / (img_w, img_h)
xmn, ymn = self.polygon.min(axis=0)
xmx, ymx = self.polygon.max(axis=0)
xc = (xmn + xmx) / 2
yc = (ymn + ymx) / 2
bw = xmx - xmn
bh = ymx - ymn
self.bbox = np.array([xc, yc, bw, bh], dtype=np.float32)
return self.bbox, self.polygon
def translate(self, x, y, scale_x, scale_y):
self.bbox[0] -= x
self.bbox[0] *= scale_x
self.bbox[1] -= y
self.bbox[1] *= scale_y
self.bbox[2] *= scale_x
self.bbox[3] *= scale_y
if self.polygon is not None:
self.polygon[:, 0] -= x
self.polygon[:, 0] *= scale_x
self.polygon[:, 1] -= y
self.polygon[:, 1] *= scale_y
def in_range(self, hrange, wrange):
xc, yc, h, w = self.bbox
x1 = xc - w / 2
y1 = yc - h / 2
x2 = xc + w / 2
y2 = yc + h / 2
truth_val = (
xc >= wrange[0]
and x1 <= wrange[1]
and x2 >= wrange[0]
and x2 <= wrange[1]
and y1 >= hrange[0]
and y1 <= hrange[1]
and y2 >= hrange[0]
and y2 <= hrange[1]
)
print(x1, x2, wrange, y1, y2, hrange, truth_val)
return truth_val
def to_string(self, bbox: list = None, polygon: list = None):
coords = ""
if bbox is None:
bbox = self.bbox
# coords += " ".join([f"{x:.6f}" for x in self.bbox])
if polygon is None:
polygon = self.polygon
if self.polygon is not None:
coords += " " + " ".join([f"{x:.6f} {y:.6f}" for x, y in self.polygon])
return f"{self.class_id} {coords}"
def __str__(self):
return f"Class: {self.class_id}, BBox: {self.bbox}, Polygon: {self.polygon}"
class YoloLabelReader:
def __init__(self, label_path: Path):
self.label_path = label_path
self.labels = self._read_labels()
def _read_labels(self):
with open(self.label_path, "r") as f:
labels = [Label(line) for line in f.readlines()]
return labels
def get_labels(self, hrange, wrange):
"""hrange and wrange are tuples of (start, end) normalized to [0, 1]"""
labels = []
# print(hrange, wrange)
for lbl in self.labels:
# print(lbl)
if lbl.in_range(hrange, wrange):
labels.append(lbl)
return labels if len(labels) > 0 else None
def __get_item__(self, index):
return self.labels[index]
def __len__(self):
return len(self.labels)
def __iter__(self):
return iter(self.labels)
class ImageSplitter:
def __init__(self, image_path: Path, label_path: Path):
self.image = imread(image_path)
self.image_path = image_path
self.label_path = label_path
if not label_path.exists():
print(f"Label file {label_path} not found")
self.labels = None
else:
self.labels = YoloLabelReader(label_path)
def split_into_tiles(self, patch_size: tuple = (2, 2)):
"""Split image into patches of size patch_size"""
hstep, wstep = (
self.image.shape[0] // patch_size[0],
self.image.shape[1] // patch_size[1],
)
h, w = self.image.shape[:2]
for i in range(patch_size[0]):
for j in range(patch_size[1]):
metadata = {
"image_path": str(self.image_path),
"label_path": str(self.label_path),
"tile_section": f"{i}, {j}",
"tile_size": f"{hstep}, {wstep}",
"patch_size": f"{patch_size[0]}, {patch_size[1]}",
}
tile_reference = f"i{i}j{j}"
hrange = (i * hstep / h, (i + 1) * hstep / h)
wrange = (j * wstep / w, (j + 1) * wstep / w)
tile = self.image[i * hstep : (i + 1) * hstep, j * wstep : (j + 1) * wstep]
labels = None
if self.labels is not None:
labels = deepcopy(self.labels.get_labels(hrange, wrange))
print(id(labels))
if labels is not None:
print(hrange[0], wrange[0])
for l in labels:
print(l.bbox)
[l.translate(wrange[0], hrange[0], 2, 2) for l in labels]
print("translated")
for l in labels:
print(l.bbox)
# print(labels)
yield tile_reference, tile, labels, metadata
def split_respective_to_label(self, padding: int = 67):
if self.labels is None:
raise ValueError("No labels found. Only images having labels can be split.")
for i, label in enumerate(self.labels):
tile_reference = f"_lbl-{i+1:02d}"
# print(label.bbox)
metadata = {"image_path": str(self.image_path), "label_path": str(self.label_path), "label_index": str(i)}
xc_norm, yc_norm, h_norm, w_norm = label.bbox # normalized coords
xc, yc, h, w = [
int(np.round(f))
for f in [
xc_norm * self.image.shape[1],
yc_norm * self.image.shape[0],
h_norm * self.image.shape[0],
w_norm * self.image.shape[1],
]
] # image coords
# print("img coords:", xc, yc, h, w)
pad_xneg = padding + 1 # int(w / 2) + padding
pad_xpos = padding # int(w / 2) + padding
pad_yneg = padding + 1 # int(h / 2) + padding
pad_ypos = padding # int(h / 2) + padding
if xc - pad_xneg < 0:
pad_xneg = xc
if pad_xpos + xc > self.image.shape[1]:
pad_xpos = self.image.shape[1] - xc
if yc - pad_yneg < 0:
pad_yneg = yc
if pad_ypos + yc > self.image.shape[0]:
pad_ypos = self.image.shape[0] - yc
# print("pads:", pad_xneg, pad_xpos, pad_yneg, pad_ypos)
tile = self.image[
yc - pad_yneg : yc + pad_ypos,
xc - pad_xneg : xc + pad_xpos,
]
ny, nx = tile.shape
x_offset = pad_xneg
y_offset = pad_yneg
# print("tile shape:", tile.shape)
yolo_annotation = f"{label.class_id} " # {x_offset/nx} {y_offset/ny} {h_norm} {w_norm} "
yolo_annotation += " ".join(
[
f"{(x*self.image.shape[1]-(xc - x_offset))/nx:.6f} {(y*self.image.shape[0]-(yc-y_offset))/ny:.6f}"
for x, y in label.polygon
]
)
print(yolo_annotation)
new_label = Label(yolo_annotation=yolo_annotation)
yield tile_reference, tile, [new_label], metadata
def main(args):
if args.output:
args.output.mkdir(exist_ok=True, parents=True)
(args.output / "images").mkdir(exist_ok=True)
(args.output / "images-zoomed").mkdir(exist_ok=True)
(args.output / "labels").mkdir(exist_ok=True)
for image_path in (args.input / "images").glob("*.tif"):
data = ImageSplitter(
image_path=image_path,
label_path=(args.input / "labels" / image_path.stem).with_suffix(".txt"),
)
if args.split_around_label:
data = data.split_respective_to_label(padding=args.padding)
else:
data = data.split_into_tiles(patch_size=args.patch_size)
for tile_reference, tile, labels, metadata in data:
print()
print(tile_reference, tile.shape, labels, metadata) # len(labels) if labels else None)
# { debug
debug = False
if debug:
plt.figure(figsize=(10, 10 * tile.shape[0] / tile.shape[1]))
if labels is None:
plt.imshow(tile, cmap="gray")
plt.axis("off")
plt.title(f"{image_path.name} ({tile_reference})")
plt.show()
continue
print(labels[0].bbox)
# Draw annotations
out = draw_annotations(
cv2.cvtColor((tile / tile.max() * 255).astype(np.uint8), cv2.COLOR_GRAY2BGR),
[l.to_string() for l in labels],
alpha=0.1,
)
# Convert BGR -> RGB for matplotlib display
out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
plt.imshow(out_rgb)
plt.axis("off")
plt.title(f"{image_path.name} ({tile_reference})")
plt.show()
# } debug
if args.output:
# imwrite(args.output / "images" / f"{image_path.stem}_{tile_reference}.tif", tile, metadata=metadata)
scale = 5
tile_zoomed = zoom(tile, zoom=scale)
metadata["scale"] = scale
imwrite(
args.output / "images" / f"{image_path.stem}_{tile_reference}.tif",
tile_zoomed,
metadata=metadata,
imagej=True,
)
if labels is not None:
with open(args.output / "labels" / f"{image_path.stem}_{tile_reference}.txt", "w") as f:
for label in labels:
# label.offset_label(tile.shape[1], tile.shape[0])
f.write(label.to_string() + "\n")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input", type=Path)
parser.add_argument("-o", "--output", type=Path)
parser.add_argument(
"-p",
"--patch-size",
nargs=2,
type=int,
default=[2, 2],
help="Number of patches along height and width, rows and columns, respectively",
)
parser.add_argument(
"-sal",
"--split-around-label",
action="store_true",
help="If enabled, the image will be split around the label and for each label, a separate image will be created.",
)
parser.add_argument(
"--padding",
type=int,
default=67,
help="Padding around the label when splitting around the label.",
)
args = parser.parse_args()
main(args)

View File

@@ -1 +0,0 @@
../../tests/show_yolo_seg.py

View File

@@ -0,0 +1,561 @@
"""
Custom YOLO training with on-the-fly float32 conversion for 16-bit grayscale images.
This module provides a custom dataset class and training function that:
1. Load 16-bit TIFF images directly with tifffile (no PIL/cv2)
2. Convert to float32 [0-1] on-the-fly (no data loss)
3. Replicate grayscale to 3-channel RGB in memory
4. Use custom training loop to bypass Ultralytics' dataset infrastructure
5. No disk caching required
"""
import numpy as np
import tifffile
import torch
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from typing import Optional, Dict, Any, List, Tuple
from ultralytics import YOLO
import yaml
import time
from src.utils.logger import get_logger
logger = get_logger(__name__)
class Float32YOLODataset(Dataset):
"""
Custom PyTorch dataset for YOLO that loads 16-bit grayscale TIFFs as float32 RGB.
This dataset:
- Loads with tifffile (not PIL/cv2)
- Converts uint16 → float32 [0-1] (preserves full dynamic range)
- Replicates grayscale to 3 channels
- Returns torch tensors in (C, H, W) format
"""
def __init__(self, images_dir: str, labels_dir: str, img_size: int = 640):
"""
Initialize dataset.
Args:
images_dir: Directory containing images
labels_dir: Directory containing YOLO label files (.txt)
img_size: Target image size (for reference, actual resizing done by model)
"""
self.images_dir = Path(images_dir)
self.labels_dir = Path(labels_dir)
self.img_size = img_size
# Find all image files
extensions = {".tif", ".tiff", ".png", ".jpg", ".jpeg", ".bmp"}
self.image_paths = sorted(
[
p
for p in self.images_dir.rglob("*")
if p.is_file() and p.suffix.lower() in extensions
]
)
if not self.image_paths:
raise ValueError(f"No images found in {images_dir}")
logger.info(
f"Float32YOLODataset initialized with {len(self.image_paths)} images from {images_dir}"
)
def __len__(self):
return len(self.image_paths)
def _read_image(self, img_path: Path) -> np.ndarray:
"""
Read image and convert to float32 [0-1] RGB.
Returns:
numpy array, shape (H, W, 3), dtype float32, range [0, 1]
"""
# Load image with tifffile
img = tifffile.imread(str(img_path))
# Convert to float32
img = img.astype(np.float32)
# Normalize if 16-bit (values > 1.5 indicates uint16)
if img.max() > 1.5:
img = img / 65535.0
# Ensure [0, 1] range
img = np.clip(img, 0.0, 1.0)
# Convert grayscale to RGB
if img.ndim == 2:
# H,W → H,W,3
img = np.repeat(img[..., None], 3, axis=2)
elif img.ndim == 3 and img.shape[2] == 1:
# H,W,1 → H,W,3
img = np.repeat(img, 3, axis=2)
return img # float32 (H, W, 3) in [0, 1]
def _parse_label(self, label_path: Path) -> List[np.ndarray]:
"""
Parse YOLO label file with variable-length rows (segmentation polygons).
Returns:
List of numpy arrays, one per annotation
"""
if not label_path.exists():
return []
labels = []
try:
with open(label_path, "r") as f:
for line in f:
line = line.strip()
if not line:
continue
# Parse space-separated values
values = line.split()
if len(values) >= 5: # At minimum: class_id x y w h
labels.append(
np.array([float(v) for v in values], dtype=np.float32)
)
except Exception as e:
logger.warning(f"Error parsing label {label_path}: {e}")
return []
return labels
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, List[np.ndarray], str]:
"""
Get a single training sample.
Returns:
Tuple of (image_tensor, labels, filename)
- image_tensor: shape (3, H, W), dtype float32, range [0, 1]
- labels: list of numpy arrays with YOLO format labels (variable length for segmentation)
- filename: image filename
"""
img_path = self.image_paths[idx]
label_path = self.labels_dir / f"{img_path.stem}.txt"
# Load image as float32 RGB
img = self._read_image(img_path)
# Convert to tensor: (H, W, 3) → (3, H, W)
img_tensor = torch.from_numpy(img).permute(2, 0, 1).contiguous()
# Load labels (list of variable-length arrays for segmentation)
labels = self._parse_label(label_path)
return img_tensor, labels, img_path.name
def collate_fn(
batch: List[Tuple[torch.Tensor, List[np.ndarray], str]],
) -> Tuple[torch.Tensor, List[List[np.ndarray]], List[str]]:
"""
Collate function for DataLoader.
Args:
batch: List of (img_tensor, labels_list, filename) tuples
where labels_list is a list of variable-length numpy arrays
Returns:
Tuple of (stacked_images, list_of_labels_lists, list_of_filenames)
"""
imgs = [b[0] for b in batch]
labels = [b[1] for b in batch] # Each element is a list of arrays
names = [b[2] for b in batch]
# Stack images - requires same H,W
# For different sizes, implement letterbox/resize in dataset
imgs_batch = torch.stack(imgs, dim=0)
return imgs_batch, labels, names
def train_with_float32_loader(
model_path: str,
data_yaml: str,
epochs: int = 100,
imgsz: int = 640,
batch: int = 16,
patience: int = 50,
save_dir: str = "data/models",
name: str = "custom_model",
callbacks: Optional[Dict] = None,
**kwargs,
) -> Dict[str, Any]:
"""
Train YOLO model with custom Float32 dataset for 16-bit TIFF support.
Uses a custom training loop to bypass Ultralytics' dataset pipeline,
avoiding channel conversion issues.
Args:
model_path: Path to base model weights (.pt file)
data_yaml: Path to dataset YAML configuration
epochs: Number of training epochs
imgsz: Input image size
batch: Batch size
patience: Early stopping patience
save_dir: Directory to save trained model
name: Name for the training run
callbacks: Optional callback dictionary (for progress reporting)
**kwargs: Additional training arguments (lr0, freeze, device, etc.)
Returns:
Dict with training results including model paths and metrics
"""
try:
logger.info(f"Starting Float32 custom training: {name}")
logger.info(
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
)
# Parse data.yaml to get dataset paths
with open(data_yaml, "r") as f:
data_config = yaml.safe_load(f)
dataset_root = Path(data_config.get("path", Path(data_yaml).parent))
train_images = dataset_root / data_config.get("train", "train/images")
val_images = dataset_root / data_config.get("val", "val/images")
# Infer label directories
train_labels = train_images.parent / "labels"
val_labels = val_images.parent / "labels"
logger.info(f"Train images: {train_images}")
logger.info(f"Train labels: {train_labels}")
logger.info(f"Val images: {val_images}")
logger.info(f"Val labels: {val_labels}")
# Create datasets
train_dataset = Float32YOLODataset(
str(train_images), str(train_labels), img_size=imgsz
)
val_dataset = Float32YOLODataset(
str(val_images), str(val_labels), img_size=imgsz
)
# Create data loaders
train_loader = DataLoader(
train_dataset,
batch_size=batch,
shuffle=True,
num_workers=4,
pin_memory=True,
collate_fn=collate_fn,
)
val_loader = DataLoader(
val_dataset,
batch_size=batch,
shuffle=False,
num_workers=2,
pin_memory=True,
collate_fn=collate_fn,
)
# Load model
logger.info(f"Loading model from {model_path}")
ul_model = YOLO(model_path)
# Get PyTorch model
pt_model, loss_fn = _get_pytorch_model(ul_model)
# Setup device
device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu")
# Configure model args for loss function
from types import SimpleNamespace
# Required args for segmentation loss
required_args = {
"overlap_mask": True,
"mask_ratio": 4,
"task": "segment",
"single_cls": False,
"box": 7.5,
"cls": 0.5,
"dfl": 1.5,
}
if not hasattr(pt_model, "args"):
# No args - create SimpleNamespace
pt_model.args = SimpleNamespace(**required_args)
elif isinstance(pt_model.args, dict):
# Args is dict - MUST convert to SimpleNamespace for attribute access
# The loss function uses model.args.overlap_mask (attribute access)
merged = {**pt_model.args, **required_args}
pt_model.args = SimpleNamespace(**merged)
logger.info(
"Converted model.args from dict to SimpleNamespace for loss function compatibility"
)
else:
# Args is SimpleNamespace or other - set attributes
for key, value in required_args.items():
if not hasattr(pt_model.args, key):
setattr(pt_model.args, key, value)
pt_model.to(device)
pt_model.train()
logger.info(f"Training on device: {device}")
logger.info(f"PyTorch model type: {type(pt_model)}")
logger.info(f"Model args configured for segmentation loss")
# Setup optimizer
lr0 = kwargs.get("lr0", 0.01)
optimizer = torch.optim.AdamW(pt_model.parameters(), lr=lr0)
# Training loop
save_path = Path(save_dir) / name
save_path.mkdir(parents=True, exist_ok=True)
weights_dir = save_path / "weights"
weights_dir.mkdir(exist_ok=True)
best_loss = float("inf")
patience_counter = 0
for epoch in range(epochs):
epoch_start = time.time()
running_loss = 0.0
num_batches = 0
logger.info(f"Epoch {epoch+1}/{epochs} starting...")
for batch_idx, (imgs, labels_list, names) in enumerate(train_loader):
imgs = imgs.to(device) # (B, 3, H, W) float32
optimizer.zero_grad()
# Forward pass
try:
preds = pt_model(imgs)
except Exception as e:
# Try with labels
preds = pt_model(imgs, labels_list)
# Compute loss
# For Ultralytics models, the easiest approach is to construct a batch dict
# and call the model in training mode which returns preds + loss
batch_dict = {
"img": imgs, # Already on device
"batch_idx": (
torch.cat(
[
torch.full((len(lab),), i, dtype=torch.long)
for i, lab in enumerate(labels_list)
]
).to(device)
if any(len(lab) > 0 for lab in labels_list)
else torch.tensor([], dtype=torch.long, device=device)
),
"cls": (
torch.cat(
[
torch.from_numpy(lab[:, 0:1])
for lab in labels_list
if len(lab) > 0
]
).to(device)
if any(len(lab) > 0 for lab in labels_list)
else torch.tensor([], dtype=torch.float32, device=device)
),
"bboxes": (
torch.cat(
[
torch.from_numpy(lab[:, 1:5])
for lab in labels_list
if len(lab) > 0
]
).to(device)
if any(len(lab) > 0 for lab in labels_list)
else torch.tensor([], dtype=torch.float32, device=device)
),
"ori_shape": (imgs.shape[2], imgs.shape[3]), # H, W
"resized_shape": (imgs.shape[2], imgs.shape[3]),
}
# Add masks if segmentation labels exist
if any(len(lab) > 5 for lab in labels_list if len(lab) > 0):
masks = []
for lab in labels_list:
if len(lab) > 0 and lab.shape[1] > 5:
# Has segmentation points
masks.append(torch.from_numpy(lab[:, 5:]))
if masks:
batch_dict["masks"] = masks
# Call model loss (it will compute loss internally)
try:
loss_output = pt_model.loss(batch_dict, preds)
if isinstance(loss_output, (tuple, list)):
loss = loss_output[0]
else:
loss = loss_output
except Exception as e:
logger.error(f"Model loss computation failed: {e}")
# Last resort: maybe preds is already a dict with 'loss'
if isinstance(preds, dict) and "loss" in preds:
loss = preds["loss"]
else:
raise RuntimeError(f"Cannot compute loss: {e}")
# Backward pass
loss.backward()
optimizer.step()
running_loss += loss.item()
num_batches += 1
# Report progress via callback
if callbacks and "on_fit_epoch_end" in callbacks:
# Create a mock trainer object for callback
class MockTrainer:
def __init__(self, epoch):
self.epoch = epoch
self.loss_items = [loss.item()]
callbacks["on_fit_epoch_end"](MockTrainer(epoch))
epoch_loss = running_loss / max(1, num_batches)
epoch_time = time.time() - epoch_start
logger.info(
f"Epoch {epoch+1}/{epochs} completed. Avg Loss: {epoch_loss:.4f}, Time: {epoch_time:.1f}s"
)
# Save checkpoint
ckpt_path = weights_dir / f"epoch{epoch+1}.pt"
torch.save(
{
"epoch": epoch + 1,
"model_state_dict": pt_model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": epoch_loss,
},
ckpt_path,
)
# Save as last.pt
last_path = weights_dir / "last.pt"
torch.save(pt_model.state_dict(), last_path)
# Check for best model
if epoch_loss < best_loss:
best_loss = epoch_loss
patience_counter = 0
best_path = weights_dir / "best.pt"
torch.save(pt_model.state_dict(), best_path)
logger.info(f"New best model saved: {best_path}")
else:
patience_counter += 1
# Early stopping
if patience_counter >= patience:
logger.info(f"Early stopping triggered after {epoch+1} epochs")
break
logger.info("Training completed successfully")
# Format results
return {
"success": True,
"final_epoch": epoch + 1,
"metrics": {
"final_loss": epoch_loss,
"best_loss": best_loss,
},
"best_model_path": str(weights_dir / "best.pt"),
"last_model_path": str(weights_dir / "last.pt"),
"save_dir": str(save_path),
}
except Exception as e:
logger.error(f"Error during Float32 training: {e}")
import traceback
logger.error(traceback.format_exc())
raise
def _get_pytorch_model(ul_model: YOLO) -> Tuple[torch.nn.Module, Optional[callable]]:
"""
Extract PyTorch model and loss function from Ultralytics YOLO wrapper.
Args:
ul_model: Ultralytics YOLO model wrapper
Returns:
Tuple of (pytorch_model, loss_function)
"""
# Try to get the underlying PyTorch model
candidates = []
# Direct model attribute
if hasattr(ul_model, "model"):
candidates.append(ul_model.model)
# Sometimes nested
if hasattr(ul_model, "model") and hasattr(ul_model.model, "model"):
candidates.append(ul_model.model.model)
# The wrapper itself
if isinstance(ul_model, torch.nn.Module):
candidates.append(ul_model)
# Find a valid model
pt_model = None
loss_fn = None
for candidate in candidates:
if candidate is None or not isinstance(candidate, torch.nn.Module):
continue
pt_model = candidate
# Try to find loss function
if hasattr(candidate, "loss") and callable(getattr(candidate, "loss")):
loss_fn = getattr(candidate, "loss")
elif hasattr(candidate, "compute_loss") and callable(
getattr(candidate, "compute_loss")
):
loss_fn = getattr(candidate, "compute_loss")
break
if pt_model is None:
raise RuntimeError("Could not extract PyTorch model from Ultralytics wrapper")
logger.info(f"Extracted PyTorch model: {type(pt_model)}")
logger.info(
f"Loss function: {type(loss_fn) if loss_fn else 'None (will attempt fallback)'}"
)
return pt_model, loss_fn
# Compatibility function (kept for backwards compatibility)
def train_float32(model: YOLO, data_yaml: str, **train_kwargs) -> Any:
"""
Train YOLO model with Float32YOLODataset (alternative API).
Args:
model: Initialized YOLO model instance
data_yaml: Path to dataset YAML
**train_kwargs: Training parameters
Returns:
Training results dict
"""
return train_with_float32_loader(
model_path=(
model.model_path if hasattr(model, "model_path") else "yolov8s-seg.pt"
),
data_yaml=data_yaml,
**train_kwargs,
)

View File

@@ -1,157 +0,0 @@
"""Ultralytics runtime patches for 16-bit TIFF training.
Goals:
- Use `tifffile` to decode `.tif/.tiff` reliably (OpenCV can silently drop bit-depth depending on codec).
- Preserve 16-bit data through the dataloader as `uint16` tensors.
- Fix Ultralytics trainer normalization (default divides by 255) to scale `uint16` correctly.
- Avoid uint8-forcing augmentations by recommending/setting hyp values (handled by caller).
This module is intended to be imported/called **before** instantiating/using YOLO.
"""
from __future__ import annotations
from typing import Optional
from src.utils.logger import get_logger
logger = get_logger(__name__)
def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
"""Apply runtime monkey-patches to Ultralytics to better support 16-bit TIFFs.
This function is safe to call multiple times.
Args:
force: If True, re-apply patches even if already applied.
"""
# Import inside function to ensure patching occurs before YOLO model/dataset is created.
import os
import cv2
import numpy as np
# import tifffile
import torch
from src.utils.image import Image
from ultralytics.utils import patches as ul_patches
already_patched = getattr(ul_patches.imread, "__name__", "") == "tifffile_imread"
if already_patched and not force:
return
_original_imread = ul_patches.imread
def tifffile_imread(filename: str, flags: int = cv2.IMREAD_COLOR, pseudo_rgb: bool = True) -> Optional[np.ndarray]:
"""Replacement for [`ultralytics.utils.patches.imread()`](venv/lib/python3.12/site-packages/ultralytics/utils/patches.py:20).
- For `.tif/.tiff`, uses `tifffile.imread()` and preserves dtype (e.g. uint16).
- For other formats, falls back to Ultralytics' original implementation.
- Always returns HWC (3 dims). For grayscale, returns (H, W, 1) or (H, W, 3) depending on requested flags.
"""
# print("here")
# return _original_imread(filename, flags)
ext = os.path.splitext(filename)[1].lower()
if ext in (".tif", ".tiff"):
arr = Image(filename).get_qt_rgb()[:, :, :3]
# Normalize common shapes:
# - (H, W) -> (H, W, 1)
# - (C, H, W) -> (H, W, C) (heuristic)
if arr is None:
return None
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[0] < arr.shape[1]:
arr = np.transpose(arr, (1, 2, 0))
if arr.ndim == 2:
arr = arr[..., None]
# Ensure contiguous array for downstream OpenCV ops.
# logger.info(f"Loading with monkey-patched imread: {filename}")
arr = arr.astype(np.float32)
arr /= arr.max()
arr *= 2**8 - 1
arr = arr.astype(np.uint8)
# print(arr.shape, arr.dtype, any(np.isnan(arr).flatten()), np.where(np.isnan(arr)), arr.min(), arr.max())
return np.ascontiguousarray(arr)
# logger.info(f"Loading with original imread: {filename}")
return _original_imread(filename, flags)
# Patch the canonical reference.
ul_patches.imread = tifffile_imread
# Patch common module-level imports (some Ultralytics modules do `from ... import imread`).
# Importing these modules is safe and helps ensure the patched function is used.
try:
import ultralytics.data.base as _ul_base
_ul_base.imread = tifffile_imread
except Exception:
pass
try:
import ultralytics.data.loaders as _ul_loaders
_ul_loaders.imread = tifffile_imread
except Exception:
pass
# Patch trainer normalization: default divides by 255 regardless of input dtype.
from ultralytics.models.yolo.detect import train as detect_train
_orig_preprocess_batch = detect_train.DetectionTrainer.preprocess_batch
def preprocess_batch_16bit(self, batch: dict) -> dict: # type: ignore[override]
# Start from upstream behavior to keep device placement + multiscale identical,
# but replace the 255 division with dtype-aware scaling.
# logger.info(f"Preprocessing batch with monkey-patched preprocess_batch")
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")
img = batch.get("img")
if isinstance(img, torch.Tensor):
# Decide scaling denom based on dtype (avoid expensive reductions if possible).
if img.dtype == torch.uint8:
denom = 255.0
elif img.dtype == torch.uint16:
denom = 65535.0
elif img.dtype.is_floating_point:
# Assume already in 0-1 range if float.
denom = 1.0
else:
# Generic integer fallback.
try:
denom = float(torch.iinfo(img.dtype).max)
except Exception:
denom = 255.0
batch["img"] = img.float() / denom
# Multi-scale branch copied from upstream to avoid re-introducing `/255` scaling.
if getattr(self.args, "multi_scale", False):
import math
import random
import torch.nn as nn
imgs = batch["img"]
sz = (
random.randrange(int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride))
// self.stride
* self.stride
)
sf = sz / max(imgs.shape[2:])
if sf != 1:
ns = [math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]]
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
batch["img"] = imgs
return batch
detect_train.DetectionTrainer.preprocess_batch = preprocess_batch_16bit
# Tag function to make it easier to detect patch state.
setattr(detect_train.DetectionTrainer.preprocess_batch, "_ultralytics_16bit_patch", True)

View File

@@ -17,9 +17,6 @@ import matplotlib.pyplot as plt
import argparse
from pathlib import Path
import random
from shapely.geometry import LineString
from src.utils.image import Image
def parse_label_line(line):
@@ -55,55 +52,36 @@ def yolo_bbox_to_xyxy(coords, img_w, img_h):
def poly_to_pts(coords, img_w, img_h):
# coords: [x1 y1 x2 y2 ...] either normalized or absolute
if coords_are_normalized(coords[4:]):
coords = [coords[i] * (img_w if i % 2 == 0 else img_h) for i in range(len(coords))]
if coords_are_normalized(coords):
coords = [
coords[i] * (img_w if i % 2 == 0 else img_h) for i in range(len(coords))
]
pts = np.array(coords, dtype=np.int32).reshape(-1, 2)
return pts
def random_color_for_class(cls):
random.seed(cls) # deterministic per class
return (
0,
0,
255,
) # tuple(int(x) for x in np.array([random.randint(0, 255) for _ in range(3)]))
return tuple(int(x) for x in np.array([random.randint(0, 255) for _ in range(3)]))
def draw_annotations(img, labels, alpha=0.4, draw_bbox_for_poly=True):
# img: BGR numpy array
overlay = img.copy()
h, w = img.shape[:2]
for line in labels:
if isinstance(line, str):
cls, coords = parse_label_line(line)
if isinstance(line, tuple):
cls, coords = line
for cls, coords in labels:
if not coords:
continue
# polygon case (>=6 coordinates)
if len(coords) >= 6:
pts = poly_to_pts(coords, w, h)
color = random_color_for_class(cls)
x1, y1, x2, y2 = yolo_bbox_to_xyxy(coords[:4], w, h)
print(x1, y1, x2, y2)
cv2.rectangle(img, (x1, y1), (x2, y2), color, 1)
pts = poly_to_pts(coords[4:], w, h)
# line = LineString(pts)
# # Buffer distance in pixels
# buffered = line.buffer(3, cap_style=2, join_style=2)
# coords = np.array(buffered.exterior.coords, dtype=np.int32)
# cv2.fillPoly(overlay, [coords], color=(255, 255, 255))
# fill on overlay
cv2.fillPoly(overlay, [pts], color)
# outline on base image
cv2.polylines(img, [pts], isClosed=True, color=color, thickness=1)
cv2.polylines(img, [pts], isClosed=True, color=color, thickness=2)
# put class text at first point
x, y = int(pts[0, 0]), int(pts[0, 1]) - 6
if 0:
cv2.putText(
img,
str(cls),
@@ -114,7 +92,9 @@ def draw_annotations(img, labels, alpha=0.4, draw_bbox_for_poly=True):
2,
cv2.LINE_AA,
)
if draw_bbox_for_poly:
x, y, w_box, h_box = cv2.boundingRect(pts)
cv2.rectangle(img, (x, y), (x + w_box, y + h_box), color, 1)
# YOLO bbox case (4 coords)
elif len(coords) == 4:
x1, y1, x2, y2 = yolo_bbox_to_xyxy(coords, w, h)
@@ -153,21 +133,21 @@ def load_labels_file(label_path):
def main():
parser = argparse.ArgumentParser(description="Show YOLO segmentation / polygon annotations")
parser = argparse.ArgumentParser(
description="Show YOLO segmentation / polygon annotations"
)
parser.add_argument("image", type=str, help="Path to image file")
parser.add_argument("--labels", type=str, help="Path to YOLO label file (polygons)")
parser.add_argument("--alpha", type=float, default=0.4, help="Polygon fill alpha (0..1)")
parser.add_argument("--no-bbox", action="store_true", help="Don't draw bounding boxes for polygons")
parser.add_argument("labels", type=str, help="Path to YOLO label file (polygons)")
parser.add_argument(
"--alpha", type=float, default=0.4, help="Polygon fill alpha (0..1)"
)
parser.add_argument(
"--no-bbox", action="store_true", help="Don't draw bounding boxes for polygons"
)
args = parser.parse_args()
print(args)
img_path = Path(args.image)
if args.labels:
lbl_path = Path(args.labels)
else:
lbl_path = img_path.with_suffix(".txt")
lbl_path = Path(str(lbl_path).replace("images", "labels"))
if not img_path.exists():
print("Image not found:", img_path)
@@ -176,9 +156,7 @@ def main():
print("Label file not found:", lbl_path)
sys.exit(1)
# img = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
img = (Image(img_path).get_qt_rgb() * 255).astype(np.uint8)
img = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
if img is None:
print("Could not load image:", img_path)
sys.exit(1)
@@ -187,42 +165,15 @@ def main():
if not labels:
print("No labels parsed from", lbl_path)
# continue and just show image
out = draw_annotations(img.copy(), labels, alpha=args.alpha, draw_bbox_for_poly=(not args.no_bbox))
out_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(10, 10 * out.shape[0] / out.shape[1]))
if 0:
plt.imshow(out_rgb.transpose(1, 0, 2))
else:
plt.imshow(out_rgb)
for label in labels:
lclass, coords = label
# print(lclass, coords)
bbox = coords[:4]
# print("bbox", bbox)
bbox = np.array(bbox) * np.array([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
yc, xc, h, w = bbox
# print("bbox", bbox)
# polyline = np.array(coords[4:]).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
polyline = np.array(coords).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
# print("pl", coords[4:])
# print("pl", polyline)
# Convert BGR -> RGB for matplotlib display
# out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
# out_rgb = Image()
plt.plot(polyline[:, 0], polyline[:, 1], "y", linewidth=2)
if 0:
plt.plot(
[yc - h / 2, yc - h / 2, yc + h / 2, yc + h / 2, yc - h / 2],
[xc - w / 2, xc + w / 2, xc + w / 2, xc - w / 2, xc - w / 2],
"r",
linewidth=2,
out = draw_annotations(
img.copy(), labels, alpha=args.alpha, draw_bbox_for_poly=(not args.no_bbox)
)
# plt.axis("off")
# Convert BGR -> RGB for matplotlib display
out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(10, 10 * out.shape[0] / out.shape[1]))
plt.imshow(out_rgb)
plt.axis("off")
plt.title(f"{img_path.name} ({lbl_path.name})")
plt.show()

View File

@@ -0,0 +1,109 @@
#!/usr/bin/env python3
"""
Test script for 16-bit TIFF loading and normalization.
"""
import numpy as np
import tifffile
from pathlib import Path
import tempfile
import sys
import os
# Add parent directory to path to import modules
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.utils.image import Image
def create_test_16bit_tiff(output_path: str) -> str:
"""Create a test 16-bit grayscale TIFF file.
Args:
output_path: Path where to save the test TIFF
Returns:
Path to the created TIFF file
"""
# Create a 16-bit grayscale test image (100x100)
# With values ranging from 0 to 65535 (full 16-bit range)
height, width = 100, 100
# Create a gradient pattern
test_data = np.zeros((height, width), dtype=np.uint16)
for i in range(height):
for j in range(width):
# Create a diagonal gradient
test_data[i, j] = int((i + j) / (height + width - 2) * 65535)
# Save as TIFF
tifffile.imwrite(output_path, test_data)
print(f"Created test 16-bit TIFF: {output_path}")
print(f" Shape: {test_data.shape}")
print(f" Dtype: {test_data.dtype}")
print(f" Min value: {test_data.min()}")
print(f" Max value: {test_data.max()}")
return output_path
def test_image_loading():
"""Test loading 16-bit TIFF with the Image class."""
print("\n=== Testing Image Loading ===")
# Create temporary test file
with tempfile.NamedTemporaryFile(suffix=".tif", delete=False) as tmp:
test_path = tmp.name
try:
# Create test image
create_test_16bit_tiff(test_path)
# Load with Image class
print("\nLoading with Image class...")
img = Image(test_path)
print(f"Successfully loaded image:")
print(f" Width: {img.width}")
print(f" Height: {img.height}")
print(f" Channels: {img.channels}")
print(f" Dtype: {img.dtype}")
print(f" Format: {img.format}")
# Test normalization
print("\nTesting normalization to float32 [0-1]...")
normalized = img.to_normalized_float32()
print(f"Normalized image:")
print(f" Shape: {normalized.shape}")
print(f" Dtype: {normalized.dtype}")
print(f" Min value: {normalized.min():.6f}")
print(f" Max value: {normalized.max():.6f}")
print(f" Mean value: {normalized.mean():.6f}")
# Verify normalization
assert normalized.dtype == np.float32, "Dtype should be float32"
assert (
0.0 <= normalized.min() <= normalized.max() <= 1.0
), "Values should be in [0, 1]"
print("\n✓ All tests passed!")
return True
except Exception as e:
print(f"\n✗ Test failed with error: {e}")
import traceback
traceback.print_exc()
return False
finally:
# Cleanup
if os.path.exists(test_path):
os.remove(test_path)
print(f"\nCleaned up test file: {test_path}")
if __name__ == "__main__":
success = test_image_loading()
sys.exit(0 if success else 1)

View File

@@ -0,0 +1,211 @@
"""
Test script for Float32 on-the-fly loading for 16-bit TIFFs.
This test verifies that:
1. Float32YOLODataset can load 16-bit TIFF files
2. Images are converted to float32 [0-1] in memory
3. Grayscale is replicated to 3 channels (RGB)
4. No disk caching is used
5. Full 16-bit precision is preserved
"""
import tempfile
import numpy as np
import tifffile
from pathlib import Path
import yaml
def create_test_dataset():
"""Create a minimal test dataset with 16-bit TIFF images."""
temp_dir = Path(tempfile.mkdtemp())
dataset_dir = temp_dir / "test_dataset"
# Create directory structure
train_images = dataset_dir / "train" / "images"
train_labels = dataset_dir / "train" / "labels"
train_images.mkdir(parents=True, exist_ok=True)
train_labels.mkdir(parents=True, exist_ok=True)
# Create a 16-bit TIFF test image
img_16bit = np.random.randint(0, 65536, (100, 100), dtype=np.uint16)
img_path = train_images / "test_image.tif"
tifffile.imwrite(str(img_path), img_16bit)
# Create a dummy label file
label_path = train_labels / "test_image.txt"
with open(label_path, "w") as f:
f.write("0 0.5 0.5 0.2 0.2\n") # class_id x_center y_center width height
# Create data.yaml
data_yaml = {
"path": str(dataset_dir),
"train": "train/images",
"val": "train/images", # Use same for val in test
"names": {0: "object"},
"nc": 1,
}
yaml_path = dataset_dir / "data.yaml"
with open(yaml_path, "w") as f:
yaml.safe_dump(data_yaml, f)
print(f"✓ Created test dataset at: {dataset_dir}")
print(f" - Image: {img_path} (shape={img_16bit.shape}, dtype={img_16bit.dtype})")
print(f" - Min value: {img_16bit.min()}, Max value: {img_16bit.max()}")
print(f" - data.yaml: {yaml_path}")
return dataset_dir, img_path, img_16bit
def test_float32_dataset():
"""Test the Float32YOLODataset class directly."""
print("\n=== Testing Float32YOLODataset ===\n")
try:
from src.utils.train_ultralytics_float import Float32YOLODataset
print("✓ Successfully imported Float32YOLODataset")
except ImportError as e:
print(f"✗ Failed to import Float32YOLODataset: {e}")
return False
# Create test dataset
dataset_dir, img_path, original_img = create_test_dataset()
try:
# Initialize the dataset
print("\nInitializing Float32YOLODataset...")
dataset = Float32YOLODataset(
images_dir=str(dataset_dir / "train" / "images"),
labels_dir=str(dataset_dir / "train" / "labels"),
img_size=640,
)
print(f"✓ Float32YOLODataset initialized with {len(dataset)} images")
# Get an item
if len(dataset) > 0:
print("\nGetting first item...")
img_tensor, labels, filename = dataset[0]
print(f"✓ Item retrieved successfully")
print(f" - Image tensor shape: {img_tensor.shape}")
print(f" - Image tensor dtype: {img_tensor.dtype}")
print(f" - Value range: [{img_tensor.min():.6f}, {img_tensor.max():.6f}]")
print(f" - Filename: {filename}")
print(f" - Labels: {len(labels)} annotations")
if labels:
print(
f" - First label shape: {labels[0].shape if len(labels) > 0 else 'N/A'}"
)
# Verify it's float32
if img_tensor.dtype == torch.float32:
print("✓ Correct dtype: float32")
else:
print(f"✗ Wrong dtype: {img_tensor.dtype} (expected float32)")
return False
# Verify it's 3-channel in correct format (C, H, W)
if len(img_tensor.shape) == 3 and img_tensor.shape[0] == 3:
print(
f"✓ Correct format: (C, H, W) = {img_tensor.shape} with 3 channels"
)
else:
print(f"✗ Wrong shape: {img_tensor.shape} (expected (3, H, W))")
return False
# Verify it's in [0, 1] range
if 0.0 <= img_tensor.min() and img_tensor.max() <= 1.0:
print("✓ Values in correct range: [0, 1]")
else:
print(
f"✗ Values out of range: [{img_tensor.min()}, {img_tensor.max()}]"
)
return False
# Verify precision (should have many unique values)
unique_values = len(torch.unique(img_tensor))
print(f" - Unique values: {unique_values}")
if unique_values > 256:
print(f"✓ High precision maintained ({unique_values} > 256 levels)")
else:
print(f"⚠ Low precision: only {unique_values} unique values")
print("\n✓ All Float32YOLODataset tests passed!")
return True
else:
print("✗ No items in dataset")
return False
except Exception as e:
print(f"✗ Error during testing: {e}")
import traceback
traceback.print_exc()
return False
def test_integration():
"""Test integration with train_with_float32_loader."""
print("\n=== Testing Integration with train_with_float32_loader ===\n")
# Create test dataset
dataset_dir, img_path, original_img = create_test_dataset()
data_yaml = dataset_dir / "data.yaml"
print(f"\nTest dataset ready at: {data_yaml}")
print("\nTo test full training, run:")
print(f" from src.utils.train_ultralytics_float import train_with_float32_loader")
print(f" results = train_with_float32_loader(")
print(f" model_path='yolov8n-seg.pt',")
print(f" data_yaml='{data_yaml}',")
print(f" epochs=1,")
print(f" batch=1,")
print(f" imgsz=640")
print(f" )")
print("\nThis will use custom training loop with Float32YOLODataset")
return True
def main():
"""Run all tests."""
import torch # Import here to ensure torch is available
print("=" * 70)
print("Float32 Training Loader Test Suite")
print("=" * 70)
results = []
# Test 1: Float32YOLODataset
results.append(("Float32YOLODataset", test_float32_dataset()))
# Test 2: Integration check
results.append(("Integration Check", test_integration()))
# Summary
print("\n" + "=" * 70)
print("Test Summary")
print("=" * 70)
for test_name, passed in results:
status = "✓ PASSED" if passed else "✗ FAILED"
print(f"{status}: {test_name}")
all_passed = all(passed for _, passed in results)
print("=" * 70)
if all_passed:
print("✓ All tests passed!")
else:
print("✗ Some tests failed")
print("=" * 70)
return all_passed
if __name__ == "__main__":
import sys
import torch # Make torch available
success = main()
sys.exit(0 if success else 1)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,142 @@
#!/usr/bin/env python3
"""
Test script for training dataset preparation with 16-bit TIFFs.
"""
import numpy as np
import tifffile
from pathlib import Path
import tempfile
import sys
import os
import shutil
# Add parent directory to path to import modules
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.utils.image import Image
def test_float32_3ch_conversion():
"""Test conversion of 16-bit TIFF to 16-bit RGB PNG."""
print("\n=== Testing 16-bit RGB PNG Conversion ===")
# Create temporary directory structure
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = Path(tmpdir)
src_dir = tmpdir / "original"
dst_dir = tmpdir / "converted"
src_dir.mkdir()
dst_dir.mkdir()
# Create test 16-bit TIFF
test_data = np.zeros((100, 100), dtype=np.uint16)
for i in range(100):
for j in range(100):
test_data[i, j] = int((i + j) / 198 * 65535)
test_file = src_dir / "test_16bit.tif"
tifffile.imwrite(test_file, test_data)
print(f"Created test 16-bit TIFF: {test_file}")
print(f" Shape: {test_data.shape}")
print(f" Dtype: {test_data.dtype}")
print(f" Range: [{test_data.min()}, {test_data.max()}]")
# Simulate the conversion process (matching training_tab.py)
print("\nConverting to 16-bit RGB PNG using PIL merge...")
img_obj = Image(test_file)
from PIL import Image as PILImage
# Get uint16 data
uint16_data = img_obj.data
# Use PIL's merge method with 'I;16' channels (proper way for 16-bit RGB)
if len(uint16_data.shape) == 2:
# Grayscale - replicate to RGB
r_img = PILImage.fromarray(uint16_data, mode="I;16")
g_img = PILImage.fromarray(uint16_data, mode="I;16")
b_img = PILImage.fromarray(uint16_data, mode="I;16")
else:
r_img = PILImage.fromarray(uint16_data[:, :, 0], mode="I;16")
g_img = PILImage.fromarray(
(
uint16_data[:, :, 1]
if uint16_data.shape[2] > 1
else uint16_data[:, :, 0]
),
mode="I;16",
)
b_img = PILImage.fromarray(
(
uint16_data[:, :, 2]
if uint16_data.shape[2] > 2
else uint16_data[:, :, 0]
),
mode="I;16",
)
# Merge channels into RGB
rgb_img = PILImage.merge("RGB", (r_img, g_img, b_img))
# Save as PNG
output_file = dst_dir / "test_16bit_rgb.png"
rgb_img.save(output_file)
print(f"Saved 16-bit RGB PNG: {output_file}")
print(f" PIL mode after merge: {rgb_img.mode}")
# Verify the output - Load with OpenCV (as YOLO does)
import cv2
loaded = cv2.imread(str(output_file), cv2.IMREAD_UNCHANGED)
print(f"\nVerifying output (loaded with OpenCV):")
print(f" Shape: {loaded.shape}")
print(f" Dtype: {loaded.dtype}")
print(f" Channels: {loaded.shape[2] if len(loaded.shape) == 3 else 1}")
print(f" Range: [{loaded.min()}, {loaded.max()}]")
print(f" Unique values: {len(np.unique(loaded[:,:,0]))}")
# Assertions
assert loaded.dtype == np.uint16, f"Expected uint16, got {loaded.dtype}"
assert loaded.shape[2] == 3, f"Expected 3 channels, got {loaded.shape[2]}"
assert (
loaded.min() >= 0 and loaded.max() <= 65535
), f"Expected [0,65535] range, got [{loaded.min()}, {loaded.max()}]"
# Verify all channels are identical (replicated grayscale)
assert np.array_equal(
loaded[:, :, 0], loaded[:, :, 1]
), "Channel 0 and 1 should be identical"
assert np.array_equal(
loaded[:, :, 0], loaded[:, :, 2]
), "Channel 0 and 2 should be identical"
# Verify no data loss
unique_vals = len(np.unique(loaded[:, :, 0]))
print(f"\n Precision check:")
print(f" Unique values in channel: {unique_vals}")
print(f" Source unique values: {len(np.unique(test_data))}")
assert unique_vals == len(
np.unique(test_data)
), f"Expected {len(np.unique(test_data))} unique values, got {unique_vals}"
print("\n✓ All conversion tests passed!")
print(" - uint16 dtype preserved")
print(" - 3 channels created")
print(" - Range [0-65535] maintained")
print(" - No precision loss from conversion")
print(" - Channels properly replicated")
return True
if __name__ == "__main__":
try:
success = test_float32_3ch_conversion()
sys.exit(0 if success else 1)
except Exception as e:
print(f"\n✗ Test failed with error: {e}")
import traceback
traceback.print_exc()
sys.exit(1)

View File

@@ -0,0 +1,150 @@
#!/usr/bin/env python3
"""
Test script for YOLO preprocessing of 16-bit TIFF images with float32 passthrough.
Verifies that no uint8 conversion occurs and data is preserved.
"""
import numpy as np
import tifffile
from pathlib import Path
import tempfile
import sys
import os
# Add parent directory to path to import modules
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.model.yolo_wrapper import YOLOWrapper
def create_test_16bit_tiff(output_path: str) -> str:
"""Create a test 16-bit grayscale TIFF file.
Args:
output_path: Path where to save the test TIFF
Returns:
Path to the created TIFF file
"""
# Create a 16-bit grayscale test image (200x200)
# With specific values to test precision preservation
height, width = 200, 200
# Create a gradient pattern with the full 16-bit range
test_data = np.zeros((height, width), dtype=np.uint16)
for i in range(height):
for j in range(width):
# Create a diagonal gradient using full 16-bit range
test_data[i, j] = int((i + j) / (height + width - 2) * 65535)
# Save as TIFF
tifffile.imwrite(output_path, test_data)
print(f"Created test 16-bit TIFF: {output_path}")
print(f" Shape: {test_data.shape}")
print(f" Dtype: {test_data.dtype}")
print(f" Min value: {test_data.min()}")
print(f" Max value: {test_data.max()}")
print(
f" Sample values: {test_data[50, 50]}, {test_data[100, 100]}, {test_data[150, 150]}"
)
return output_path
def test_float32_passthrough():
"""Test that 16-bit TIFF preprocessing passes float32 directly without uint8 conversion."""
print("\n=== Testing Float32 Passthrough (NO uint8) ===")
# Create temporary test file
with tempfile.NamedTemporaryFile(suffix=".tif", delete=False) as tmp:
test_path = tmp.name
try:
# Create test image
create_test_16bit_tiff(test_path)
# Create YOLOWrapper instance
print("\nTesting YOLOWrapper._prepare_source() for float32 passthrough...")
wrapper = YOLOWrapper()
# Call _prepare_source to preprocess the image
prepared_source, cleanup_path = wrapper._prepare_source(test_path)
print(f"\nPreprocessing result:")
print(f" Original path: {test_path}")
print(f" Prepared source type: {type(prepared_source)}")
# Verify it returns a numpy array (not a file path)
if isinstance(prepared_source, np.ndarray):
print(
f"\n✓ SUCCESS: Prepared source is a numpy array (float32 passthrough)"
)
print(f" Shape: {prepared_source.shape}")
print(f" Dtype: {prepared_source.dtype}")
print(f" Min value: {prepared_source.min():.6f}")
print(f" Max value: {prepared_source.max():.6f}")
print(f" Mean value: {prepared_source.mean():.6f}")
# Verify it's float32 in [0, 1] range
assert (
prepared_source.dtype == np.float32
), f"Expected float32, got {prepared_source.dtype}"
assert (
0.0 <= prepared_source.min() <= prepared_source.max() <= 1.0
), f"Expected values in [0, 1], got [{prepared_source.min()}, {prepared_source.max()}]"
# Verify it has 3 channels (RGB)
assert (
prepared_source.shape[2] == 3
), f"Expected 3 channels (RGB), got {prepared_source.shape[2]}"
# Verify no quantization to 256 levels (would happen with uint8 conversion)
unique_values = len(np.unique(prepared_source))
print(f" Unique values: {unique_values}")
# With float32, we should have much more than 256 unique values
if unique_values > 256:
print(f"\n✓ SUCCESS: Data has {unique_values} unique values (> 256)")
print(f" This confirms NO uint8 quantization occurred!")
else:
print(f"\n✗ WARNING: Data has only {unique_values} unique values")
print(f" This might indicate uint8 quantization happened")
# Sample some values to show precision
print(f"\n Sample normalized values:")
print(f" [50, 50]: {prepared_source[50, 50, 0]:.8f}")
print(f" [100, 100]: {prepared_source[100, 100, 0]:.8f}")
print(f" [150, 150]: {prepared_source[150, 150, 0]:.8f}")
# No cleanup needed since we returned array directly
assert (
cleanup_path is None
), "Cleanup path should be None for float32 pass through"
print("\n✓ All float32 passthrough tests passed!")
return True
else:
print(f"\n✗ FAILED: Prepared source is a file path: {prepared_source}")
print(f" This means data was saved to disk, not passed as float32 array")
if cleanup_path and os.path.exists(cleanup_path):
os.remove(cleanup_path)
return False
except Exception as e:
print(f"\n✗ Test failed with error: {e}")
import traceback
traceback.print_exc()
return False
finally:
# Cleanup
if os.path.exists(test_path):
os.remove(test_path)
print(f"\nCleaned up test file: {test_path}")
if __name__ == "__main__":
success = test_float32_passthrough()
sys.exit(0 if success else 1)

View File

@@ -0,0 +1,126 @@
#!/usr/bin/env python3
"""
Test script for YOLO preprocessing of 16-bit TIFF images.
"""
import numpy as np
import tifffile
from pathlib import Path
import tempfile
import sys
import os
# Add parent directory to path to import modules
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.model.yolo_wrapper import YOLOWrapper
from src.utils.image import Image
from PIL import Image as PILImage
def create_test_16bit_tiff(output_path: str) -> str:
"""Create a test 16-bit grayscale TIFF file.
Args:
output_path: Path where to save the test TIFF
Returns:
Path to the created TIFF file
"""
# Create a 16-bit grayscale test image (200x200)
# With values ranging from 0 to 65535 (full 16-bit range)
height, width = 200, 200
# Create a gradient pattern
test_data = np.zeros((height, width), dtype=np.uint16)
for i in range(height):
for j in range(width):
# Create a diagonal gradient
test_data[i, j] = int((i + j) / (height + width - 2) * 65535)
# Save as TIFF
tifffile.imwrite(output_path, test_data)
print(f"Created test 16-bit TIFF: {output_path}")
print(f" Shape: {test_data.shape}")
print(f" Dtype: {test_data.dtype}")
print(f" Min value: {test_data.min()}")
print(f" Max value: {test_data.max()}")
return output_path
def test_yolo_preprocessing():
"""Test YOLO preprocessing of 16-bit TIFF images."""
print("\n=== Testing YOLO Preprocessing of 16-bit TIFF ===")
# Create temporary test file
with tempfile.NamedTemporaryFile(suffix=".tif", delete=False) as tmp:
test_path = tmp.name
try:
# Create test image
create_test_16bit_tiff(test_path)
# Create YOLOWrapper instance (no actual model loading needed for this test)
print("\nTesting YOLOWrapper._prepare_source()...")
wrapper = YOLOWrapper()
# Call _prepare_source to preprocess the image
prepared_path, cleanup_path = wrapper._prepare_source(test_path)
print(f"\nPreprocessing complete:")
print(f" Original path: {test_path}")
print(f" Prepared path: {prepared_path}")
print(f" Cleanup path: {cleanup_path}")
# Verify the prepared image exists
assert os.path.exists(prepared_path), "Prepared image should exist"
# Load the prepared image and verify it's uint8 RGB
prepared_img = PILImage.open(prepared_path)
print(f"\nPrepared image properties:")
print(f" Mode: {prepared_img.mode}")
print(f" Size: {prepared_img.size}")
print(f" Format: {prepared_img.format}")
# Convert to numpy to check values
img_array = np.array(prepared_img)
print(f" Shape: {img_array.shape}")
print(f" Dtype: {img_array.dtype}")
print(f" Min value: {img_array.min()}")
print(f" Max value: {img_array.max()}")
print(f" Mean value: {img_array.mean():.2f}")
# Verify it's RGB uint8
assert prepared_img.mode == "RGB", "Prepared image should be RGB"
assert img_array.dtype == np.uint8, "Prepared image should be uint8"
assert img_array.shape[2] == 3, "Prepared image should have 3 channels"
assert (
0 <= img_array.min() <= img_array.max() <= 255
), "Values should be in [0, 255]"
# Cleanup prepared file if needed
if cleanup_path and os.path.exists(cleanup_path):
os.remove(cleanup_path)
print(f"\nCleaned up prepared image: {cleanup_path}")
print("\n✓ All YOLO preprocessing tests passed!")
return True
except Exception as e:
print(f"\n✗ Test failed with error: {e}")
import traceback
traceback.print_exc()
return False
finally:
# Cleanup
if os.path.exists(test_path):
os.remove(test_path)
print(f"Cleaned up test file: {test_path}")
if __name__ == "__main__":
success = test_yolo_preprocessing()
sys.exit(0 if success else 1)