Defining image extensions only in one place
This commit is contained in:
@@ -13,8 +13,9 @@ import hashlib
|
||||
import yaml
|
||||
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.image import Image
|
||||
|
||||
IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp")
|
||||
IMAGE_EXTENSIONS = tuple(Image.SUPPORTED_EXTENSIONS)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -168,7 +168,7 @@ class AnnotationTab(QWidget):
|
||||
self,
|
||||
"Select Image",
|
||||
start_dir,
|
||||
"Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)",
|
||||
"Images (*" + " *".join(Image.SUPPORTED_EXTENSIONS) + ")",
|
||||
)
|
||||
|
||||
if not file_path:
|
||||
|
||||
@@ -27,6 +27,7 @@ from src.utils.config_manager import ConfigManager
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.file_utils import get_image_files
|
||||
from src.model.inference import InferenceEngine
|
||||
from src.utils.image import Image
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -236,7 +237,7 @@ class DetectionTab(QWidget):
|
||||
self,
|
||||
"Select Image",
|
||||
start_dir,
|
||||
"Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)",
|
||||
"Images (*" + " *".join(Image.SUPPORTED_EXTENSIONS) + ")",
|
||||
)
|
||||
|
||||
if not file_path:
|
||||
|
||||
@@ -34,20 +34,13 @@ from PySide6.QtWidgets import (
|
||||
from src.database.db_manager import DatabaseManager
|
||||
from src.model.yolo_wrapper import YOLOWrapper
|
||||
from src.utils.config_manager import ConfigManager
|
||||
from src.utils.image import convert_grayscale_to_rgb_preserve_range
|
||||
from src.utils.image import Image, convert_grayscale_to_rgb_preserve_range
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
DEFAULT_IMAGE_EXTENSIONS = {
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".png",
|
||||
".tif",
|
||||
".tiff",
|
||||
".bmp",
|
||||
}
|
||||
DEFAULT_IMAGE_EXTENSIONS = set(Image.SUPPORTED_EXTENSIONS)
|
||||
|
||||
|
||||
class TrainingWorker(QThread):
|
||||
|
||||
@@ -7,6 +7,7 @@ import yaml
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.image import Image
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -46,14 +47,7 @@ class ConfigManager:
|
||||
"database": {"path": "data/detections.db"},
|
||||
"image_repository": {
|
||||
"base_path": "",
|
||||
"allowed_extensions": [
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".png",
|
||||
".tif",
|
||||
".tiff",
|
||||
".bmp",
|
||||
],
|
||||
"allowed_extensions": Image.SUPPORTED_EXTENSIONS,
|
||||
},
|
||||
"models": {
|
||||
"default_base_model": "yolov8s-seg.pt",
|
||||
@@ -232,5 +226,5 @@ class ConfigManager:
|
||||
def get_allowed_extensions(self) -> list:
|
||||
"""Get list of allowed image file extensions."""
|
||||
return self.get(
|
||||
"image_repository.allowed_extensions", [".jpg", ".jpeg", ".png"]
|
||||
"image_repository.allowed_extensions", Image.SUPPORTED_EXTENSIONS
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.image import Image
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -28,7 +29,7 @@ def get_image_files(
|
||||
List of absolute paths to image files
|
||||
"""
|
||||
if allowed_extensions is None:
|
||||
allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
|
||||
allowed_extensions = Image.SUPPORTED_EXTENSIONS
|
||||
|
||||
# Normalize extensions to lowercase
|
||||
allowed_extensions = [ext.lower() for ext in allowed_extensions]
|
||||
@@ -204,7 +205,7 @@ def is_image_file(
|
||||
True if file is an image
|
||||
"""
|
||||
if allowed_extensions is None:
|
||||
allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
|
||||
allowed_extensions = Image.SUPPORTED_EXTENSIONS
|
||||
|
||||
extension = Path(file_path).suffix.lower()
|
||||
return extension in [ext.lower() for ext in allowed_extensions]
|
||||
|
||||
@@ -27,7 +27,7 @@ class TestImage:
|
||||
|
||||
def test_supported_extensions(self):
|
||||
"""Test that supported extensions are correctly defined."""
|
||||
expected_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
|
||||
expected_extensions = Image.SUPPORTED_EXTENSIONS
|
||||
assert Image.SUPPORTED_EXTENSIONS == expected_extensions
|
||||
|
||||
def test_image_properties(self, tmp_path):
|
||||
|
||||
Reference in New Issue
Block a user