Formatting
This commit is contained in:
@@ -96,9 +96,7 @@ class YOLOWrapper:
|
||||
|
||||
try:
|
||||
logger.info(f"Starting training: {name}")
|
||||
logger.info(
|
||||
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
|
||||
)
|
||||
logger.info(f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}")
|
||||
|
||||
# Defaults for 16-bit safety: disable augmentations that force uint8 and HSV ops that assume 0..255.
|
||||
# Users can override by passing explicit kwargs.
|
||||
@@ -149,9 +147,7 @@ class YOLOWrapper:
|
||||
|
||||
try:
|
||||
logger.info(f"Starting validation on {split} split")
|
||||
results = self.model.val(
|
||||
data=data_yaml, split=split, device=self.device, **kwargs
|
||||
)
|
||||
results = self.model.val(data=data_yaml, split=split, device=self.device, **kwargs)
|
||||
|
||||
logger.info("Validation completed successfully")
|
||||
return self._format_validation_results(results)
|
||||
@@ -190,11 +186,9 @@ class YOLOWrapper:
|
||||
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
||||
|
||||
prepared_source, cleanup_path = self._prepare_source(source)
|
||||
|
||||
imgsz = 1088
|
||||
try:
|
||||
logger.info(
|
||||
f"Running inference on {source} -> prepared_source {prepared_source}"
|
||||
)
|
||||
logger.info(f"Running inference on {source} -> prepared_source {prepared_source}")
|
||||
results = self.model.predict(
|
||||
source=source,
|
||||
conf=conf,
|
||||
@@ -203,6 +197,7 @@ class YOLOWrapper:
|
||||
save_txt=save_txt,
|
||||
save_conf=save_conf,
|
||||
device=self.device,
|
||||
imgsz=imgsz,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -218,13 +213,9 @@ class YOLOWrapper:
|
||||
try:
|
||||
os.remove(cleanup_path)
|
||||
except OSError as cleanup_error:
|
||||
logger.warning(
|
||||
f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}"
|
||||
)
|
||||
logger.warning(f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}")
|
||||
|
||||
def export(
|
||||
self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
|
||||
) -> str:
|
||||
def export(self, format: str = "onnx", output_path: Optional[str] = None, **kwargs) -> str:
|
||||
"""
|
||||
Export model to different format.
|
||||
|
||||
@@ -265,9 +256,7 @@ class YOLOWrapper:
|
||||
tmp.close()
|
||||
img_obj.save(tmp_path)
|
||||
cleanup_path = tmp_path
|
||||
logger.info(
|
||||
f"Converted image {source_path} to RGB for inference at {tmp_path}"
|
||||
)
|
||||
logger.info(f"Converted image {source_path} to RGB for inference at {tmp_path}")
|
||||
return tmp_path, cleanup_path
|
||||
except Exception as convert_error:
|
||||
logger.warning(
|
||||
@@ -280,9 +269,7 @@ class YOLOWrapper:
|
||||
"""Format training results into dictionary."""
|
||||
try:
|
||||
# Get the results dict
|
||||
results_dict = (
|
||||
results.results_dict if hasattr(results, "results_dict") else {}
|
||||
)
|
||||
results_dict = results.results_dict if hasattr(results, "results_dict") else {}
|
||||
|
||||
formatted = {
|
||||
"success": True,
|
||||
@@ -315,9 +302,7 @@ class YOLOWrapper:
|
||||
"mAP50-95": float(box_metrics.map),
|
||||
"precision": float(box_metrics.mp),
|
||||
"recall": float(box_metrics.mr),
|
||||
"fitness": (
|
||||
float(results.fitness) if hasattr(results, "fitness") else 0.0
|
||||
),
|
||||
"fitness": (float(results.fitness) if hasattr(results, "fitness") else 0.0),
|
||||
}
|
||||
|
||||
# Add per-class metrics if available
|
||||
@@ -327,11 +312,7 @@ class YOLOWrapper:
|
||||
if idx < len(box_metrics.ap):
|
||||
class_metrics[name] = {
|
||||
"ap": float(box_metrics.ap[idx]),
|
||||
"ap50": (
|
||||
float(box_metrics.ap50[idx])
|
||||
if hasattr(box_metrics, "ap50")
|
||||
else 0.0
|
||||
),
|
||||
"ap50": (float(box_metrics.ap50[idx]) if hasattr(box_metrics, "ap50") else 0.0),
|
||||
}
|
||||
formatted["class_metrics"] = class_metrics
|
||||
|
||||
@@ -364,21 +345,15 @@ class YOLOWrapper:
|
||||
"class_id": int(boxes.cls[i]),
|
||||
"class_name": result.names[int(boxes.cls[i])],
|
||||
"confidence": float(boxes.conf[i]),
|
||||
"bbox_normalized": [
|
||||
float(v) for v in xyxyn
|
||||
], # [x_min, y_min, x_max, y_max]
|
||||
"bbox_absolute": [
|
||||
float(v) for v in boxes.xyxy[i].cpu().numpy()
|
||||
], # Absolute pixels
|
||||
"bbox_normalized": [float(v) for v in xyxyn], # [x_min, y_min, x_max, y_max]
|
||||
"bbox_absolute": [float(v) for v in boxes.xyxy[i].cpu().numpy()], # Absolute pixels
|
||||
}
|
||||
|
||||
# Extract segmentation mask if available
|
||||
if has_masks:
|
||||
try:
|
||||
# Get the mask for this detection
|
||||
mask_data = result.masks.xy[
|
||||
i
|
||||
] # Polygon coordinates in absolute pixels
|
||||
mask_data = result.masks.xy[i] # Polygon coordinates in absolute pixels
|
||||
|
||||
# Convert to normalized coordinates
|
||||
if len(mask_data) > 0:
|
||||
@@ -391,9 +366,7 @@ class YOLOWrapper:
|
||||
else:
|
||||
detection["segmentation_mask"] = None
|
||||
except Exception as mask_error:
|
||||
logger.warning(
|
||||
f"Error extracting mask for detection {i}: {mask_error}"
|
||||
)
|
||||
logger.warning(f"Error extracting mask for detection {i}: {mask_error}")
|
||||
detection["segmentation_mask"] = None
|
||||
else:
|
||||
detection["segmentation_mask"] = None
|
||||
@@ -407,9 +380,7 @@ class YOLOWrapper:
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def convert_bbox_format(
|
||||
bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy"
|
||||
) -> List[float]:
|
||||
def convert_bbox_format(bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy") -> List[float]:
|
||||
"""
|
||||
Convert bounding box between formats.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user