Adding standalone training script and update

This commit is contained in:
2025-12-13 09:28:24 +02:00
parent 908e9a5b82
commit aec0fbf83c
8 changed files with 1434 additions and 290 deletions

View File

@@ -3,14 +3,11 @@ Training tab for the microscopy object detection application.
Handles model training with YOLO.
"""
import hashlib
import shutil
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import tifffile
import yaml
from PySide6.QtCore import Qt, QThread, Signal
from PySide6.QtWidgets import (
@@ -949,9 +946,6 @@ class TrainingTab(QWidget):
for msg in split_messages:
self._append_training_log(msg)
if dataset_yaml:
self._clear_rgb_cache_for_dataset(dataset_yaml)
def _export_labels_for_split(
self,
split_name: str,
@@ -1166,49 +1160,6 @@ class TrainingTab(QWidget):
return 1.0
return value
def _prepare_dataset_for_training(
self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None
) -> Path:
"""Prepare dataset for training.
For 16-bit TIFF files: creates 3-channel float32 TIFF versions for training.
This is necessary because YOLO's training dataloader loads images directly
from disk and expects 3 channels.
"""
dataset_info = dataset_info or (
self.selected_dataset
if self.selected_dataset
and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
else self._parse_dataset_yaml(dataset_yaml)
)
train_split = dataset_info.get("splits", {}).get("train") or {}
images_path_str = train_split.get("path")
if not images_path_str:
return dataset_yaml
images_path = Path(images_path_str)
if not images_path.exists():
return dataset_yaml
# Check if dataset has 16-bit TIFF files that need 3-channel conversion
if not self._dataset_has_16bit_tiff(images_path):
return dataset_yaml
cache_root = self._get_float32_cache_root(dataset_yaml)
float32_yaml = cache_root / "data.yaml"
if float32_yaml.exists():
self._append_training_log(
f"Detected 16-bit TIFF dataset; reusing float32 3-channel cache at {cache_root}"
)
return float32_yaml
self._append_training_log(
f"Detected 16-bit TIFF dataset; creating float32 3-channel cache at {cache_root}"
)
self._build_float32_dataset(cache_root, dataset_info)
return float32_yaml
def _compose_stage_plan(self, params: Dict[str, Any]) -> List[Dict[str, Any]]:
two_stage = params.get("two_stage") or {}
base_stage = {
@@ -1293,140 +1244,6 @@ class TrainingTab(QWidget):
f"{stage_label}: epochs={epochs}, lr0={lr0}, patience={patience}, freeze={freeze}"
)
def _get_float32_cache_root(self, dataset_yaml: Path) -> Path:
"""Get cache directory for float32 3-channel converted datasets."""
cache_base = Path("data/datasets/_float32_cache")
cache_base.mkdir(parents=True, exist_ok=True)
key = hashlib.md5(str(dataset_yaml.parent.resolve()).encode()).hexdigest()[:8]
return cache_base / f"{dataset_yaml.parent.name}_{key}"
def _clear_rgb_cache_for_dataset(self, dataset_yaml: Path):
"""Clear float32 cache for dataset."""
cache_root = self._get_float32_cache_root(dataset_yaml)
if cache_root.exists():
try:
shutil.rmtree(cache_root)
logger.debug(f"Removed float32 cache at {cache_root}")
except OSError as exc:
logger.warning(f"Failed to remove float32 cache {cache_root}: {exc}")
def _dataset_has_16bit_tiff(self, images_dir: Path) -> bool:
"""Check if dataset contains 16-bit TIFF files."""
sample_image = self._find_first_image(images_dir)
if not sample_image:
return False
try:
if sample_image.suffix.lower() not in [".tif", ".tiff"]:
return False
img = Image(sample_image)
return img.dtype == np.uint16
except Exception as exc:
logger.warning(f"Failed to inspect image {sample_image}: {exc}")
return False
def _find_first_image(self, directory: Path) -> Optional[Path]:
"""Find first image in directory."""
if not directory.exists():
return None
for path in directory.rglob("*"):
if path.is_file() and path.suffix.lower() in self.allowed_extensions:
return path
return None
def _build_float32_dataset(self, cache_root: Path, dataset_info: Dict[str, Any]):
"""Build float32 3-channel version of 16-bit TIFF dataset."""
if cache_root.exists():
shutil.rmtree(cache_root)
cache_root.mkdir(parents=True, exist_ok=True)
splits = dataset_info.get("splits", {})
for split_name in ("train", "val", "test"):
split_entry = splits.get(split_name)
if not split_entry:
continue
images_src = Path(split_entry.get("path", ""))
if not images_src.exists():
continue
images_dst = cache_root / split_name / "images"
self._convert_16bit_to_float32_3ch(images_src, images_dst)
labels_src = self._infer_labels_dir(images_src)
if labels_src.exists():
labels_dst = cache_root / split_name / "labels"
self._copy_labels(labels_src, labels_dst)
class_names = dataset_info.get("class_names") or []
names_map = {idx: name for idx, name in enumerate(class_names)}
num_classes = dataset_info.get("num_classes") or len(class_names)
yaml_payload: Dict[str, Any] = {
"path": cache_root.as_posix(),
"names": names_map,
"nc": num_classes,
}
for split_name in ("train", "val", "test"):
images_dir = cache_root / split_name / "images"
if images_dir.exists():
yaml_payload[split_name] = f"{split_name}/images"
with open(cache_root / "data.yaml", "w", encoding="utf-8") as handle:
yaml.safe_dump(yaml_payload, handle, sort_keys=False)
def _convert_16bit_to_float32_3ch(self, src_dir: Path, dst_dir: Path):
"""Convert 16-bit TIFF images to float32 [0-1] 3-channel TIFFs.
This preserves the full dynamic range (no uint8 conversion) while
creating the 3-channel format that YOLO training expects.
"""
for src in src_dir.rglob("*"):
if not src.is_file() or src.suffix.lower() not in self.allowed_extensions:
continue
relative = src.relative_to(src_dir)
dst = dst_dir / relative.with_suffix(".tif")
dst.parent.mkdir(parents=True, exist_ok=True)
try:
img_obj = Image(src)
# Check if it's a 16-bit TIFF
is_16bit_tiff = (
src.suffix.lower() in [".tif", ".tiff"]
and img_obj.dtype == np.uint16
)
if is_16bit_tiff:
# Convert to float32 [0-1]
float_data = img_obj.to_normalized_float32()
# Replicate to 3 channels
if len(float_data.shape) == 2:
# H,W → H,W,3
float_3ch = np.stack([float_data] * 3, axis=-1)
elif len(float_data.shape) == 3 and float_data.shape[2] == 1:
# H,W,1 → H,W,3
float_3ch = np.repeat(float_data, 3, axis=2)
else:
# Already multi-channel
float_3ch = float_data
# Save as float32 TIFF (preserves full precision)
tifffile.imwrite(dst, float_3ch.astype(np.float32))
logger.debug(f"Converted {src} to float32 3-channel TIFF at {dst}")
else:
# For non-16-bit images, just copy
shutil.copy2(src, dst)
except Exception as exc:
logger.warning(f"Failed to convert {src}: {exc}")
def _copy_labels(self, labels_src: Path, labels_dst: Path):
label_files = list(labels_src.rglob("*.txt"))
for label_file in label_files:
relative = label_file.relative_to(labels_src)
dst = labels_dst / relative
dst.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(label_file, dst)
def _infer_labels_dir(self, images_dir: Path) -> Path:
return images_dir.parent / "labels"
@@ -1514,11 +1331,9 @@ class TrainingTab(QWidget):
self.training_log.clear()
self._export_labels_from_database(dataset_info)
dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info)
if dataset_to_use != dataset_path:
self._append_training_log(
f"Using float32 3-channel dataset at {dataset_to_use.parent}"
)
self._append_training_log(
"Using Float32 on-the-fly loader for 16-bit TIFF support (no disk caching)"
)
params = self._collect_training_params()
stage_plan = self._compose_stage_plan(params)
@@ -1544,7 +1359,7 @@ class TrainingTab(QWidget):
self._set_training_state(True)
self.training_worker = TrainingWorker(
data_yaml=dataset_to_use.as_posix(),
data_yaml=dataset_path.as_posix(),
base_model=params["base_model"],
epochs=params["epochs"],
batch=params["batch"],