diff --git a/src/database/db_manager.py b/src/database/db_manager.py index 039786d..1a331e8 100644 --- a/src/database/db_manager.py +++ b/src/database/db_manager.py @@ -13,8 +13,9 @@ import hashlib import yaml from src.utils.logger import get_logger +from src.utils.image import Image -IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp") +IMAGE_EXTENSIONS = tuple(Image.SUPPORTED_EXTENSIONS) logger = get_logger(__name__) diff --git a/src/gui/tabs/annotation_tab.py b/src/gui/tabs/annotation_tab.py index 6927aba..83ca894 100644 --- a/src/gui/tabs/annotation_tab.py +++ b/src/gui/tabs/annotation_tab.py @@ -168,7 +168,7 @@ class AnnotationTab(QWidget): self, "Select Image", start_dir, - "Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)", + "Images (*" + " *".join(Image.SUPPORTED_EXTENSIONS) + ")", ) if not file_path: diff --git a/src/gui/tabs/detection_tab.py b/src/gui/tabs/detection_tab.py index 364783f..7dbef1c 100644 --- a/src/gui/tabs/detection_tab.py +++ b/src/gui/tabs/detection_tab.py @@ -27,6 +27,7 @@ from src.utils.config_manager import ConfigManager from src.utils.logger import get_logger from src.utils.file_utils import get_image_files from src.model.inference import InferenceEngine +from src.utils.image import Image logger = get_logger(__name__) @@ -236,7 +237,7 @@ class DetectionTab(QWidget): self, "Select Image", start_dir, - "Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)", + "Images (*" + " *".join(Image.SUPPORTED_EXTENSIONS) + ")", ) if not file_path: diff --git a/src/gui/tabs/training_tab.py b/src/gui/tabs/training_tab.py index 257d3c2..2fce300 100644 --- a/src/gui/tabs/training_tab.py +++ b/src/gui/tabs/training_tab.py @@ -34,20 +34,13 @@ from PySide6.QtWidgets import ( from src.database.db_manager import DatabaseManager from src.model.yolo_wrapper import YOLOWrapper from src.utils.config_manager import ConfigManager -from src.utils.image import convert_grayscale_to_rgb_preserve_range +from src.utils.image import Image, convert_grayscale_to_rgb_preserve_range from src.utils.logger import get_logger logger = get_logger(__name__) -DEFAULT_IMAGE_EXTENSIONS = { - ".jpg", - ".jpeg", - ".png", - ".tif", - ".tiff", - ".bmp", -} +DEFAULT_IMAGE_EXTENSIONS = set(Image.SUPPORTED_EXTENSIONS) class TrainingWorker(QThread): diff --git a/src/utils/config_manager.py b/src/utils/config_manager.py index 5b909ff..c6d1979 100644 --- a/src/utils/config_manager.py +++ b/src/utils/config_manager.py @@ -7,6 +7,7 @@ import yaml from pathlib import Path from typing import Any, Dict, Optional from src.utils.logger import get_logger +from src.utils.image import Image logger = get_logger(__name__) @@ -46,14 +47,7 @@ class ConfigManager: "database": {"path": "data/detections.db"}, "image_repository": { "base_path": "", - "allowed_extensions": [ - ".jpg", - ".jpeg", - ".png", - ".tif", - ".tiff", - ".bmp", - ], + "allowed_extensions": Image.SUPPORTED_EXTENSIONS, }, "models": { "default_base_model": "yolov8s-seg.pt", @@ -232,5 +226,5 @@ class ConfigManager: def get_allowed_extensions(self) -> list: """Get list of allowed image file extensions.""" return self.get( - "image_repository.allowed_extensions", [".jpg", ".jpeg", ".png"] + "image_repository.allowed_extensions", Image.SUPPORTED_EXTENSIONS ) diff --git a/src/utils/file_utils.py b/src/utils/file_utils.py index 019852e..c7730dd 100644 --- a/src/utils/file_utils.py +++ b/src/utils/file_utils.py @@ -6,6 +6,7 @@ import os from pathlib import Path from typing import List, Optional from src.utils.logger import get_logger +from src.utils.image import Image logger = get_logger(__name__) @@ -28,7 +29,7 @@ def get_image_files( List of absolute paths to image files """ if allowed_extensions is None: - allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"] + allowed_extensions = Image.SUPPORTED_EXTENSIONS # Normalize extensions to lowercase allowed_extensions = [ext.lower() for ext in allowed_extensions] @@ -204,7 +205,7 @@ def is_image_file( True if file is an image """ if allowed_extensions is None: - allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"] + allowed_extensions = Image.SUPPORTED_EXTENSIONS extension = Path(file_path).suffix.lower() return extension in [ext.lower() for ext in allowed_extensions] diff --git a/tests/test_image.py b/tests/test_image.py index 88b617f..1ddb206 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -27,7 +27,7 @@ class TestImage: def test_supported_extensions(self): """Test that supported extensions are correctly defined.""" - expected_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"] + expected_extensions = Image.SUPPORTED_EXTENSIONS assert Image.SUPPORTED_EXTENSIONS == expected_extensions def test_image_properties(self, tmp_path):