Adding python files
This commit is contained in:
364
src/model/yolo_wrapper.py
Normal file
364
src/model/yolo_wrapper.py
Normal file
@@ -0,0 +1,364 @@
|
||||
"""
|
||||
YOLO model wrapper for the microscopy object detection application.
|
||||
Provides a clean interface to YOLOv8 for training, validation, and inference.
|
||||
"""
|
||||
|
||||
from ultralytics import YOLO
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Callable, Any
|
||||
import torch
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class YOLOWrapper:
|
||||
"""Wrapper for YOLOv8 model operations."""
|
||||
|
||||
def __init__(self, model_path: str = "yolov8s.pt"):
|
||||
"""
|
||||
Initialize YOLO model.
|
||||
|
||||
Args:
|
||||
model_path: Path to model weights (.pt file)
|
||||
"""
|
||||
self.model_path = model_path
|
||||
self.model = None
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"YOLOWrapper initialized with device: {self.device}")
|
||||
|
||||
def load_model(self) -> bool:
|
||||
"""
|
||||
Load YOLO model from path.
|
||||
|
||||
Returns:
|
||||
True if loaded successfully, False otherwise
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Loading YOLO model from {self.model_path}")
|
||||
self.model = YOLO(self.model_path)
|
||||
self.model.to(self.device)
|
||||
logger.info("Model loaded successfully")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading model: {e}")
|
||||
return False
|
||||
|
||||
def train(
|
||||
self,
|
||||
data_yaml: str,
|
||||
epochs: int = 100,
|
||||
imgsz: int = 640,
|
||||
batch: int = 16,
|
||||
patience: int = 50,
|
||||
save_dir: str = "data/models",
|
||||
name: str = "custom_model",
|
||||
resume: bool = False,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Train the YOLO model.
|
||||
|
||||
Args:
|
||||
data_yaml: Path to data.yaml configuration file
|
||||
epochs: Number of training epochs
|
||||
imgsz: Input image size
|
||||
batch: Batch size
|
||||
patience: Early stopping patience
|
||||
save_dir: Directory to save trained model
|
||||
name: Name for the training run
|
||||
resume: Resume training from last checkpoint
|
||||
**kwargs: Additional training arguments
|
||||
|
||||
Returns:
|
||||
Dictionary with training results
|
||||
"""
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
|
||||
try:
|
||||
logger.info(f"Starting training: {name}")
|
||||
logger.info(
|
||||
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
|
||||
)
|
||||
|
||||
# Train the model
|
||||
results = self.model.train(
|
||||
data=data_yaml,
|
||||
epochs=epochs,
|
||||
imgsz=imgsz,
|
||||
batch=batch,
|
||||
patience=patience,
|
||||
project=save_dir,
|
||||
name=name,
|
||||
device=self.device,
|
||||
resume=resume,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
logger.info("Training completed successfully")
|
||||
return self._format_training_results(results)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during training: {e}")
|
||||
raise
|
||||
|
||||
def validate(self, data_yaml: str, split: str = "val", **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate the model.
|
||||
|
||||
Args:
|
||||
data_yaml: Path to data.yaml configuration file
|
||||
split: Dataset split to validate on ('val' or 'test')
|
||||
**kwargs: Additional validation arguments
|
||||
|
||||
Returns:
|
||||
Dictionary with validation metrics
|
||||
"""
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
|
||||
try:
|
||||
logger.info(f"Starting validation on {split} split")
|
||||
results = self.model.val(
|
||||
data=data_yaml, split=split, device=self.device, **kwargs
|
||||
)
|
||||
|
||||
logger.info("Validation completed successfully")
|
||||
return self._format_validation_results(results)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during validation: {e}")
|
||||
raise
|
||||
|
||||
def predict(
|
||||
self,
|
||||
source: str,
|
||||
conf: float = 0.25,
|
||||
iou: float = 0.45,
|
||||
save: bool = False,
|
||||
save_txt: bool = False,
|
||||
save_conf: bool = False,
|
||||
**kwargs,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Perform inference on image(s).
|
||||
|
||||
Args:
|
||||
source: Path to image or directory
|
||||
conf: Confidence threshold
|
||||
iou: IoU threshold for NMS
|
||||
save: Whether to save annotated images
|
||||
save_txt: Whether to save labels to .txt files
|
||||
save_conf: Whether to save confidence in labels
|
||||
**kwargs: Additional prediction arguments
|
||||
|
||||
Returns:
|
||||
List of detection dictionaries
|
||||
"""
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
|
||||
try:
|
||||
logger.info(f"Running inference on {source}")
|
||||
results = self.model.predict(
|
||||
source=source,
|
||||
conf=conf,
|
||||
iou=iou,
|
||||
save=save,
|
||||
save_txt=save_txt,
|
||||
save_conf=save_conf,
|
||||
device=self.device,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
detections = self._format_prediction_results(results)
|
||||
logger.info(f"Inference complete: {len(detections)} detections")
|
||||
return detections
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during inference: {e}")
|
||||
raise
|
||||
|
||||
def export(
|
||||
self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Export model to different format.
|
||||
|
||||
Args:
|
||||
format: Export format (onnx, torchscript, tflite, etc.)
|
||||
output_path: Path for exported model
|
||||
**kwargs: Additional export arguments
|
||||
|
||||
Returns:
|
||||
Path to exported model
|
||||
"""
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
|
||||
try:
|
||||
logger.info(f"Exporting model to {format} format")
|
||||
export_path = self.model.export(format=format, **kwargs)
|
||||
logger.info(f"Model exported to {export_path}")
|
||||
return str(export_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error exporting model: {e}")
|
||||
raise
|
||||
|
||||
def _format_training_results(self, results) -> Dict[str, Any]:
|
||||
"""Format training results into dictionary."""
|
||||
try:
|
||||
# Get the results dict
|
||||
results_dict = (
|
||||
results.results_dict if hasattr(results, "results_dict") else {}
|
||||
)
|
||||
|
||||
formatted = {
|
||||
"success": True,
|
||||
"final_epoch": getattr(results, "epoch", 0),
|
||||
"metrics": {
|
||||
"mAP50": float(results_dict.get("metrics/mAP50(B)", 0)),
|
||||
"mAP50-95": float(results_dict.get("metrics/mAP50-95(B)", 0)),
|
||||
"precision": float(results_dict.get("metrics/precision(B)", 0)),
|
||||
"recall": float(results_dict.get("metrics/recall(B)", 0)),
|
||||
},
|
||||
"best_model_path": str(Path(results.save_dir) / "weights" / "best.pt"),
|
||||
"last_model_path": str(Path(results.save_dir) / "weights" / "last.pt"),
|
||||
"save_dir": str(results.save_dir),
|
||||
}
|
||||
|
||||
return formatted
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting training results: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
def _format_validation_results(self, results) -> Dict[str, Any]:
|
||||
"""Format validation results into dictionary."""
|
||||
try:
|
||||
box_metrics = results.box
|
||||
|
||||
formatted = {
|
||||
"success": True,
|
||||
"mAP50": float(box_metrics.map50),
|
||||
"mAP50-95": float(box_metrics.map),
|
||||
"precision": float(box_metrics.mp),
|
||||
"recall": float(box_metrics.mr),
|
||||
"fitness": (
|
||||
float(results.fitness) if hasattr(results, "fitness") else 0.0
|
||||
),
|
||||
}
|
||||
|
||||
# Add per-class metrics if available
|
||||
if hasattr(box_metrics, "ap") and hasattr(results, "names"):
|
||||
class_metrics = {}
|
||||
for idx, name in results.names.items():
|
||||
if idx < len(box_metrics.ap):
|
||||
class_metrics[name] = {
|
||||
"ap": float(box_metrics.ap[idx]),
|
||||
"ap50": (
|
||||
float(box_metrics.ap50[idx])
|
||||
if hasattr(box_metrics, "ap50")
|
||||
else 0.0
|
||||
),
|
||||
}
|
||||
formatted["class_metrics"] = class_metrics
|
||||
|
||||
return formatted
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting validation results: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
def _format_prediction_results(self, results) -> List[Dict]:
|
||||
"""Format prediction results into list of dictionaries."""
|
||||
detections = []
|
||||
|
||||
try:
|
||||
for result in results:
|
||||
boxes = result.boxes
|
||||
image_path = str(result.path)
|
||||
orig_shape = result.orig_shape # (height, width)
|
||||
|
||||
for i in range(len(boxes)):
|
||||
# Get normalized coordinates
|
||||
xyxyn = boxes.xyxyn[i].cpu().numpy() # Normalized [x1, y1, x2, y2]
|
||||
|
||||
detection = {
|
||||
"image_path": image_path,
|
||||
"class_id": int(boxes.cls[i]),
|
||||
"class_name": result.names[int(boxes.cls[i])],
|
||||
"confidence": float(boxes.conf[i]),
|
||||
"bbox_normalized": [
|
||||
float(v) for v in xyxyn
|
||||
], # [x_min, y_min, x_max, y_max]
|
||||
"bbox_absolute": [
|
||||
float(v) for v in boxes.xyxy[i].cpu().numpy()
|
||||
], # Absolute pixels
|
||||
}
|
||||
detections.append(detection)
|
||||
|
||||
return detections
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting prediction results: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def convert_bbox_format(
|
||||
bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy"
|
||||
) -> List[float]:
|
||||
"""
|
||||
Convert bounding box between formats.
|
||||
|
||||
Formats:
|
||||
- xywh: [x_center, y_center, width, height]
|
||||
- xyxy: [x_min, y_min, x_max, y_max]
|
||||
|
||||
Args:
|
||||
bbox: Bounding box coordinates
|
||||
format_from: Source format
|
||||
format_to: Target format
|
||||
|
||||
Returns:
|
||||
Converted bounding box
|
||||
"""
|
||||
if format_from == "xywh" and format_to == "xyxy":
|
||||
x, y, w, h = bbox
|
||||
return [x - w / 2, y - h / 2, x + w / 2, y + h / 2]
|
||||
elif format_from == "xyxy" and format_to == "xywh":
|
||||
x1, y1, x2, y2 = bbox
|
||||
return [(x1 + x2) / 2, (y1 + y2) / 2, x2 - x1, y2 - y1]
|
||||
else:
|
||||
return bbox
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about the loaded model.
|
||||
|
||||
Returns:
|
||||
Dictionary with model information
|
||||
"""
|
||||
if self.model is None:
|
||||
return {"error": "Model not loaded"}
|
||||
|
||||
try:
|
||||
info = {
|
||||
"model_path": self.model_path,
|
||||
"device": self.device,
|
||||
"task": getattr(self.model, "task", "unknown"),
|
||||
}
|
||||
|
||||
# Try to get class names
|
||||
if hasattr(self.model, "names"):
|
||||
info["classes"] = self.model.names
|
||||
info["num_classes"] = len(self.model.names)
|
||||
|
||||
return info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model info: {e}")
|
||||
return {"error": str(e)}
|
||||
Reference in New Issue
Block a user