diff --git a/src/gui/tabs/training_tab.py b/src/gui/tabs/training_tab.py index 2fce300..a3af7e2 100644 --- a/src/gui/tabs/training_tab.py +++ b/src/gui/tabs/training_tab.py @@ -10,7 +10,6 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import yaml -from PIL import Image as PILImage from PySide6.QtCore import Qt, QThread, Signal from PySide6.QtWidgets import ( QWidget, @@ -1293,8 +1292,8 @@ class TrainingTab(QWidget): if not sample_image: return False try: - with PILImage.open(sample_image) as img: - return img.mode.upper() != "RGB" + img = Image(sample_image) + return img.pil_image.mode.upper() != "RGB" except Exception as exc: logger.warning(f"Failed to inspect image {sample_image}: {exc}") return False @@ -1354,12 +1353,13 @@ class TrainingTab(QWidget): dst = dst_dir / relative dst.parent.mkdir(parents=True, exist_ok=True) try: - with PILImage.open(src) as img: - if len(img.getbands()) == 1: - rgb_img = convert_grayscale_to_rgb_preserve_range(img) - else: - rgb_img = img.convert("RGB") - rgb_img.save(dst) + img_obj = Image(src) + pil_img = img_obj.pil_image + if len(pil_img.getbands()) == 1: + rgb_img = convert_grayscale_to_rgb_preserve_range(pil_img) + else: + rgb_img = pil_img.convert("RGB") + rgb_img.save(dst) except Exception as exc: logger.warning(f"Failed to convert {src} to RGB: {exc}") diff --git a/src/model/inference.py b/src/model/inference.py index 99a2de7..0e4aea3 100644 --- a/src/model/inference.py +++ b/src/model/inference.py @@ -5,12 +5,12 @@ Handles detection inference and result storage. from typing import List, Dict, Optional, Callable from pathlib import Path -from PIL import Image import cv2 import numpy as np from src.model.yolo_wrapper import YOLOWrapper from src.database.db_manager import DatabaseManager +from src.utils.image import Image from src.utils.logger import get_logger from src.utils.file_utils import get_relative_path @@ -64,9 +64,9 @@ class InferenceEngine: stored_relative_path = str(Path(image_path).resolve()) # Get image dimensions - img = Image.open(image_path) - width, height = img.size - img.close() + img = Image(image_path) + width = img.width + height = img.height # Perform detection detections = self.yolo.predict(image_path, conf=conf) diff --git a/src/model/yolo_wrapper.py b/src/model/yolo_wrapper.py index f165c7b..db95a2c 100644 --- a/src/model/yolo_wrapper.py +++ b/src/model/yolo_wrapper.py @@ -7,10 +7,9 @@ from ultralytics import YOLO from pathlib import Path from typing import Optional, List, Dict, Callable, Any import torch -from PIL import Image import tempfile import os -from src.utils.image import convert_grayscale_to_rgb_preserve_range +from src.utils.image import Image, convert_grayscale_to_rgb_preserve_range from src.utils.logger import get_logger @@ -232,22 +231,23 @@ class YOLOWrapper: source_path = Path(source) if source_path.is_file(): try: - with Image.open(source_path) as img: - if len(img.getbands()) == 1: - rgb_img = convert_grayscale_to_rgb_preserve_range(img) - else: - rgb_img = img.convert("RGB") + img_obj = Image(source_path) + pil_img = img_obj.pil_image + if len(pil_img.getbands()) == 1: + rgb_img = convert_grayscale_to_rgb_preserve_range(pil_img) + else: + rgb_img = pil_img.convert("RGB") - suffix = source_path.suffix or ".png" - tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) - tmp_path = tmp.name - tmp.close() - rgb_img.save(tmp_path) - cleanup_path = tmp_path - logger.info( - f"Converted image {source_path} to RGB for inference at {tmp_path}" - ) - return tmp_path, cleanup_path + suffix = source_path.suffix or ".png" + tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) + tmp_path = tmp.name + tmp.close() + rgb_img.save(tmp_path) + cleanup_path = tmp_path + logger.info( + f"Converted image {source_path} to RGB for inference at {tmp_path}" + ) + return tmp_path, cleanup_path except Exception as convert_error: logger.warning( f"Failed to preprocess {source_path} as RGB, continuing with original file: {convert_error}"