diff --git a/src/model/yolo_wrapper.py b/src/model/yolo_wrapper.py index 751080d..a15a7bd 100644 --- a/src/model/yolo_wrapper.py +++ b/src/model/yolo_wrapper.py @@ -96,9 +96,7 @@ class YOLOWrapper: try: logger.info(f"Starting training: {name}") - logger.info( - f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}" - ) + logger.info(f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}") # Defaults for 16-bit safety: disable augmentations that force uint8 and HSV ops that assume 0..255. # Users can override by passing explicit kwargs. @@ -149,9 +147,7 @@ class YOLOWrapper: try: logger.info(f"Starting validation on {split} split") - results = self.model.val( - data=data_yaml, split=split, device=self.device, **kwargs - ) + results = self.model.val(data=data_yaml, split=split, device=self.device, **kwargs) logger.info("Validation completed successfully") return self._format_validation_results(results) @@ -190,11 +186,9 @@ class YOLOWrapper: raise RuntimeError(f"Failed to load model from {self.model_path}") prepared_source, cleanup_path = self._prepare_source(source) - + imgsz = 1088 try: - logger.info( - f"Running inference on {source} -> prepared_source {prepared_source}" - ) + logger.info(f"Running inference on {source} -> prepared_source {prepared_source}") results = self.model.predict( source=source, conf=conf, @@ -203,6 +197,7 @@ class YOLOWrapper: save_txt=save_txt, save_conf=save_conf, device=self.device, + imgsz=imgsz, **kwargs, ) @@ -218,13 +213,9 @@ class YOLOWrapper: try: os.remove(cleanup_path) except OSError as cleanup_error: - logger.warning( - f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}" - ) + logger.warning(f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}") - def export( - self, format: str = "onnx", output_path: Optional[str] = None, **kwargs - ) -> str: + def export(self, format: str = "onnx", output_path: Optional[str] = None, **kwargs) -> str: """ Export model to different format. @@ -265,9 +256,7 @@ class YOLOWrapper: tmp.close() img_obj.save(tmp_path) cleanup_path = tmp_path - logger.info( - f"Converted image {source_path} to RGB for inference at {tmp_path}" - ) + logger.info(f"Converted image {source_path} to RGB for inference at {tmp_path}") return tmp_path, cleanup_path except Exception as convert_error: logger.warning( @@ -280,9 +269,7 @@ class YOLOWrapper: """Format training results into dictionary.""" try: # Get the results dict - results_dict = ( - results.results_dict if hasattr(results, "results_dict") else {} - ) + results_dict = results.results_dict if hasattr(results, "results_dict") else {} formatted = { "success": True, @@ -315,9 +302,7 @@ class YOLOWrapper: "mAP50-95": float(box_metrics.map), "precision": float(box_metrics.mp), "recall": float(box_metrics.mr), - "fitness": ( - float(results.fitness) if hasattr(results, "fitness") else 0.0 - ), + "fitness": (float(results.fitness) if hasattr(results, "fitness") else 0.0), } # Add per-class metrics if available @@ -327,11 +312,7 @@ class YOLOWrapper: if idx < len(box_metrics.ap): class_metrics[name] = { "ap": float(box_metrics.ap[idx]), - "ap50": ( - float(box_metrics.ap50[idx]) - if hasattr(box_metrics, "ap50") - else 0.0 - ), + "ap50": (float(box_metrics.ap50[idx]) if hasattr(box_metrics, "ap50") else 0.0), } formatted["class_metrics"] = class_metrics @@ -364,21 +345,15 @@ class YOLOWrapper: "class_id": int(boxes.cls[i]), "class_name": result.names[int(boxes.cls[i])], "confidence": float(boxes.conf[i]), - "bbox_normalized": [ - float(v) for v in xyxyn - ], # [x_min, y_min, x_max, y_max] - "bbox_absolute": [ - float(v) for v in boxes.xyxy[i].cpu().numpy() - ], # Absolute pixels + "bbox_normalized": [float(v) for v in xyxyn], # [x_min, y_min, x_max, y_max] + "bbox_absolute": [float(v) for v in boxes.xyxy[i].cpu().numpy()], # Absolute pixels } # Extract segmentation mask if available if has_masks: try: # Get the mask for this detection - mask_data = result.masks.xy[ - i - ] # Polygon coordinates in absolute pixels + mask_data = result.masks.xy[i] # Polygon coordinates in absolute pixels # Convert to normalized coordinates if len(mask_data) > 0: @@ -391,9 +366,7 @@ class YOLOWrapper: else: detection["segmentation_mask"] = None except Exception as mask_error: - logger.warning( - f"Error extracting mask for detection {i}: {mask_error}" - ) + logger.warning(f"Error extracting mask for detection {i}: {mask_error}") detection["segmentation_mask"] = None else: detection["segmentation_mask"] = None @@ -407,9 +380,7 @@ class YOLOWrapper: return [] @staticmethod - def convert_bbox_format( - bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy" - ) -> List[float]: + def convert_bbox_format(bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy") -> List[float]: """ Convert bounding box between formats.