Update, cleanup
This commit is contained in:
@@ -3,7 +3,6 @@ Training tab for the microscopy object detection application.
|
|||||||
Handles model training with YOLO.
|
Handles model training with YOLO.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import shutil
|
import shutil
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -34,7 +33,7 @@ from PySide6.QtWidgets import (
|
|||||||
from src.database.db_manager import DatabaseManager
|
from src.database.db_manager import DatabaseManager
|
||||||
from src.model.yolo_wrapper import YOLOWrapper
|
from src.model.yolo_wrapper import YOLOWrapper
|
||||||
from src.utils.config_manager import ConfigManager
|
from src.utils.config_manager import ConfigManager
|
||||||
from src.utils.image import Image, convert_grayscale_to_rgb_preserve_range
|
from src.utils.image import Image
|
||||||
from src.utils.logger import get_logger
|
from src.utils.logger import get_logger
|
||||||
|
|
||||||
|
|
||||||
@@ -947,9 +946,6 @@ class TrainingTab(QWidget):
|
|||||||
for msg in split_messages:
|
for msg in split_messages:
|
||||||
self._append_training_log(msg)
|
self._append_training_log(msg)
|
||||||
|
|
||||||
if dataset_yaml:
|
|
||||||
self._clear_rgb_cache_for_dataset(dataset_yaml)
|
|
||||||
|
|
||||||
def _export_labels_for_split(
|
def _export_labels_for_split(
|
||||||
self,
|
self,
|
||||||
split_name: str,
|
split_name: str,
|
||||||
@@ -1167,38 +1163,12 @@ class TrainingTab(QWidget):
|
|||||||
def _prepare_dataset_for_training(
|
def _prepare_dataset_for_training(
|
||||||
self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None
|
self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None
|
||||||
) -> Path:
|
) -> Path:
|
||||||
dataset_info = dataset_info or (
|
"""Prepare dataset for training.
|
||||||
self.selected_dataset
|
|
||||||
if self.selected_dataset
|
|
||||||
and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
|
|
||||||
else self._parse_dataset_yaml(dataset_yaml)
|
|
||||||
)
|
|
||||||
|
|
||||||
train_split = dataset_info.get("splits", {}).get("train") or {}
|
Note: With proper 16-bit TIFF support in YOLOWrapper, we no longer need
|
||||||
images_path_str = train_split.get("path")
|
to create RGB-converted copies of the dataset. Images are handled directly.
|
||||||
if not images_path_str:
|
"""
|
||||||
return dataset_yaml
|
return dataset_yaml
|
||||||
|
|
||||||
images_path = Path(images_path_str)
|
|
||||||
if not images_path.exists():
|
|
||||||
return dataset_yaml
|
|
||||||
|
|
||||||
if not self._dataset_requires_rgb_conversion(images_path):
|
|
||||||
return dataset_yaml
|
|
||||||
|
|
||||||
cache_root = self._get_rgb_cache_root(dataset_yaml)
|
|
||||||
rgb_yaml = cache_root / "data.yaml"
|
|
||||||
if rgb_yaml.exists():
|
|
||||||
self._append_training_log(
|
|
||||||
f"Detected grayscale dataset; reusing RGB cache at {cache_root}"
|
|
||||||
)
|
|
||||||
return rgb_yaml
|
|
||||||
|
|
||||||
self._append_training_log(
|
|
||||||
f"Detected grayscale dataset; creating RGB cache at {cache_root}"
|
|
||||||
)
|
|
||||||
self._build_rgb_dataset(cache_root, dataset_info)
|
|
||||||
return rgb_yaml
|
|
||||||
|
|
||||||
def _compose_stage_plan(self, params: Dict[str, Any]) -> List[Dict[str, Any]]:
|
def _compose_stage_plan(self, params: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||||
two_stage = params.get("two_stage") or {}
|
two_stage = params.get("two_stage") or {}
|
||||||
@@ -1284,97 +1254,6 @@ class TrainingTab(QWidget):
|
|||||||
f" • {stage_label}: epochs={epochs}, lr0={lr0}, patience={patience}, freeze={freeze}"
|
f" • {stage_label}: epochs={epochs}, lr0={lr0}, patience={patience}, freeze={freeze}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_rgb_cache_root(self, dataset_yaml: Path) -> Path:
|
|
||||||
cache_base = Path("data/datasets/_rgb_cache")
|
|
||||||
cache_base.mkdir(parents=True, exist_ok=True)
|
|
||||||
key = hashlib.md5(str(dataset_yaml.parent.resolve()).encode()).hexdigest()[:8]
|
|
||||||
return cache_base / f"{dataset_yaml.parent.name}_{key}"
|
|
||||||
|
|
||||||
def _clear_rgb_cache_for_dataset(self, dataset_yaml: Path):
|
|
||||||
cache_root = self._get_rgb_cache_root(dataset_yaml)
|
|
||||||
if cache_root.exists():
|
|
||||||
try:
|
|
||||||
shutil.rmtree(cache_root)
|
|
||||||
logger.debug(f"Removed RGB cache at {cache_root}")
|
|
||||||
except OSError as exc:
|
|
||||||
logger.warning(f"Failed to remove RGB cache {cache_root}: {exc}")
|
|
||||||
|
|
||||||
def _dataset_requires_rgb_conversion(self, images_dir: Path) -> bool:
|
|
||||||
sample_image = self._find_first_image(images_dir)
|
|
||||||
if not sample_image:
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
img = Image(sample_image)
|
|
||||||
return img.pil_image.mode.upper() != "RGB"
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(f"Failed to inspect image {sample_image}: {exc}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _find_first_image(self, directory: Path) -> Optional[Path]:
|
|
||||||
if not directory.exists():
|
|
||||||
return None
|
|
||||||
for path in directory.rglob("*"):
|
|
||||||
if path.is_file() and path.suffix.lower() in self.allowed_extensions:
|
|
||||||
return path
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _build_rgb_dataset(self, cache_root: Path, dataset_info: Dict[str, Any]):
|
|
||||||
if cache_root.exists():
|
|
||||||
shutil.rmtree(cache_root)
|
|
||||||
cache_root.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
splits = dataset_info.get("splits", {})
|
|
||||||
for split_name in ("train", "val", "test"):
|
|
||||||
split_entry = splits.get(split_name)
|
|
||||||
if not split_entry:
|
|
||||||
continue
|
|
||||||
images_src = Path(split_entry.get("path", ""))
|
|
||||||
if not images_src.exists():
|
|
||||||
continue
|
|
||||||
images_dst = cache_root / split_name / "images"
|
|
||||||
self._convert_images_to_rgb(images_src, images_dst)
|
|
||||||
|
|
||||||
labels_src = self._infer_labels_dir(images_src)
|
|
||||||
if labels_src.exists():
|
|
||||||
labels_dst = cache_root / split_name / "labels"
|
|
||||||
self._copy_labels(labels_src, labels_dst)
|
|
||||||
|
|
||||||
class_names = dataset_info.get("class_names") or []
|
|
||||||
names_map = {idx: name for idx, name in enumerate(class_names)}
|
|
||||||
num_classes = dataset_info.get("num_classes") or len(class_names)
|
|
||||||
|
|
||||||
yaml_payload: Dict[str, Any] = {
|
|
||||||
"path": cache_root.as_posix(),
|
|
||||||
"names": names_map,
|
|
||||||
"nc": num_classes,
|
|
||||||
}
|
|
||||||
|
|
||||||
for split_name in ("train", "val", "test"):
|
|
||||||
images_dir = cache_root / split_name / "images"
|
|
||||||
if images_dir.exists():
|
|
||||||
yaml_payload[split_name] = f"{split_name}/images"
|
|
||||||
|
|
||||||
with open(cache_root / "data.yaml", "w", encoding="utf-8") as handle:
|
|
||||||
yaml.safe_dump(yaml_payload, handle, sort_keys=False)
|
|
||||||
|
|
||||||
def _convert_images_to_rgb(self, src_dir: Path, dst_dir: Path):
|
|
||||||
for src in src_dir.rglob("*"):
|
|
||||||
if not src.is_file() or src.suffix.lower() not in self.allowed_extensions:
|
|
||||||
continue
|
|
||||||
relative = src.relative_to(src_dir)
|
|
||||||
dst = dst_dir / relative
|
|
||||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
try:
|
|
||||||
img_obj = Image(src)
|
|
||||||
pil_img = img_obj.pil_image
|
|
||||||
if len(pil_img.getbands()) == 1:
|
|
||||||
rgb_img = convert_grayscale_to_rgb_preserve_range(pil_img)
|
|
||||||
else:
|
|
||||||
rgb_img = pil_img.convert("RGB")
|
|
||||||
rgb_img.save(dst)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(f"Failed to convert {src} to RGB: {exc}")
|
|
||||||
|
|
||||||
def _copy_labels(self, labels_src: Path, labels_dst: Path):
|
def _copy_labels(self, labels_src: Path, labels_dst: Path):
|
||||||
label_files = list(labels_src.rglob("*.txt"))
|
label_files = list(labels_src.rglob("*.txt"))
|
||||||
for label_file in label_files:
|
for label_file in label_files:
|
||||||
@@ -1471,10 +1350,6 @@ class TrainingTab(QWidget):
|
|||||||
self._export_labels_from_database(dataset_info)
|
self._export_labels_from_database(dataset_info)
|
||||||
|
|
||||||
dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info)
|
dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info)
|
||||||
if dataset_to_use != dataset_path:
|
|
||||||
self._append_training_log(
|
|
||||||
f"Using RGB-converted dataset at {dataset_to_use.parent}"
|
|
||||||
)
|
|
||||||
|
|
||||||
params = self._collect_training_params()
|
params = self._collect_training_params()
|
||||||
stage_plan = self._compose_stage_plan(params)
|
stage_plan = self._compose_stage_plan(params)
|
||||||
|
|||||||
@@ -192,12 +192,15 @@ class YOLOWrapper:
|
|||||||
logger.error(f"Error during inference: {e}")
|
logger.error(f"Error during inference: {e}")
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
if 0: # cleanup_path:
|
# Clean up temporary files (only for non-16-bit images)
|
||||||
|
# 16-bit TIFFs return numpy arrays directly, so cleanup_path is None
|
||||||
|
if cleanup_path:
|
||||||
try:
|
try:
|
||||||
os.remove(cleanup_path)
|
os.remove(cleanup_path)
|
||||||
|
logger.debug(f"Cleaned up temporary file: {cleanup_path}")
|
||||||
except OSError as cleanup_error:
|
except OSError as cleanup_error:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}"
|
f"Failed to delete temporary file {cleanup_path}: {cleanup_error}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def export(
|
def export(
|
||||||
|
|||||||
Reference in New Issue
Block a user