Defining image extensions only in one place

This commit is contained in:
2025-12-11 15:50:14 +02:00
parent 8eb1cc8c86
commit 9ba44043ef
7 changed files with 14 additions and 24 deletions

View File

@@ -13,8 +13,9 @@ import hashlib
import yaml import yaml
from src.utils.logger import get_logger from src.utils.logger import get_logger
from src.utils.image import Image
IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp") IMAGE_EXTENSIONS = tuple(Image.SUPPORTED_EXTENSIONS)
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -168,7 +168,7 @@ class AnnotationTab(QWidget):
self, self,
"Select Image", "Select Image",
start_dir, start_dir,
"Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)", "Images (*" + " *".join(Image.SUPPORTED_EXTENSIONS) + ")",
) )
if not file_path: if not file_path:

View File

@@ -27,6 +27,7 @@ from src.utils.config_manager import ConfigManager
from src.utils.logger import get_logger from src.utils.logger import get_logger
from src.utils.file_utils import get_image_files from src.utils.file_utils import get_image_files
from src.model.inference import InferenceEngine from src.model.inference import InferenceEngine
from src.utils.image import Image
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -236,7 +237,7 @@ class DetectionTab(QWidget):
self, self,
"Select Image", "Select Image",
start_dir, start_dir,
"Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)", "Images (*" + " *".join(Image.SUPPORTED_EXTENSIONS) + ")",
) )
if not file_path: if not file_path:

View File

@@ -34,20 +34,13 @@ from PySide6.QtWidgets import (
from src.database.db_manager import DatabaseManager from src.database.db_manager import DatabaseManager
from src.model.yolo_wrapper import YOLOWrapper from src.model.yolo_wrapper import YOLOWrapper
from src.utils.config_manager import ConfigManager from src.utils.config_manager import ConfigManager
from src.utils.image import convert_grayscale_to_rgb_preserve_range from src.utils.image import Image, convert_grayscale_to_rgb_preserve_range
from src.utils.logger import get_logger from src.utils.logger import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
DEFAULT_IMAGE_EXTENSIONS = { DEFAULT_IMAGE_EXTENSIONS = set(Image.SUPPORTED_EXTENSIONS)
".jpg",
".jpeg",
".png",
".tif",
".tiff",
".bmp",
}
class TrainingWorker(QThread): class TrainingWorker(QThread):

View File

@@ -7,6 +7,7 @@ import yaml
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from src.utils.logger import get_logger from src.utils.logger import get_logger
from src.utils.image import Image
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -46,14 +47,7 @@ class ConfigManager:
"database": {"path": "data/detections.db"}, "database": {"path": "data/detections.db"},
"image_repository": { "image_repository": {
"base_path": "", "base_path": "",
"allowed_extensions": [ "allowed_extensions": Image.SUPPORTED_EXTENSIONS,
".jpg",
".jpeg",
".png",
".tif",
".tiff",
".bmp",
],
}, },
"models": { "models": {
"default_base_model": "yolov8s-seg.pt", "default_base_model": "yolov8s-seg.pt",
@@ -232,5 +226,5 @@ class ConfigManager:
def get_allowed_extensions(self) -> list: def get_allowed_extensions(self) -> list:
"""Get list of allowed image file extensions.""" """Get list of allowed image file extensions."""
return self.get( return self.get(
"image_repository.allowed_extensions", [".jpg", ".jpeg", ".png"] "image_repository.allowed_extensions", Image.SUPPORTED_EXTENSIONS
) )

View File

@@ -6,6 +6,7 @@ import os
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional
from src.utils.logger import get_logger from src.utils.logger import get_logger
from src.utils.image import Image
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -28,7 +29,7 @@ def get_image_files(
List of absolute paths to image files List of absolute paths to image files
""" """
if allowed_extensions is None: if allowed_extensions is None:
allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"] allowed_extensions = Image.SUPPORTED_EXTENSIONS
# Normalize extensions to lowercase # Normalize extensions to lowercase
allowed_extensions = [ext.lower() for ext in allowed_extensions] allowed_extensions = [ext.lower() for ext in allowed_extensions]
@@ -204,7 +205,7 @@ def is_image_file(
True if file is an image True if file is an image
""" """
if allowed_extensions is None: if allowed_extensions is None:
allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"] allowed_extensions = Image.SUPPORTED_EXTENSIONS
extension = Path(file_path).suffix.lower() extension = Path(file_path).suffix.lower()
return extension in [ext.lower() for ext in allowed_extensions] return extension in [ext.lower() for ext in allowed_extensions]

View File

@@ -27,7 +27,7 @@ class TestImage:
def test_supported_extensions(self): def test_supported_extensions(self):
"""Test that supported extensions are correctly defined.""" """Test that supported extensions are correctly defined."""
expected_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"] expected_extensions = Image.SUPPORTED_EXTENSIONS
assert Image.SUPPORTED_EXTENSIONS == expected_extensions assert Image.SUPPORTED_EXTENSIONS == expected_extensions
def test_image_properties(self, tmp_path): def test_image_properties(self, tmp_path):