Formatting
This commit is contained in:
@@ -96,9 +96,7 @@ class YOLOWrapper:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Starting training: {name}")
|
logger.info(f"Starting training: {name}")
|
||||||
logger.info(
|
logger.info(f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}")
|
||||||
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Defaults for 16-bit safety: disable augmentations that force uint8 and HSV ops that assume 0..255.
|
# Defaults for 16-bit safety: disable augmentations that force uint8 and HSV ops that assume 0..255.
|
||||||
# Users can override by passing explicit kwargs.
|
# Users can override by passing explicit kwargs.
|
||||||
@@ -149,9 +147,7 @@ class YOLOWrapper:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Starting validation on {split} split")
|
logger.info(f"Starting validation on {split} split")
|
||||||
results = self.model.val(
|
results = self.model.val(data=data_yaml, split=split, device=self.device, **kwargs)
|
||||||
data=data_yaml, split=split, device=self.device, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Validation completed successfully")
|
logger.info("Validation completed successfully")
|
||||||
return self._format_validation_results(results)
|
return self._format_validation_results(results)
|
||||||
@@ -190,11 +186,9 @@ class YOLOWrapper:
|
|||||||
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
||||||
|
|
||||||
prepared_source, cleanup_path = self._prepare_source(source)
|
prepared_source, cleanup_path = self._prepare_source(source)
|
||||||
|
imgsz = 1088
|
||||||
try:
|
try:
|
||||||
logger.info(
|
logger.info(f"Running inference on {source} -> prepared_source {prepared_source}")
|
||||||
f"Running inference on {source} -> prepared_source {prepared_source}"
|
|
||||||
)
|
|
||||||
results = self.model.predict(
|
results = self.model.predict(
|
||||||
source=source,
|
source=source,
|
||||||
conf=conf,
|
conf=conf,
|
||||||
@@ -203,6 +197,7 @@ class YOLOWrapper:
|
|||||||
save_txt=save_txt,
|
save_txt=save_txt,
|
||||||
save_conf=save_conf,
|
save_conf=save_conf,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
imgsz=imgsz,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -218,13 +213,9 @@ class YOLOWrapper:
|
|||||||
try:
|
try:
|
||||||
os.remove(cleanup_path)
|
os.remove(cleanup_path)
|
||||||
except OSError as cleanup_error:
|
except OSError as cleanup_error:
|
||||||
logger.warning(
|
logger.warning(f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}")
|
||||||
f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def export(
|
def export(self, format: str = "onnx", output_path: Optional[str] = None, **kwargs) -> str:
|
||||||
self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
|
|
||||||
) -> str:
|
|
||||||
"""
|
"""
|
||||||
Export model to different format.
|
Export model to different format.
|
||||||
|
|
||||||
@@ -265,9 +256,7 @@ class YOLOWrapper:
|
|||||||
tmp.close()
|
tmp.close()
|
||||||
img_obj.save(tmp_path)
|
img_obj.save(tmp_path)
|
||||||
cleanup_path = tmp_path
|
cleanup_path = tmp_path
|
||||||
logger.info(
|
logger.info(f"Converted image {source_path} to RGB for inference at {tmp_path}")
|
||||||
f"Converted image {source_path} to RGB for inference at {tmp_path}"
|
|
||||||
)
|
|
||||||
return tmp_path, cleanup_path
|
return tmp_path, cleanup_path
|
||||||
except Exception as convert_error:
|
except Exception as convert_error:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -280,9 +269,7 @@ class YOLOWrapper:
|
|||||||
"""Format training results into dictionary."""
|
"""Format training results into dictionary."""
|
||||||
try:
|
try:
|
||||||
# Get the results dict
|
# Get the results dict
|
||||||
results_dict = (
|
results_dict = results.results_dict if hasattr(results, "results_dict") else {}
|
||||||
results.results_dict if hasattr(results, "results_dict") else {}
|
|
||||||
)
|
|
||||||
|
|
||||||
formatted = {
|
formatted = {
|
||||||
"success": True,
|
"success": True,
|
||||||
@@ -315,9 +302,7 @@ class YOLOWrapper:
|
|||||||
"mAP50-95": float(box_metrics.map),
|
"mAP50-95": float(box_metrics.map),
|
||||||
"precision": float(box_metrics.mp),
|
"precision": float(box_metrics.mp),
|
||||||
"recall": float(box_metrics.mr),
|
"recall": float(box_metrics.mr),
|
||||||
"fitness": (
|
"fitness": (float(results.fitness) if hasattr(results, "fitness") else 0.0),
|
||||||
float(results.fitness) if hasattr(results, "fitness") else 0.0
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add per-class metrics if available
|
# Add per-class metrics if available
|
||||||
@@ -327,11 +312,7 @@ class YOLOWrapper:
|
|||||||
if idx < len(box_metrics.ap):
|
if idx < len(box_metrics.ap):
|
||||||
class_metrics[name] = {
|
class_metrics[name] = {
|
||||||
"ap": float(box_metrics.ap[idx]),
|
"ap": float(box_metrics.ap[idx]),
|
||||||
"ap50": (
|
"ap50": (float(box_metrics.ap50[idx]) if hasattr(box_metrics, "ap50") else 0.0),
|
||||||
float(box_metrics.ap50[idx])
|
|
||||||
if hasattr(box_metrics, "ap50")
|
|
||||||
else 0.0
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
formatted["class_metrics"] = class_metrics
|
formatted["class_metrics"] = class_metrics
|
||||||
|
|
||||||
@@ -364,21 +345,15 @@ class YOLOWrapper:
|
|||||||
"class_id": int(boxes.cls[i]),
|
"class_id": int(boxes.cls[i]),
|
||||||
"class_name": result.names[int(boxes.cls[i])],
|
"class_name": result.names[int(boxes.cls[i])],
|
||||||
"confidence": float(boxes.conf[i]),
|
"confidence": float(boxes.conf[i]),
|
||||||
"bbox_normalized": [
|
"bbox_normalized": [float(v) for v in xyxyn], # [x_min, y_min, x_max, y_max]
|
||||||
float(v) for v in xyxyn
|
"bbox_absolute": [float(v) for v in boxes.xyxy[i].cpu().numpy()], # Absolute pixels
|
||||||
], # [x_min, y_min, x_max, y_max]
|
|
||||||
"bbox_absolute": [
|
|
||||||
float(v) for v in boxes.xyxy[i].cpu().numpy()
|
|
||||||
], # Absolute pixels
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Extract segmentation mask if available
|
# Extract segmentation mask if available
|
||||||
if has_masks:
|
if has_masks:
|
||||||
try:
|
try:
|
||||||
# Get the mask for this detection
|
# Get the mask for this detection
|
||||||
mask_data = result.masks.xy[
|
mask_data = result.masks.xy[i] # Polygon coordinates in absolute pixels
|
||||||
i
|
|
||||||
] # Polygon coordinates in absolute pixels
|
|
||||||
|
|
||||||
# Convert to normalized coordinates
|
# Convert to normalized coordinates
|
||||||
if len(mask_data) > 0:
|
if len(mask_data) > 0:
|
||||||
@@ -391,9 +366,7 @@ class YOLOWrapper:
|
|||||||
else:
|
else:
|
||||||
detection["segmentation_mask"] = None
|
detection["segmentation_mask"] = None
|
||||||
except Exception as mask_error:
|
except Exception as mask_error:
|
||||||
logger.warning(
|
logger.warning(f"Error extracting mask for detection {i}: {mask_error}")
|
||||||
f"Error extracting mask for detection {i}: {mask_error}"
|
|
||||||
)
|
|
||||||
detection["segmentation_mask"] = None
|
detection["segmentation_mask"] = None
|
||||||
else:
|
else:
|
||||||
detection["segmentation_mask"] = None
|
detection["segmentation_mask"] = None
|
||||||
@@ -407,9 +380,7 @@ class YOLOWrapper:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_bbox_format(
|
def convert_bbox_format(bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy") -> List[float]:
|
||||||
bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy"
|
|
||||||
) -> List[float]:
|
|
||||||
"""
|
"""
|
||||||
Convert bounding box between formats.
|
Convert bounding box between formats.
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user