Making image manipulations thru one class
This commit is contained in:
@@ -10,7 +10,6 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from PIL import Image as PILImage
|
|
||||||
from PySide6.QtCore import Qt, QThread, Signal
|
from PySide6.QtCore import Qt, QThread, Signal
|
||||||
from PySide6.QtWidgets import (
|
from PySide6.QtWidgets import (
|
||||||
QWidget,
|
QWidget,
|
||||||
@@ -1293,8 +1292,8 @@ class TrainingTab(QWidget):
|
|||||||
if not sample_image:
|
if not sample_image:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
with PILImage.open(sample_image) as img:
|
img = Image(sample_image)
|
||||||
return img.mode.upper() != "RGB"
|
return img.pil_image.mode.upper() != "RGB"
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning(f"Failed to inspect image {sample_image}: {exc}")
|
logger.warning(f"Failed to inspect image {sample_image}: {exc}")
|
||||||
return False
|
return False
|
||||||
@@ -1354,12 +1353,13 @@ class TrainingTab(QWidget):
|
|||||||
dst = dst_dir / relative
|
dst = dst_dir / relative
|
||||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||||
try:
|
try:
|
||||||
with PILImage.open(src) as img:
|
img_obj = Image(src)
|
||||||
if len(img.getbands()) == 1:
|
pil_img = img_obj.pil_image
|
||||||
rgb_img = convert_grayscale_to_rgb_preserve_range(img)
|
if len(pil_img.getbands()) == 1:
|
||||||
else:
|
rgb_img = convert_grayscale_to_rgb_preserve_range(pil_img)
|
||||||
rgb_img = img.convert("RGB")
|
else:
|
||||||
rgb_img.save(dst)
|
rgb_img = pil_img.convert("RGB")
|
||||||
|
rgb_img.save(dst)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning(f"Failed to convert {src} to RGB: {exc}")
|
logger.warning(f"Failed to convert {src} to RGB: {exc}")
|
||||||
|
|
||||||
|
|||||||
@@ -5,12 +5,12 @@ Handles detection inference and result storage.
|
|||||||
|
|
||||||
from typing import List, Dict, Optional, Callable
|
from typing import List, Dict, Optional, Callable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from PIL import Image
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from src.model.yolo_wrapper import YOLOWrapper
|
from src.model.yolo_wrapper import YOLOWrapper
|
||||||
from src.database.db_manager import DatabaseManager
|
from src.database.db_manager import DatabaseManager
|
||||||
|
from src.utils.image import Image
|
||||||
from src.utils.logger import get_logger
|
from src.utils.logger import get_logger
|
||||||
from src.utils.file_utils import get_relative_path
|
from src.utils.file_utils import get_relative_path
|
||||||
|
|
||||||
@@ -64,9 +64,9 @@ class InferenceEngine:
|
|||||||
stored_relative_path = str(Path(image_path).resolve())
|
stored_relative_path = str(Path(image_path).resolve())
|
||||||
|
|
||||||
# Get image dimensions
|
# Get image dimensions
|
||||||
img = Image.open(image_path)
|
img = Image(image_path)
|
||||||
width, height = img.size
|
width = img.width
|
||||||
img.close()
|
height = img.height
|
||||||
|
|
||||||
# Perform detection
|
# Perform detection
|
||||||
detections = self.yolo.predict(image_path, conf=conf)
|
detections = self.yolo.predict(image_path, conf=conf)
|
||||||
|
|||||||
@@ -7,10 +7,9 @@ from ultralytics import YOLO
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, List, Dict, Callable, Any
|
from typing import Optional, List, Dict, Callable, Any
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import os
|
import os
|
||||||
from src.utils.image import convert_grayscale_to_rgb_preserve_range
|
from src.utils.image import Image, convert_grayscale_to_rgb_preserve_range
|
||||||
from src.utils.logger import get_logger
|
from src.utils.logger import get_logger
|
||||||
|
|
||||||
|
|
||||||
@@ -232,22 +231,23 @@ class YOLOWrapper:
|
|||||||
source_path = Path(source)
|
source_path = Path(source)
|
||||||
if source_path.is_file():
|
if source_path.is_file():
|
||||||
try:
|
try:
|
||||||
with Image.open(source_path) as img:
|
img_obj = Image(source_path)
|
||||||
if len(img.getbands()) == 1:
|
pil_img = img_obj.pil_image
|
||||||
rgb_img = convert_grayscale_to_rgb_preserve_range(img)
|
if len(pil_img.getbands()) == 1:
|
||||||
else:
|
rgb_img = convert_grayscale_to_rgb_preserve_range(pil_img)
|
||||||
rgb_img = img.convert("RGB")
|
else:
|
||||||
|
rgb_img = pil_img.convert("RGB")
|
||||||
|
|
||||||
suffix = source_path.suffix or ".png"
|
suffix = source_path.suffix or ".png"
|
||||||
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
|
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
|
||||||
tmp_path = tmp.name
|
tmp_path = tmp.name
|
||||||
tmp.close()
|
tmp.close()
|
||||||
rgb_img.save(tmp_path)
|
rgb_img.save(tmp_path)
|
||||||
cleanup_path = tmp_path
|
cleanup_path = tmp_path
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Converted image {source_path} to RGB for inference at {tmp_path}"
|
f"Converted image {source_path} to RGB for inference at {tmp_path}"
|
||||||
)
|
)
|
||||||
return tmp_path, cleanup_path
|
return tmp_path, cleanup_path
|
||||||
except Exception as convert_error:
|
except Exception as convert_error:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Failed to preprocess {source_path} as RGB, continuing with original file: {convert_error}"
|
f"Failed to preprocess {source_path} as RGB, continuing with original file: {convert_error}"
|
||||||
|
|||||||
Reference in New Issue
Block a user