Formatting

This commit is contained in:
2026-01-16 10:27:15 +02:00
parent 69cde09e53
commit 89e47591db

View File

@@ -96,9 +96,7 @@ class YOLOWrapper:
try: try:
logger.info(f"Starting training: {name}") logger.info(f"Starting training: {name}")
logger.info( logger.info(f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}")
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
)
# Defaults for 16-bit safety: disable augmentations that force uint8 and HSV ops that assume 0..255. # Defaults for 16-bit safety: disable augmentations that force uint8 and HSV ops that assume 0..255.
# Users can override by passing explicit kwargs. # Users can override by passing explicit kwargs.
@@ -149,9 +147,7 @@ class YOLOWrapper:
try: try:
logger.info(f"Starting validation on {split} split") logger.info(f"Starting validation on {split} split")
results = self.model.val( results = self.model.val(data=data_yaml, split=split, device=self.device, **kwargs)
data=data_yaml, split=split, device=self.device, **kwargs
)
logger.info("Validation completed successfully") logger.info("Validation completed successfully")
return self._format_validation_results(results) return self._format_validation_results(results)
@@ -190,11 +186,9 @@ class YOLOWrapper:
raise RuntimeError(f"Failed to load model from {self.model_path}") raise RuntimeError(f"Failed to load model from {self.model_path}")
prepared_source, cleanup_path = self._prepare_source(source) prepared_source, cleanup_path = self._prepare_source(source)
imgsz = 1088
try: try:
logger.info( logger.info(f"Running inference on {source} -> prepared_source {prepared_source}")
f"Running inference on {source} -> prepared_source {prepared_source}"
)
results = self.model.predict( results = self.model.predict(
source=source, source=source,
conf=conf, conf=conf,
@@ -203,6 +197,7 @@ class YOLOWrapper:
save_txt=save_txt, save_txt=save_txt,
save_conf=save_conf, save_conf=save_conf,
device=self.device, device=self.device,
imgsz=imgsz,
**kwargs, **kwargs,
) )
@@ -218,13 +213,9 @@ class YOLOWrapper:
try: try:
os.remove(cleanup_path) os.remove(cleanup_path)
except OSError as cleanup_error: except OSError as cleanup_error:
logger.warning( logger.warning(f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}")
f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}"
)
def export( def export(self, format: str = "onnx", output_path: Optional[str] = None, **kwargs) -> str:
self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
) -> str:
""" """
Export model to different format. Export model to different format.
@@ -265,9 +256,7 @@ class YOLOWrapper:
tmp.close() tmp.close()
img_obj.save(tmp_path) img_obj.save(tmp_path)
cleanup_path = tmp_path cleanup_path = tmp_path
logger.info( logger.info(f"Converted image {source_path} to RGB for inference at {tmp_path}")
f"Converted image {source_path} to RGB for inference at {tmp_path}"
)
return tmp_path, cleanup_path return tmp_path, cleanup_path
except Exception as convert_error: except Exception as convert_error:
logger.warning( logger.warning(
@@ -280,9 +269,7 @@ class YOLOWrapper:
"""Format training results into dictionary.""" """Format training results into dictionary."""
try: try:
# Get the results dict # Get the results dict
results_dict = ( results_dict = results.results_dict if hasattr(results, "results_dict") else {}
results.results_dict if hasattr(results, "results_dict") else {}
)
formatted = { formatted = {
"success": True, "success": True,
@@ -315,9 +302,7 @@ class YOLOWrapper:
"mAP50-95": float(box_metrics.map), "mAP50-95": float(box_metrics.map),
"precision": float(box_metrics.mp), "precision": float(box_metrics.mp),
"recall": float(box_metrics.mr), "recall": float(box_metrics.mr),
"fitness": ( "fitness": (float(results.fitness) if hasattr(results, "fitness") else 0.0),
float(results.fitness) if hasattr(results, "fitness") else 0.0
),
} }
# Add per-class metrics if available # Add per-class metrics if available
@@ -327,11 +312,7 @@ class YOLOWrapper:
if idx < len(box_metrics.ap): if idx < len(box_metrics.ap):
class_metrics[name] = { class_metrics[name] = {
"ap": float(box_metrics.ap[idx]), "ap": float(box_metrics.ap[idx]),
"ap50": ( "ap50": (float(box_metrics.ap50[idx]) if hasattr(box_metrics, "ap50") else 0.0),
float(box_metrics.ap50[idx])
if hasattr(box_metrics, "ap50")
else 0.0
),
} }
formatted["class_metrics"] = class_metrics formatted["class_metrics"] = class_metrics
@@ -364,21 +345,15 @@ class YOLOWrapper:
"class_id": int(boxes.cls[i]), "class_id": int(boxes.cls[i]),
"class_name": result.names[int(boxes.cls[i])], "class_name": result.names[int(boxes.cls[i])],
"confidence": float(boxes.conf[i]), "confidence": float(boxes.conf[i]),
"bbox_normalized": [ "bbox_normalized": [float(v) for v in xyxyn], # [x_min, y_min, x_max, y_max]
float(v) for v in xyxyn "bbox_absolute": [float(v) for v in boxes.xyxy[i].cpu().numpy()], # Absolute pixels
], # [x_min, y_min, x_max, y_max]
"bbox_absolute": [
float(v) for v in boxes.xyxy[i].cpu().numpy()
], # Absolute pixels
} }
# Extract segmentation mask if available # Extract segmentation mask if available
if has_masks: if has_masks:
try: try:
# Get the mask for this detection # Get the mask for this detection
mask_data = result.masks.xy[ mask_data = result.masks.xy[i] # Polygon coordinates in absolute pixels
i
] # Polygon coordinates in absolute pixels
# Convert to normalized coordinates # Convert to normalized coordinates
if len(mask_data) > 0: if len(mask_data) > 0:
@@ -391,9 +366,7 @@ class YOLOWrapper:
else: else:
detection["segmentation_mask"] = None detection["segmentation_mask"] = None
except Exception as mask_error: except Exception as mask_error:
logger.warning( logger.warning(f"Error extracting mask for detection {i}: {mask_error}")
f"Error extracting mask for detection {i}: {mask_error}"
)
detection["segmentation_mask"] = None detection["segmentation_mask"] = None
else: else:
detection["segmentation_mask"] = None detection["segmentation_mask"] = None
@@ -407,9 +380,7 @@ class YOLOWrapper:
return [] return []
@staticmethod @staticmethod
def convert_bbox_format( def convert_bbox_format(bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy") -> List[float]:
bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy"
) -> List[float]:
""" """
Convert bounding box between formats. Convert bounding box between formats.