From edcd448a614908bea75b2cc0fd27238129ec9b64 Mon Sep 17 00:00:00 2001 From: Martin Laasmaa Date: Sat, 13 Dec 2025 01:06:40 +0200 Subject: [PATCH] Update, cleanup --- src/gui/tabs/training_tab.py | 137 ++--------------------------------- src/model/yolo_wrapper.py | 7 +- 2 files changed, 11 insertions(+), 133 deletions(-) diff --git a/src/gui/tabs/training_tab.py b/src/gui/tabs/training_tab.py index 5d86fe4..67855cb 100644 --- a/src/gui/tabs/training_tab.py +++ b/src/gui/tabs/training_tab.py @@ -3,7 +3,6 @@ Training tab for the microscopy object detection application. Handles model training with YOLO. """ -import hashlib import shutil from datetime import datetime from pathlib import Path @@ -34,7 +33,7 @@ from PySide6.QtWidgets import ( from src.database.db_manager import DatabaseManager from src.model.yolo_wrapper import YOLOWrapper from src.utils.config_manager import ConfigManager -from src.utils.image import Image, convert_grayscale_to_rgb_preserve_range +from src.utils.image import Image from src.utils.logger import get_logger @@ -947,9 +946,6 @@ class TrainingTab(QWidget): for msg in split_messages: self._append_training_log(msg) - if dataset_yaml: - self._clear_rgb_cache_for_dataset(dataset_yaml) - def _export_labels_for_split( self, split_name: str, @@ -1167,38 +1163,12 @@ class TrainingTab(QWidget): def _prepare_dataset_for_training( self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None ) -> Path: - dataset_info = dataset_info or ( - self.selected_dataset - if self.selected_dataset - and self.selected_dataset.get("yaml_path") == str(dataset_yaml) - else self._parse_dataset_yaml(dataset_yaml) - ) + """Prepare dataset for training. - train_split = dataset_info.get("splits", {}).get("train") or {} - images_path_str = train_split.get("path") - if not images_path_str: - return dataset_yaml - - images_path = Path(images_path_str) - if not images_path.exists(): - return dataset_yaml - - if not self._dataset_requires_rgb_conversion(images_path): - return dataset_yaml - - cache_root = self._get_rgb_cache_root(dataset_yaml) - rgb_yaml = cache_root / "data.yaml" - if rgb_yaml.exists(): - self._append_training_log( - f"Detected grayscale dataset; reusing RGB cache at {cache_root}" - ) - return rgb_yaml - - self._append_training_log( - f"Detected grayscale dataset; creating RGB cache at {cache_root}" - ) - self._build_rgb_dataset(cache_root, dataset_info) - return rgb_yaml + Note: With proper 16-bit TIFF support in YOLOWrapper, we no longer need + to create RGB-converted copies of the dataset. Images are handled directly. + """ + return dataset_yaml def _compose_stage_plan(self, params: Dict[str, Any]) -> List[Dict[str, Any]]: two_stage = params.get("two_stage") or {} @@ -1284,97 +1254,6 @@ class TrainingTab(QWidget): f" • {stage_label}: epochs={epochs}, lr0={lr0}, patience={patience}, freeze={freeze}" ) - def _get_rgb_cache_root(self, dataset_yaml: Path) -> Path: - cache_base = Path("data/datasets/_rgb_cache") - cache_base.mkdir(parents=True, exist_ok=True) - key = hashlib.md5(str(dataset_yaml.parent.resolve()).encode()).hexdigest()[:8] - return cache_base / f"{dataset_yaml.parent.name}_{key}" - - def _clear_rgb_cache_for_dataset(self, dataset_yaml: Path): - cache_root = self._get_rgb_cache_root(dataset_yaml) - if cache_root.exists(): - try: - shutil.rmtree(cache_root) - logger.debug(f"Removed RGB cache at {cache_root}") - except OSError as exc: - logger.warning(f"Failed to remove RGB cache {cache_root}: {exc}") - - def _dataset_requires_rgb_conversion(self, images_dir: Path) -> bool: - sample_image = self._find_first_image(images_dir) - if not sample_image: - return False - try: - img = Image(sample_image) - return img.pil_image.mode.upper() != "RGB" - except Exception as exc: - logger.warning(f"Failed to inspect image {sample_image}: {exc}") - return False - - def _find_first_image(self, directory: Path) -> Optional[Path]: - if not directory.exists(): - return None - for path in directory.rglob("*"): - if path.is_file() and path.suffix.lower() in self.allowed_extensions: - return path - return None - - def _build_rgb_dataset(self, cache_root: Path, dataset_info: Dict[str, Any]): - if cache_root.exists(): - shutil.rmtree(cache_root) - cache_root.mkdir(parents=True, exist_ok=True) - - splits = dataset_info.get("splits", {}) - for split_name in ("train", "val", "test"): - split_entry = splits.get(split_name) - if not split_entry: - continue - images_src = Path(split_entry.get("path", "")) - if not images_src.exists(): - continue - images_dst = cache_root / split_name / "images" - self._convert_images_to_rgb(images_src, images_dst) - - labels_src = self._infer_labels_dir(images_src) - if labels_src.exists(): - labels_dst = cache_root / split_name / "labels" - self._copy_labels(labels_src, labels_dst) - - class_names = dataset_info.get("class_names") or [] - names_map = {idx: name for idx, name in enumerate(class_names)} - num_classes = dataset_info.get("num_classes") or len(class_names) - - yaml_payload: Dict[str, Any] = { - "path": cache_root.as_posix(), - "names": names_map, - "nc": num_classes, - } - - for split_name in ("train", "val", "test"): - images_dir = cache_root / split_name / "images" - if images_dir.exists(): - yaml_payload[split_name] = f"{split_name}/images" - - with open(cache_root / "data.yaml", "w", encoding="utf-8") as handle: - yaml.safe_dump(yaml_payload, handle, sort_keys=False) - - def _convert_images_to_rgb(self, src_dir: Path, dst_dir: Path): - for src in src_dir.rglob("*"): - if not src.is_file() or src.suffix.lower() not in self.allowed_extensions: - continue - relative = src.relative_to(src_dir) - dst = dst_dir / relative - dst.parent.mkdir(parents=True, exist_ok=True) - try: - img_obj = Image(src) - pil_img = img_obj.pil_image - if len(pil_img.getbands()) == 1: - rgb_img = convert_grayscale_to_rgb_preserve_range(pil_img) - else: - rgb_img = pil_img.convert("RGB") - rgb_img.save(dst) - except Exception as exc: - logger.warning(f"Failed to convert {src} to RGB: {exc}") - def _copy_labels(self, labels_src: Path, labels_dst: Path): label_files = list(labels_src.rglob("*.txt")) for label_file in label_files: @@ -1471,10 +1350,6 @@ class TrainingTab(QWidget): self._export_labels_from_database(dataset_info) dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info) - if dataset_to_use != dataset_path: - self._append_training_log( - f"Using RGB-converted dataset at {dataset_to_use.parent}" - ) params = self._collect_training_params() stage_plan = self._compose_stage_plan(params) diff --git a/src/model/yolo_wrapper.py b/src/model/yolo_wrapper.py index fbe71eb..0b690cd 100644 --- a/src/model/yolo_wrapper.py +++ b/src/model/yolo_wrapper.py @@ -192,12 +192,15 @@ class YOLOWrapper: logger.error(f"Error during inference: {e}") raise finally: - if 0: # cleanup_path: + # Clean up temporary files (only for non-16-bit images) + # 16-bit TIFFs return numpy arrays directly, so cleanup_path is None + if cleanup_path: try: os.remove(cleanup_path) + logger.debug(f"Cleaned up temporary file: {cleanup_path}") except OSError as cleanup_error: logger.warning( - f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}" + f"Failed to delete temporary file {cleanup_path}: {cleanup_error}" ) def export(