Update, cleanup

This commit is contained in:
2025-12-13 01:06:40 +02:00
parent 2411223a14
commit edcd448a61
2 changed files with 11 additions and 133 deletions

View File

@@ -3,7 +3,6 @@ Training tab for the microscopy object detection application.
Handles model training with YOLO. Handles model training with YOLO.
""" """
import hashlib
import shutil import shutil
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
@@ -34,7 +33,7 @@ from PySide6.QtWidgets import (
from src.database.db_manager import DatabaseManager from src.database.db_manager import DatabaseManager
from src.model.yolo_wrapper import YOLOWrapper from src.model.yolo_wrapper import YOLOWrapper
from src.utils.config_manager import ConfigManager from src.utils.config_manager import ConfigManager
from src.utils.image import Image, convert_grayscale_to_rgb_preserve_range from src.utils.image import Image
from src.utils.logger import get_logger from src.utils.logger import get_logger
@@ -947,9 +946,6 @@ class TrainingTab(QWidget):
for msg in split_messages: for msg in split_messages:
self._append_training_log(msg) self._append_training_log(msg)
if dataset_yaml:
self._clear_rgb_cache_for_dataset(dataset_yaml)
def _export_labels_for_split( def _export_labels_for_split(
self, self,
split_name: str, split_name: str,
@@ -1167,38 +1163,12 @@ class TrainingTab(QWidget):
def _prepare_dataset_for_training( def _prepare_dataset_for_training(
self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None
) -> Path: ) -> Path:
dataset_info = dataset_info or ( """Prepare dataset for training.
self.selected_dataset
if self.selected_dataset
and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
else self._parse_dataset_yaml(dataset_yaml)
)
train_split = dataset_info.get("splits", {}).get("train") or {} Note: With proper 16-bit TIFF support in YOLOWrapper, we no longer need
images_path_str = train_split.get("path") to create RGB-converted copies of the dataset. Images are handled directly.
if not images_path_str: """
return dataset_yaml return dataset_yaml
images_path = Path(images_path_str)
if not images_path.exists():
return dataset_yaml
if not self._dataset_requires_rgb_conversion(images_path):
return dataset_yaml
cache_root = self._get_rgb_cache_root(dataset_yaml)
rgb_yaml = cache_root / "data.yaml"
if rgb_yaml.exists():
self._append_training_log(
f"Detected grayscale dataset; reusing RGB cache at {cache_root}"
)
return rgb_yaml
self._append_training_log(
f"Detected grayscale dataset; creating RGB cache at {cache_root}"
)
self._build_rgb_dataset(cache_root, dataset_info)
return rgb_yaml
def _compose_stage_plan(self, params: Dict[str, Any]) -> List[Dict[str, Any]]: def _compose_stage_plan(self, params: Dict[str, Any]) -> List[Dict[str, Any]]:
two_stage = params.get("two_stage") or {} two_stage = params.get("two_stage") or {}
@@ -1284,97 +1254,6 @@ class TrainingTab(QWidget):
f"{stage_label}: epochs={epochs}, lr0={lr0}, patience={patience}, freeze={freeze}" f"{stage_label}: epochs={epochs}, lr0={lr0}, patience={patience}, freeze={freeze}"
) )
def _get_rgb_cache_root(self, dataset_yaml: Path) -> Path:
cache_base = Path("data/datasets/_rgb_cache")
cache_base.mkdir(parents=True, exist_ok=True)
key = hashlib.md5(str(dataset_yaml.parent.resolve()).encode()).hexdigest()[:8]
return cache_base / f"{dataset_yaml.parent.name}_{key}"
def _clear_rgb_cache_for_dataset(self, dataset_yaml: Path):
cache_root = self._get_rgb_cache_root(dataset_yaml)
if cache_root.exists():
try:
shutil.rmtree(cache_root)
logger.debug(f"Removed RGB cache at {cache_root}")
except OSError as exc:
logger.warning(f"Failed to remove RGB cache {cache_root}: {exc}")
def _dataset_requires_rgb_conversion(self, images_dir: Path) -> bool:
sample_image = self._find_first_image(images_dir)
if not sample_image:
return False
try:
img = Image(sample_image)
return img.pil_image.mode.upper() != "RGB"
except Exception as exc:
logger.warning(f"Failed to inspect image {sample_image}: {exc}")
return False
def _find_first_image(self, directory: Path) -> Optional[Path]:
if not directory.exists():
return None
for path in directory.rglob("*"):
if path.is_file() and path.suffix.lower() in self.allowed_extensions:
return path
return None
def _build_rgb_dataset(self, cache_root: Path, dataset_info: Dict[str, Any]):
if cache_root.exists():
shutil.rmtree(cache_root)
cache_root.mkdir(parents=True, exist_ok=True)
splits = dataset_info.get("splits", {})
for split_name in ("train", "val", "test"):
split_entry = splits.get(split_name)
if not split_entry:
continue
images_src = Path(split_entry.get("path", ""))
if not images_src.exists():
continue
images_dst = cache_root / split_name / "images"
self._convert_images_to_rgb(images_src, images_dst)
labels_src = self._infer_labels_dir(images_src)
if labels_src.exists():
labels_dst = cache_root / split_name / "labels"
self._copy_labels(labels_src, labels_dst)
class_names = dataset_info.get("class_names") or []
names_map = {idx: name for idx, name in enumerate(class_names)}
num_classes = dataset_info.get("num_classes") or len(class_names)
yaml_payload: Dict[str, Any] = {
"path": cache_root.as_posix(),
"names": names_map,
"nc": num_classes,
}
for split_name in ("train", "val", "test"):
images_dir = cache_root / split_name / "images"
if images_dir.exists():
yaml_payload[split_name] = f"{split_name}/images"
with open(cache_root / "data.yaml", "w", encoding="utf-8") as handle:
yaml.safe_dump(yaml_payload, handle, sort_keys=False)
def _convert_images_to_rgb(self, src_dir: Path, dst_dir: Path):
for src in src_dir.rglob("*"):
if not src.is_file() or src.suffix.lower() not in self.allowed_extensions:
continue
relative = src.relative_to(src_dir)
dst = dst_dir / relative
dst.parent.mkdir(parents=True, exist_ok=True)
try:
img_obj = Image(src)
pil_img = img_obj.pil_image
if len(pil_img.getbands()) == 1:
rgb_img = convert_grayscale_to_rgb_preserve_range(pil_img)
else:
rgb_img = pil_img.convert("RGB")
rgb_img.save(dst)
except Exception as exc:
logger.warning(f"Failed to convert {src} to RGB: {exc}")
def _copy_labels(self, labels_src: Path, labels_dst: Path): def _copy_labels(self, labels_src: Path, labels_dst: Path):
label_files = list(labels_src.rglob("*.txt")) label_files = list(labels_src.rglob("*.txt"))
for label_file in label_files: for label_file in label_files:
@@ -1471,10 +1350,6 @@ class TrainingTab(QWidget):
self._export_labels_from_database(dataset_info) self._export_labels_from_database(dataset_info)
dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info) dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info)
if dataset_to_use != dataset_path:
self._append_training_log(
f"Using RGB-converted dataset at {dataset_to_use.parent}"
)
params = self._collect_training_params() params = self._collect_training_params()
stage_plan = self._compose_stage_plan(params) stage_plan = self._compose_stage_plan(params)

View File

@@ -192,12 +192,15 @@ class YOLOWrapper:
logger.error(f"Error during inference: {e}") logger.error(f"Error during inference: {e}")
raise raise
finally: finally:
if 0: # cleanup_path: # Clean up temporary files (only for non-16-bit images)
# 16-bit TIFFs return numpy arrays directly, so cleanup_path is None
if cleanup_path:
try: try:
os.remove(cleanup_path) os.remove(cleanup_path)
logger.debug(f"Cleaned up temporary file: {cleanup_path}")
except OSError as cleanup_error: except OSError as cleanup_error:
logger.warning( logger.warning(
f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}" f"Failed to delete temporary file {cleanup_path}: {cleanup_error}"
) )
def export( def export(