Update, cleanup

This commit is contained in:
2025-12-13 01:06:40 +02:00
parent 2411223a14
commit edcd448a61
2 changed files with 11 additions and 133 deletions

View File

@@ -3,7 +3,6 @@ Training tab for the microscopy object detection application.
Handles model training with YOLO.
"""
import hashlib
import shutil
from datetime import datetime
from pathlib import Path
@@ -34,7 +33,7 @@ from PySide6.QtWidgets import (
from src.database.db_manager import DatabaseManager
from src.model.yolo_wrapper import YOLOWrapper
from src.utils.config_manager import ConfigManager
from src.utils.image import Image, convert_grayscale_to_rgb_preserve_range
from src.utils.image import Image
from src.utils.logger import get_logger
@@ -947,9 +946,6 @@ class TrainingTab(QWidget):
for msg in split_messages:
self._append_training_log(msg)
if dataset_yaml:
self._clear_rgb_cache_for_dataset(dataset_yaml)
def _export_labels_for_split(
self,
split_name: str,
@@ -1167,38 +1163,12 @@ class TrainingTab(QWidget):
def _prepare_dataset_for_training(
self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None
) -> Path:
dataset_info = dataset_info or (
self.selected_dataset
if self.selected_dataset
and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
else self._parse_dataset_yaml(dataset_yaml)
)
"""Prepare dataset for training.
train_split = dataset_info.get("splits", {}).get("train") or {}
images_path_str = train_split.get("path")
if not images_path_str:
return dataset_yaml
images_path = Path(images_path_str)
if not images_path.exists():
return dataset_yaml
if not self._dataset_requires_rgb_conversion(images_path):
return dataset_yaml
cache_root = self._get_rgb_cache_root(dataset_yaml)
rgb_yaml = cache_root / "data.yaml"
if rgb_yaml.exists():
self._append_training_log(
f"Detected grayscale dataset; reusing RGB cache at {cache_root}"
)
return rgb_yaml
self._append_training_log(
f"Detected grayscale dataset; creating RGB cache at {cache_root}"
)
self._build_rgb_dataset(cache_root, dataset_info)
return rgb_yaml
Note: With proper 16-bit TIFF support in YOLOWrapper, we no longer need
to create RGB-converted copies of the dataset. Images are handled directly.
"""
return dataset_yaml
def _compose_stage_plan(self, params: Dict[str, Any]) -> List[Dict[str, Any]]:
two_stage = params.get("two_stage") or {}
@@ -1284,97 +1254,6 @@ class TrainingTab(QWidget):
f"{stage_label}: epochs={epochs}, lr0={lr0}, patience={patience}, freeze={freeze}"
)
def _get_rgb_cache_root(self, dataset_yaml: Path) -> Path:
cache_base = Path("data/datasets/_rgb_cache")
cache_base.mkdir(parents=True, exist_ok=True)
key = hashlib.md5(str(dataset_yaml.parent.resolve()).encode()).hexdigest()[:8]
return cache_base / f"{dataset_yaml.parent.name}_{key}"
def _clear_rgb_cache_for_dataset(self, dataset_yaml: Path):
cache_root = self._get_rgb_cache_root(dataset_yaml)
if cache_root.exists():
try:
shutil.rmtree(cache_root)
logger.debug(f"Removed RGB cache at {cache_root}")
except OSError as exc:
logger.warning(f"Failed to remove RGB cache {cache_root}: {exc}")
def _dataset_requires_rgb_conversion(self, images_dir: Path) -> bool:
sample_image = self._find_first_image(images_dir)
if not sample_image:
return False
try:
img = Image(sample_image)
return img.pil_image.mode.upper() != "RGB"
except Exception as exc:
logger.warning(f"Failed to inspect image {sample_image}: {exc}")
return False
def _find_first_image(self, directory: Path) -> Optional[Path]:
if not directory.exists():
return None
for path in directory.rglob("*"):
if path.is_file() and path.suffix.lower() in self.allowed_extensions:
return path
return None
def _build_rgb_dataset(self, cache_root: Path, dataset_info: Dict[str, Any]):
if cache_root.exists():
shutil.rmtree(cache_root)
cache_root.mkdir(parents=True, exist_ok=True)
splits = dataset_info.get("splits", {})
for split_name in ("train", "val", "test"):
split_entry = splits.get(split_name)
if not split_entry:
continue
images_src = Path(split_entry.get("path", ""))
if not images_src.exists():
continue
images_dst = cache_root / split_name / "images"
self._convert_images_to_rgb(images_src, images_dst)
labels_src = self._infer_labels_dir(images_src)
if labels_src.exists():
labels_dst = cache_root / split_name / "labels"
self._copy_labels(labels_src, labels_dst)
class_names = dataset_info.get("class_names") or []
names_map = {idx: name for idx, name in enumerate(class_names)}
num_classes = dataset_info.get("num_classes") or len(class_names)
yaml_payload: Dict[str, Any] = {
"path": cache_root.as_posix(),
"names": names_map,
"nc": num_classes,
}
for split_name in ("train", "val", "test"):
images_dir = cache_root / split_name / "images"
if images_dir.exists():
yaml_payload[split_name] = f"{split_name}/images"
with open(cache_root / "data.yaml", "w", encoding="utf-8") as handle:
yaml.safe_dump(yaml_payload, handle, sort_keys=False)
def _convert_images_to_rgb(self, src_dir: Path, dst_dir: Path):
for src in src_dir.rglob("*"):
if not src.is_file() or src.suffix.lower() not in self.allowed_extensions:
continue
relative = src.relative_to(src_dir)
dst = dst_dir / relative
dst.parent.mkdir(parents=True, exist_ok=True)
try:
img_obj = Image(src)
pil_img = img_obj.pil_image
if len(pil_img.getbands()) == 1:
rgb_img = convert_grayscale_to_rgb_preserve_range(pil_img)
else:
rgb_img = pil_img.convert("RGB")
rgb_img.save(dst)
except Exception as exc:
logger.warning(f"Failed to convert {src} to RGB: {exc}")
def _copy_labels(self, labels_src: Path, labels_dst: Path):
label_files = list(labels_src.rglob("*.txt"))
for label_file in label_files:
@@ -1471,10 +1350,6 @@ class TrainingTab(QWidget):
self._export_labels_from_database(dataset_info)
dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info)
if dataset_to_use != dataset_path:
self._append_training_log(
f"Using RGB-converted dataset at {dataset_to_use.parent}"
)
params = self._collect_training_params()
stage_plan = self._compose_stage_plan(params)

View File

@@ -192,12 +192,15 @@ class YOLOWrapper:
logger.error(f"Error during inference: {e}")
raise
finally:
if 0: # cleanup_path:
# Clean up temporary files (only for non-16-bit images)
# 16-bit TIFFs return numpy arrays directly, so cleanup_path is None
if cleanup_path:
try:
os.remove(cleanup_path)
logger.debug(f"Cleaned up temporary file: {cleanup_path}")
except OSError as cleanup_error:
logger.warning(
f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}"
f"Failed to delete temporary file {cleanup_path}: {cleanup_error}"
)
def export(