Adding standalone training script and update
This commit is contained in:
@@ -3,14 +3,11 @@ Training tab for the microscopy object detection application.
|
||||
Handles model training with YOLO.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import tifffile
|
||||
import yaml
|
||||
from PySide6.QtCore import Qt, QThread, Signal
|
||||
from PySide6.QtWidgets import (
|
||||
@@ -949,9 +946,6 @@ class TrainingTab(QWidget):
|
||||
for msg in split_messages:
|
||||
self._append_training_log(msg)
|
||||
|
||||
if dataset_yaml:
|
||||
self._clear_rgb_cache_for_dataset(dataset_yaml)
|
||||
|
||||
def _export_labels_for_split(
|
||||
self,
|
||||
split_name: str,
|
||||
@@ -1166,49 +1160,6 @@ class TrainingTab(QWidget):
|
||||
return 1.0
|
||||
return value
|
||||
|
||||
def _prepare_dataset_for_training(
|
||||
self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None
|
||||
) -> Path:
|
||||
"""Prepare dataset for training.
|
||||
|
||||
For 16-bit TIFF files: creates 3-channel float32 TIFF versions for training.
|
||||
This is necessary because YOLO's training dataloader loads images directly
|
||||
from disk and expects 3 channels.
|
||||
"""
|
||||
dataset_info = dataset_info or (
|
||||
self.selected_dataset
|
||||
if self.selected_dataset
|
||||
and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
|
||||
else self._parse_dataset_yaml(dataset_yaml)
|
||||
)
|
||||
|
||||
train_split = dataset_info.get("splits", {}).get("train") or {}
|
||||
images_path_str = train_split.get("path")
|
||||
if not images_path_str:
|
||||
return dataset_yaml
|
||||
|
||||
images_path = Path(images_path_str)
|
||||
if not images_path.exists():
|
||||
return dataset_yaml
|
||||
|
||||
# Check if dataset has 16-bit TIFF files that need 3-channel conversion
|
||||
if not self._dataset_has_16bit_tiff(images_path):
|
||||
return dataset_yaml
|
||||
|
||||
cache_root = self._get_float32_cache_root(dataset_yaml)
|
||||
float32_yaml = cache_root / "data.yaml"
|
||||
if float32_yaml.exists():
|
||||
self._append_training_log(
|
||||
f"Detected 16-bit TIFF dataset; reusing float32 3-channel cache at {cache_root}"
|
||||
)
|
||||
return float32_yaml
|
||||
|
||||
self._append_training_log(
|
||||
f"Detected 16-bit TIFF dataset; creating float32 3-channel cache at {cache_root}"
|
||||
)
|
||||
self._build_float32_dataset(cache_root, dataset_info)
|
||||
return float32_yaml
|
||||
|
||||
def _compose_stage_plan(self, params: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
two_stage = params.get("two_stage") or {}
|
||||
base_stage = {
|
||||
@@ -1293,140 +1244,6 @@ class TrainingTab(QWidget):
|
||||
f" • {stage_label}: epochs={epochs}, lr0={lr0}, patience={patience}, freeze={freeze}"
|
||||
)
|
||||
|
||||
def _get_float32_cache_root(self, dataset_yaml: Path) -> Path:
|
||||
"""Get cache directory for float32 3-channel converted datasets."""
|
||||
cache_base = Path("data/datasets/_float32_cache")
|
||||
cache_base.mkdir(parents=True, exist_ok=True)
|
||||
key = hashlib.md5(str(dataset_yaml.parent.resolve()).encode()).hexdigest()[:8]
|
||||
return cache_base / f"{dataset_yaml.parent.name}_{key}"
|
||||
|
||||
def _clear_rgb_cache_for_dataset(self, dataset_yaml: Path):
|
||||
"""Clear float32 cache for dataset."""
|
||||
cache_root = self._get_float32_cache_root(dataset_yaml)
|
||||
if cache_root.exists():
|
||||
try:
|
||||
shutil.rmtree(cache_root)
|
||||
logger.debug(f"Removed float32 cache at {cache_root}")
|
||||
except OSError as exc:
|
||||
logger.warning(f"Failed to remove float32 cache {cache_root}: {exc}")
|
||||
|
||||
def _dataset_has_16bit_tiff(self, images_dir: Path) -> bool:
|
||||
"""Check if dataset contains 16-bit TIFF files."""
|
||||
sample_image = self._find_first_image(images_dir)
|
||||
if not sample_image:
|
||||
return False
|
||||
try:
|
||||
if sample_image.suffix.lower() not in [".tif", ".tiff"]:
|
||||
return False
|
||||
img = Image(sample_image)
|
||||
return img.dtype == np.uint16
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to inspect image {sample_image}: {exc}")
|
||||
return False
|
||||
|
||||
def _find_first_image(self, directory: Path) -> Optional[Path]:
|
||||
"""Find first image in directory."""
|
||||
if not directory.exists():
|
||||
return None
|
||||
for path in directory.rglob("*"):
|
||||
if path.is_file() and path.suffix.lower() in self.allowed_extensions:
|
||||
return path
|
||||
return None
|
||||
|
||||
def _build_float32_dataset(self, cache_root: Path, dataset_info: Dict[str, Any]):
|
||||
"""Build float32 3-channel version of 16-bit TIFF dataset."""
|
||||
if cache_root.exists():
|
||||
shutil.rmtree(cache_root)
|
||||
cache_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
splits = dataset_info.get("splits", {})
|
||||
for split_name in ("train", "val", "test"):
|
||||
split_entry = splits.get(split_name)
|
||||
if not split_entry:
|
||||
continue
|
||||
images_src = Path(split_entry.get("path", ""))
|
||||
if not images_src.exists():
|
||||
continue
|
||||
images_dst = cache_root / split_name / "images"
|
||||
self._convert_16bit_to_float32_3ch(images_src, images_dst)
|
||||
|
||||
labels_src = self._infer_labels_dir(images_src)
|
||||
if labels_src.exists():
|
||||
labels_dst = cache_root / split_name / "labels"
|
||||
self._copy_labels(labels_src, labels_dst)
|
||||
|
||||
class_names = dataset_info.get("class_names") or []
|
||||
names_map = {idx: name for idx, name in enumerate(class_names)}
|
||||
num_classes = dataset_info.get("num_classes") or len(class_names)
|
||||
|
||||
yaml_payload: Dict[str, Any] = {
|
||||
"path": cache_root.as_posix(),
|
||||
"names": names_map,
|
||||
"nc": num_classes,
|
||||
}
|
||||
|
||||
for split_name in ("train", "val", "test"):
|
||||
images_dir = cache_root / split_name / "images"
|
||||
if images_dir.exists():
|
||||
yaml_payload[split_name] = f"{split_name}/images"
|
||||
|
||||
with open(cache_root / "data.yaml", "w", encoding="utf-8") as handle:
|
||||
yaml.safe_dump(yaml_payload, handle, sort_keys=False)
|
||||
|
||||
def _convert_16bit_to_float32_3ch(self, src_dir: Path, dst_dir: Path):
|
||||
"""Convert 16-bit TIFF images to float32 [0-1] 3-channel TIFFs.
|
||||
|
||||
This preserves the full dynamic range (no uint8 conversion) while
|
||||
creating the 3-channel format that YOLO training expects.
|
||||
"""
|
||||
for src in src_dir.rglob("*"):
|
||||
if not src.is_file() or src.suffix.lower() not in self.allowed_extensions:
|
||||
continue
|
||||
relative = src.relative_to(src_dir)
|
||||
dst = dst_dir / relative.with_suffix(".tif")
|
||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
img_obj = Image(src)
|
||||
|
||||
# Check if it's a 16-bit TIFF
|
||||
is_16bit_tiff = (
|
||||
src.suffix.lower() in [".tif", ".tiff"]
|
||||
and img_obj.dtype == np.uint16
|
||||
)
|
||||
|
||||
if is_16bit_tiff:
|
||||
# Convert to float32 [0-1]
|
||||
float_data = img_obj.to_normalized_float32()
|
||||
|
||||
# Replicate to 3 channels
|
||||
if len(float_data.shape) == 2:
|
||||
# H,W → H,W,3
|
||||
float_3ch = np.stack([float_data] * 3, axis=-1)
|
||||
elif len(float_data.shape) == 3 and float_data.shape[2] == 1:
|
||||
# H,W,1 → H,W,3
|
||||
float_3ch = np.repeat(float_data, 3, axis=2)
|
||||
else:
|
||||
# Already multi-channel
|
||||
float_3ch = float_data
|
||||
|
||||
# Save as float32 TIFF (preserves full precision)
|
||||
tifffile.imwrite(dst, float_3ch.astype(np.float32))
|
||||
logger.debug(f"Converted {src} to float32 3-channel TIFF at {dst}")
|
||||
else:
|
||||
# For non-16-bit images, just copy
|
||||
shutil.copy2(src, dst)
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to convert {src}: {exc}")
|
||||
|
||||
def _copy_labels(self, labels_src: Path, labels_dst: Path):
|
||||
label_files = list(labels_src.rglob("*.txt"))
|
||||
for label_file in label_files:
|
||||
relative = label_file.relative_to(labels_src)
|
||||
dst = labels_dst / relative
|
||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(label_file, dst)
|
||||
|
||||
def _infer_labels_dir(self, images_dir: Path) -> Path:
|
||||
return images_dir.parent / "labels"
|
||||
|
||||
@@ -1514,11 +1331,9 @@ class TrainingTab(QWidget):
|
||||
self.training_log.clear()
|
||||
self._export_labels_from_database(dataset_info)
|
||||
|
||||
dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info)
|
||||
if dataset_to_use != dataset_path:
|
||||
self._append_training_log(
|
||||
f"Using float32 3-channel dataset at {dataset_to_use.parent}"
|
||||
)
|
||||
self._append_training_log(
|
||||
"Using Float32 on-the-fly loader for 16-bit TIFF support (no disk caching)"
|
||||
)
|
||||
|
||||
params = self._collect_training_params()
|
||||
stage_plan = self._compose_stage_plan(params)
|
||||
@@ -1544,7 +1359,7 @@ class TrainingTab(QWidget):
|
||||
self._set_training_state(True)
|
||||
|
||||
self.training_worker = TrainingWorker(
|
||||
data_yaml=dataset_to_use.as_posix(),
|
||||
data_yaml=dataset_path.as_posix(),
|
||||
base_model=params["base_model"],
|
||||
epochs=params["epochs"],
|
||||
batch=params["batch"],
|
||||
|
||||
Reference in New Issue
Block a user