Update, cleanup
This commit is contained in:
@@ -3,7 +3,6 @@ Training tab for the microscopy object detection application.
|
||||
Handles model training with YOLO.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@@ -34,7 +33,7 @@ from PySide6.QtWidgets import (
|
||||
from src.database.db_manager import DatabaseManager
|
||||
from src.model.yolo_wrapper import YOLOWrapper
|
||||
from src.utils.config_manager import ConfigManager
|
||||
from src.utils.image import Image, convert_grayscale_to_rgb_preserve_range
|
||||
from src.utils.image import Image
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
|
||||
@@ -947,9 +946,6 @@ class TrainingTab(QWidget):
|
||||
for msg in split_messages:
|
||||
self._append_training_log(msg)
|
||||
|
||||
if dataset_yaml:
|
||||
self._clear_rgb_cache_for_dataset(dataset_yaml)
|
||||
|
||||
def _export_labels_for_split(
|
||||
self,
|
||||
split_name: str,
|
||||
@@ -1167,38 +1163,12 @@ class TrainingTab(QWidget):
|
||||
def _prepare_dataset_for_training(
|
||||
self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None
|
||||
) -> Path:
|
||||
dataset_info = dataset_info or (
|
||||
self.selected_dataset
|
||||
if self.selected_dataset
|
||||
and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
|
||||
else self._parse_dataset_yaml(dataset_yaml)
|
||||
)
|
||||
"""Prepare dataset for training.
|
||||
|
||||
train_split = dataset_info.get("splits", {}).get("train") or {}
|
||||
images_path_str = train_split.get("path")
|
||||
if not images_path_str:
|
||||
return dataset_yaml
|
||||
|
||||
images_path = Path(images_path_str)
|
||||
if not images_path.exists():
|
||||
return dataset_yaml
|
||||
|
||||
if not self._dataset_requires_rgb_conversion(images_path):
|
||||
return dataset_yaml
|
||||
|
||||
cache_root = self._get_rgb_cache_root(dataset_yaml)
|
||||
rgb_yaml = cache_root / "data.yaml"
|
||||
if rgb_yaml.exists():
|
||||
self._append_training_log(
|
||||
f"Detected grayscale dataset; reusing RGB cache at {cache_root}"
|
||||
)
|
||||
return rgb_yaml
|
||||
|
||||
self._append_training_log(
|
||||
f"Detected grayscale dataset; creating RGB cache at {cache_root}"
|
||||
)
|
||||
self._build_rgb_dataset(cache_root, dataset_info)
|
||||
return rgb_yaml
|
||||
Note: With proper 16-bit TIFF support in YOLOWrapper, we no longer need
|
||||
to create RGB-converted copies of the dataset. Images are handled directly.
|
||||
"""
|
||||
return dataset_yaml
|
||||
|
||||
def _compose_stage_plan(self, params: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
two_stage = params.get("two_stage") or {}
|
||||
@@ -1284,97 +1254,6 @@ class TrainingTab(QWidget):
|
||||
f" • {stage_label}: epochs={epochs}, lr0={lr0}, patience={patience}, freeze={freeze}"
|
||||
)
|
||||
|
||||
def _get_rgb_cache_root(self, dataset_yaml: Path) -> Path:
|
||||
cache_base = Path("data/datasets/_rgb_cache")
|
||||
cache_base.mkdir(parents=True, exist_ok=True)
|
||||
key = hashlib.md5(str(dataset_yaml.parent.resolve()).encode()).hexdigest()[:8]
|
||||
return cache_base / f"{dataset_yaml.parent.name}_{key}"
|
||||
|
||||
def _clear_rgb_cache_for_dataset(self, dataset_yaml: Path):
|
||||
cache_root = self._get_rgb_cache_root(dataset_yaml)
|
||||
if cache_root.exists():
|
||||
try:
|
||||
shutil.rmtree(cache_root)
|
||||
logger.debug(f"Removed RGB cache at {cache_root}")
|
||||
except OSError as exc:
|
||||
logger.warning(f"Failed to remove RGB cache {cache_root}: {exc}")
|
||||
|
||||
def _dataset_requires_rgb_conversion(self, images_dir: Path) -> bool:
|
||||
sample_image = self._find_first_image(images_dir)
|
||||
if not sample_image:
|
||||
return False
|
||||
try:
|
||||
img = Image(sample_image)
|
||||
return img.pil_image.mode.upper() != "RGB"
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to inspect image {sample_image}: {exc}")
|
||||
return False
|
||||
|
||||
def _find_first_image(self, directory: Path) -> Optional[Path]:
|
||||
if not directory.exists():
|
||||
return None
|
||||
for path in directory.rglob("*"):
|
||||
if path.is_file() and path.suffix.lower() in self.allowed_extensions:
|
||||
return path
|
||||
return None
|
||||
|
||||
def _build_rgb_dataset(self, cache_root: Path, dataset_info: Dict[str, Any]):
|
||||
if cache_root.exists():
|
||||
shutil.rmtree(cache_root)
|
||||
cache_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
splits = dataset_info.get("splits", {})
|
||||
for split_name in ("train", "val", "test"):
|
||||
split_entry = splits.get(split_name)
|
||||
if not split_entry:
|
||||
continue
|
||||
images_src = Path(split_entry.get("path", ""))
|
||||
if not images_src.exists():
|
||||
continue
|
||||
images_dst = cache_root / split_name / "images"
|
||||
self._convert_images_to_rgb(images_src, images_dst)
|
||||
|
||||
labels_src = self._infer_labels_dir(images_src)
|
||||
if labels_src.exists():
|
||||
labels_dst = cache_root / split_name / "labels"
|
||||
self._copy_labels(labels_src, labels_dst)
|
||||
|
||||
class_names = dataset_info.get("class_names") or []
|
||||
names_map = {idx: name for idx, name in enumerate(class_names)}
|
||||
num_classes = dataset_info.get("num_classes") or len(class_names)
|
||||
|
||||
yaml_payload: Dict[str, Any] = {
|
||||
"path": cache_root.as_posix(),
|
||||
"names": names_map,
|
||||
"nc": num_classes,
|
||||
}
|
||||
|
||||
for split_name in ("train", "val", "test"):
|
||||
images_dir = cache_root / split_name / "images"
|
||||
if images_dir.exists():
|
||||
yaml_payload[split_name] = f"{split_name}/images"
|
||||
|
||||
with open(cache_root / "data.yaml", "w", encoding="utf-8") as handle:
|
||||
yaml.safe_dump(yaml_payload, handle, sort_keys=False)
|
||||
|
||||
def _convert_images_to_rgb(self, src_dir: Path, dst_dir: Path):
|
||||
for src in src_dir.rglob("*"):
|
||||
if not src.is_file() or src.suffix.lower() not in self.allowed_extensions:
|
||||
continue
|
||||
relative = src.relative_to(src_dir)
|
||||
dst = dst_dir / relative
|
||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
img_obj = Image(src)
|
||||
pil_img = img_obj.pil_image
|
||||
if len(pil_img.getbands()) == 1:
|
||||
rgb_img = convert_grayscale_to_rgb_preserve_range(pil_img)
|
||||
else:
|
||||
rgb_img = pil_img.convert("RGB")
|
||||
rgb_img.save(dst)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to convert {src} to RGB: {exc}")
|
||||
|
||||
def _copy_labels(self, labels_src: Path, labels_dst: Path):
|
||||
label_files = list(labels_src.rglob("*.txt"))
|
||||
for label_file in label_files:
|
||||
@@ -1471,10 +1350,6 @@ class TrainingTab(QWidget):
|
||||
self._export_labels_from_database(dataset_info)
|
||||
|
||||
dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info)
|
||||
if dataset_to_use != dataset_path:
|
||||
self._append_training_log(
|
||||
f"Using RGB-converted dataset at {dataset_to_use.parent}"
|
||||
)
|
||||
|
||||
params = self._collect_training_params()
|
||||
stage_plan = self._compose_stage_plan(params)
|
||||
|
||||
@@ -192,12 +192,15 @@ class YOLOWrapper:
|
||||
logger.error(f"Error during inference: {e}")
|
||||
raise
|
||||
finally:
|
||||
if 0: # cleanup_path:
|
||||
# Clean up temporary files (only for non-16-bit images)
|
||||
# 16-bit TIFFs return numpy arrays directly, so cleanup_path is None
|
||||
if cleanup_path:
|
||||
try:
|
||||
os.remove(cleanup_path)
|
||||
logger.debug(f"Cleaned up temporary file: {cleanup_path}")
|
||||
except OSError as cleanup_error:
|
||||
logger.warning(
|
||||
f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}"
|
||||
f"Failed to delete temporary file {cleanup_path}: {cleanup_error}"
|
||||
)
|
||||
|
||||
def export(
|
||||
|
||||
Reference in New Issue
Block a user