diff --git a/src/model/yolo_wrapper.py b/src/model/yolo_wrapper.py index b2c90e5..751080d 100644 --- a/src/model/yolo_wrapper.py +++ b/src/model/yolo_wrapper.py @@ -196,7 +196,7 @@ class YOLOWrapper: f"Running inference on {source} -> prepared_source {prepared_source}" ) results = self.model.predict( - source=prepared_source, + source=source, conf=conf, iou=iou, save=save, diff --git a/src/utils/image.py b/src/utils/image.py index c51590f..f7403cb 100644 --- a/src/utils/image.py +++ b/src/utils/image.py @@ -129,8 +129,13 @@ class Image: # Extract metadata print(self._data.shape) - self._height, self._width = self._data.shape[:2] - self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1 + if len(self._data.shape) == 2: + self._height, self._width = self._data.shape[:2] + self._channels = 1 + else: + self._height, self._width = self._data.shape[1:] + self._channels = self._data.shape[0] + # self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1 self._format = self.path.suffix.lower().lstrip(".") self._size_bytes = self.path.stat().st_size self._dtype = self._data.dtype @@ -317,6 +322,7 @@ class Image: if self.channels == 1: if pseudo_rgb: img = get_pseudo_rgb(self.data) + print("Image.save", img.shape) else: img = np.repeat(self.data, 3, axis=2)