Adding result shower
This commit is contained in:
@@ -7,6 +7,9 @@ from ultralytics import YOLO
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Callable, Any
|
||||
import torch
|
||||
from PIL import Image
|
||||
import tempfile
|
||||
import os
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
|
||||
@@ -162,10 +165,12 @@ class YOLOWrapper:
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
|
||||
prepared_source, cleanup_path = self._prepare_source(source)
|
||||
|
||||
try:
|
||||
logger.info(f"Running inference on {source}")
|
||||
results = self.model.predict(
|
||||
source=source,
|
||||
source=prepared_source,
|
||||
conf=conf,
|
||||
iou=iou,
|
||||
save=save,
|
||||
@@ -182,6 +187,14 @@ class YOLOWrapper:
|
||||
except Exception as e:
|
||||
logger.error(f"Error during inference: {e}")
|
||||
raise
|
||||
finally:
|
||||
if cleanup_path:
|
||||
try:
|
||||
os.remove(cleanup_path)
|
||||
except OSError as cleanup_error:
|
||||
logger.warning(
|
||||
f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}"
|
||||
)
|
||||
|
||||
def export(
|
||||
self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
|
||||
@@ -210,6 +223,36 @@ class YOLOWrapper:
|
||||
logger.error(f"Error exporting model: {e}")
|
||||
raise
|
||||
|
||||
def _prepare_source(self, source):
|
||||
"""Convert single-channel images to RGB temporarily for inference."""
|
||||
cleanup_path = None
|
||||
|
||||
if isinstance(source, (str, Path)):
|
||||
source_path = Path(source)
|
||||
if source_path.is_file():
|
||||
try:
|
||||
with Image.open(source_path) as img:
|
||||
if len(img.getbands()) == 1:
|
||||
rgb_img = img.convert("RGB")
|
||||
suffix = source_path.suffix or ".png"
|
||||
tmp = tempfile.NamedTemporaryFile(
|
||||
suffix=suffix, delete=False
|
||||
)
|
||||
tmp_path = tmp.name
|
||||
tmp.close()
|
||||
rgb_img.save(tmp_path)
|
||||
cleanup_path = tmp_path
|
||||
logger.info(
|
||||
f"Converted single-channel image {source_path} to RGB for inference"
|
||||
)
|
||||
return tmp_path, cleanup_path
|
||||
except Exception as convert_error:
|
||||
logger.warning(
|
||||
f"Failed to preprocess {source_path} as RGB, continuing with original file: {convert_error}"
|
||||
)
|
||||
|
||||
return source, cleanup_path
|
||||
|
||||
def _format_training_results(self, results) -> Dict[str, Any]:
|
||||
"""Format training results into dictionary."""
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user