1060 lines
35 KiB
Python
1060 lines
35 KiB
Python
"""
|
|
Database manager for the microscopy object detection application.
|
|
Handles all database operations including CRUD operations, queries, and exports.
|
|
"""
|
|
|
|
import sqlite3
|
|
import json
|
|
from datetime import datetime
|
|
from typing import List, Dict, Optional, Tuple, Any, Union
|
|
from pathlib import Path
|
|
import csv
|
|
import hashlib
|
|
import yaml
|
|
|
|
from src.utils.logger import get_logger
|
|
|
|
IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp")
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class DatabaseManager:
|
|
"""Manages all database operations for the application."""
|
|
|
|
def __init__(self, db_path: str = "data/detections.db"):
|
|
"""
|
|
Initialize database manager.
|
|
|
|
Args:
|
|
db_path: Path to SQLite database file
|
|
"""
|
|
self.db_path = db_path
|
|
self._ensure_database_exists()
|
|
|
|
def _ensure_database_exists(self) -> None:
|
|
"""Create database and tables if they don't exist."""
|
|
# Create directory if it doesn't exist
|
|
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
conn = self.get_connection()
|
|
try:
|
|
# Check if annotations table needs migration
|
|
self._migrate_annotations_table(conn)
|
|
|
|
# Read schema file and execute
|
|
schema_path = Path(__file__).parent / "schema.sql"
|
|
with open(schema_path, "r") as f:
|
|
schema_sql = f.read()
|
|
|
|
conn.executescript(schema_sql)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
def _migrate_annotations_table(self, conn: sqlite3.Connection) -> None:
|
|
"""
|
|
Migrate annotations table from old schema (class_name) to new schema (class_id).
|
|
"""
|
|
cursor = conn.cursor()
|
|
|
|
# Check if annotations table exists
|
|
cursor.execute(
|
|
"SELECT name FROM sqlite_master WHERE type='table' AND name='annotations'"
|
|
)
|
|
if not cursor.fetchone():
|
|
# Table doesn't exist yet, no migration needed
|
|
return
|
|
|
|
# Check if table has old schema (class_name column)
|
|
cursor.execute("PRAGMA table_info(annotations)")
|
|
columns = {row[1]: row for row in cursor.fetchall()}
|
|
|
|
if "class_name" in columns and "class_id" not in columns:
|
|
# Old schema detected, need to migrate
|
|
print("Migrating annotations table to new schema with class_id...")
|
|
|
|
# Drop old annotations table (assuming no critical data since this is a new feature)
|
|
cursor.execute("DROP TABLE IF EXISTS annotations")
|
|
conn.commit()
|
|
print("Old annotations table dropped, will be recreated with new schema")
|
|
|
|
def get_connection(self) -> sqlite3.Connection:
|
|
"""Get database connection with proper settings."""
|
|
conn = sqlite3.connect(self.db_path)
|
|
conn.row_factory = sqlite3.Row # Enable column access by name
|
|
conn.execute("PRAGMA foreign_keys = ON") # Enable foreign keys
|
|
return conn
|
|
|
|
# ==================== Model Operations ====================
|
|
|
|
def add_model(
|
|
self,
|
|
model_name: str,
|
|
model_version: str,
|
|
model_path: str,
|
|
base_model: str = "yolov8s-seg.pt",
|
|
training_params: Optional[Dict] = None,
|
|
metrics: Optional[Dict] = None,
|
|
) -> int:
|
|
"""
|
|
Add a new model to the database.
|
|
|
|
Args:
|
|
model_name: Name of the model
|
|
model_version: Version string
|
|
model_path: Path to model weights file
|
|
base_model: Base model used for training
|
|
training_params: Dictionary of training parameters
|
|
metrics: Dictionary of validation metrics
|
|
|
|
Returns:
|
|
ID of the inserted model
|
|
"""
|
|
conn = self.get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"""
|
|
INSERT INTO models (model_name, model_version, model_path, base_model, training_params, metrics)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
model_name,
|
|
model_version,
|
|
model_path,
|
|
base_model,
|
|
json.dumps(training_params) if training_params else None,
|
|
json.dumps(metrics) if metrics else None,
|
|
),
|
|
)
|
|
conn.commit()
|
|
return cursor.lastrowid
|
|
finally:
|
|
conn.close()
|
|
|
|
def get_models(self, filters: Optional[Dict] = None) -> List[Dict]:
|
|
"""
|
|
Retrieve models from database.
|
|
|
|
Args:
|
|
filters: Optional filters (e.g., {'model_name': 'my_model'})
|
|
|
|
Returns:
|
|
List of model dictionaries
|
|
"""
|
|
conn = self.get_connection()
|
|
try:
|
|
query = "SELECT * FROM models"
|
|
params = []
|
|
|
|
if filters:
|
|
conditions = []
|
|
for key, value in filters.items():
|
|
conditions.append(f"{key} = ?")
|
|
params.append(value)
|
|
query += " WHERE " + " AND ".join(conditions)
|
|
|
|
query += " ORDER BY created_at DESC"
|
|
|
|
cursor = conn.cursor()
|
|
cursor.execute(query, params)
|
|
|
|
models = []
|
|
for row in cursor.fetchall():
|
|
model = dict(row)
|
|
# Parse JSON fields
|
|
if model["training_params"]:
|
|
model["training_params"] = json.loads(model["training_params"])
|
|
if model["metrics"]:
|
|
model["metrics"] = json.loads(model["metrics"])
|
|
models.append(model)
|
|
|
|
return models
|
|
finally:
|
|
conn.close()
|
|
|
|
def get_model_by_id(self, model_id: int) -> Optional[Dict]:
|
|
"""Get model by ID."""
|
|
models = self.get_models({"id": model_id})
|
|
return models[0] if models else None
|
|
|
|
def update_model(self, model_id: int, updates: Dict) -> bool:
|
|
"""Update model fields."""
|
|
conn = self.get_connection()
|
|
try:
|
|
# Build update query
|
|
set_clauses = []
|
|
params = []
|
|
for key, value in updates.items():
|
|
if key in ["training_params", "metrics"] and isinstance(value, dict):
|
|
value = json.dumps(value)
|
|
set_clauses.append(f"{key} = ?")
|
|
params.append(value)
|
|
|
|
params.append(model_id)
|
|
|
|
query = f"UPDATE models SET {', '.join(set_clauses)} WHERE id = ?"
|
|
cursor = conn.cursor()
|
|
cursor.execute(query, params)
|
|
conn.commit()
|
|
return cursor.rowcount > 0
|
|
finally:
|
|
conn.close()
|
|
|
|
# ==================== Image Operations ====================
|
|
|
|
def add_image(
|
|
self,
|
|
relative_path: str,
|
|
filename: str,
|
|
width: int,
|
|
height: int,
|
|
captured_at: Optional[datetime] = None,
|
|
checksum: Optional[str] = None,
|
|
) -> int:
|
|
"""
|
|
Add a new image to the database.
|
|
|
|
Args:
|
|
relative_path: Path relative to image repository
|
|
filename: Image filename
|
|
width: Image width in pixels
|
|
height: Image height in pixels
|
|
captured_at: When image was captured (if known)
|
|
checksum: MD5 checksum of image file
|
|
|
|
Returns:
|
|
ID of the inserted image
|
|
"""
|
|
conn = self.get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"""
|
|
INSERT INTO images (relative_path, filename, width, height, captured_at, checksum)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(relative_path, filename, width, height, captured_at, checksum),
|
|
)
|
|
conn.commit()
|
|
return cursor.lastrowid
|
|
except sqlite3.IntegrityError:
|
|
# Image already exists, return its ID
|
|
cursor.execute(
|
|
"SELECT id FROM images WHERE relative_path = ?", (relative_path,)
|
|
)
|
|
row = cursor.fetchone()
|
|
return row["id"] if row else None
|
|
finally:
|
|
conn.close()
|
|
|
|
def get_image_by_path(self, relative_path: str) -> Optional[Dict]:
|
|
"""Get image by relative path."""
|
|
conn = self.get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"SELECT * FROM images WHERE relative_path = ?", (relative_path,)
|
|
)
|
|
row = cursor.fetchone()
|
|
return dict(row) if row else None
|
|
finally:
|
|
conn.close()
|
|
|
|
def get_or_create_image(
|
|
self, relative_path: str, filename: str, width: int, height: int
|
|
) -> int:
|
|
"""Get existing image or create new one."""
|
|
existing = self.get_image_by_path(relative_path)
|
|
if existing:
|
|
return existing["id"]
|
|
return self.add_image(relative_path, filename, width, height)
|
|
|
|
# ==================== Detection Operations ====================
|
|
|
|
def add_detection(
|
|
self,
|
|
image_id: int,
|
|
model_id: int,
|
|
class_name: str,
|
|
bbox: Tuple[float, float, float, float], # (x_min, y_min, x_max, y_max)
|
|
confidence: float,
|
|
segmentation_mask: Optional[List[List[float]]] = None,
|
|
metadata: Optional[Dict] = None,
|
|
) -> int:
|
|
"""
|
|
Add a new detection to the database.
|
|
|
|
Args:
|
|
image_id: ID of the image
|
|
model_id: ID of the model used
|
|
class_name: Detected object class
|
|
bbox: Bounding box coordinates (normalized 0-1)
|
|
confidence: Detection confidence score
|
|
segmentation_mask: Polygon coordinates for segmentation [[x1,y1], [x2,y2], ...]
|
|
metadata: Additional metadata
|
|
|
|
Returns:
|
|
ID of the inserted detection
|
|
"""
|
|
conn = self.get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
x_min, y_min, x_max, y_max = bbox
|
|
cursor.execute(
|
|
"""
|
|
INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, segmentation_mask, metadata)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
image_id,
|
|
model_id,
|
|
class_name,
|
|
x_min,
|
|
y_min,
|
|
x_max,
|
|
y_max,
|
|
confidence,
|
|
json.dumps(segmentation_mask) if segmentation_mask else None,
|
|
json.dumps(metadata) if metadata else None,
|
|
),
|
|
)
|
|
conn.commit()
|
|
return cursor.lastrowid
|
|
finally:
|
|
conn.close()
|
|
|
|
def add_detections_batch(self, detections: List[Dict]) -> int:
|
|
"""
|
|
Add multiple detections in a single transaction.
|
|
|
|
Args:
|
|
detections: List of detection dictionaries
|
|
|
|
Returns:
|
|
Number of detections inserted
|
|
"""
|
|
conn = self.get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
for det in detections:
|
|
bbox = det["bbox"]
|
|
cursor.execute(
|
|
"""
|
|
INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, segmentation_mask, metadata)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
det["image_id"],
|
|
det["model_id"],
|
|
det["class_name"],
|
|
bbox[0],
|
|
bbox[1],
|
|
bbox[2],
|
|
bbox[3],
|
|
det["confidence"],
|
|
(
|
|
json.dumps(det.get("segmentation_mask"))
|
|
if det.get("segmentation_mask")
|
|
else None
|
|
),
|
|
(
|
|
json.dumps(det.get("metadata"))
|
|
if det.get("metadata")
|
|
else None
|
|
),
|
|
),
|
|
)
|
|
conn.commit()
|
|
return len(detections)
|
|
finally:
|
|
conn.close()
|
|
|
|
def get_detections(
|
|
self,
|
|
filters: Optional[Dict] = None,
|
|
limit: Optional[int] = None,
|
|
offset: int = 0,
|
|
) -> List[Dict]:
|
|
"""
|
|
Retrieve detections from database.
|
|
|
|
Args:
|
|
filters: Optional filters for querying
|
|
limit: Maximum number of results
|
|
offset: Number of results to skip
|
|
|
|
Returns:
|
|
List of detection dictionaries with joined data
|
|
"""
|
|
conn = self.get_connection()
|
|
try:
|
|
query = """
|
|
SELECT
|
|
d.*,
|
|
i.relative_path as image_path,
|
|
i.filename as image_filename,
|
|
i.width as image_width,
|
|
i.height as image_height,
|
|
m.model_name,
|
|
m.model_version
|
|
FROM detections d
|
|
JOIN images i ON d.image_id = i.id
|
|
JOIN models m ON d.model_id = m.id
|
|
"""
|
|
params = []
|
|
|
|
if filters:
|
|
conditions = []
|
|
for key, value in filters.items():
|
|
if (
|
|
key.startswith("d.")
|
|
or key.startswith("i.")
|
|
or key.startswith("m.")
|
|
):
|
|
conditions.append(f"{key} = ?")
|
|
else:
|
|
conditions.append(f"d.{key} = ?")
|
|
params.append(value)
|
|
query += " WHERE " + " AND ".join(conditions)
|
|
|
|
query += " ORDER BY d.detected_at DESC"
|
|
|
|
if limit:
|
|
query += f" LIMIT {limit} OFFSET {offset}"
|
|
|
|
cursor = conn.cursor()
|
|
cursor.execute(query, params)
|
|
|
|
detections = []
|
|
for row in cursor.fetchall():
|
|
det = dict(row)
|
|
# Parse JSON fields
|
|
if det.get("metadata"):
|
|
det["metadata"] = json.loads(det["metadata"])
|
|
if det.get("segmentation_mask"):
|
|
det["segmentation_mask"] = json.loads(det["segmentation_mask"])
|
|
detections.append(det)
|
|
|
|
return detections
|
|
finally:
|
|
conn.close()
|
|
|
|
def get_detections_for_image(
|
|
self, image_id: int, model_id: Optional[int] = None
|
|
) -> List[Dict]:
|
|
"""Get all detections for a specific image."""
|
|
filters = {"image_id": image_id}
|
|
if model_id:
|
|
filters["model_id"] = model_id
|
|
return self.get_detections(filters)
|
|
|
|
def delete_detections_for_model(self, model_id: int) -> int:
|
|
"""Delete all detections for a specific model."""
|
|
conn = self.get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute("DELETE FROM detections WHERE model_id = ?", (model_id,))
|
|
conn.commit()
|
|
return cursor.rowcount
|
|
finally:
|
|
conn.close()
|
|
|
|
# ==================== Statistics Operations ====================
|
|
|
|
def get_detection_statistics(
|
|
self, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None
|
|
) -> Dict:
|
|
"""
|
|
Get detection statistics for a date range.
|
|
|
|
Returns:
|
|
Dictionary with statistics (count by class, confidence distribution, etc.)
|
|
"""
|
|
conn = self.get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
|
|
# Build date filter
|
|
date_filter = ""
|
|
params = []
|
|
if start_date:
|
|
date_filter += " AND detected_at >= ?"
|
|
params.append(start_date)
|
|
if end_date:
|
|
date_filter += " AND detected_at <= ?"
|
|
params.append(end_date)
|
|
|
|
# Total detections
|
|
cursor.execute(
|
|
f"SELECT COUNT(*) as count FROM detections WHERE 1=1{date_filter}",
|
|
params,
|
|
)
|
|
total_count = cursor.fetchone()["count"]
|
|
|
|
# Count by class
|
|
cursor.execute(
|
|
f"""
|
|
SELECT class_name, COUNT(*) as count
|
|
FROM detections
|
|
WHERE 1=1{date_filter}
|
|
GROUP BY class_name
|
|
ORDER BY count DESC
|
|
""",
|
|
params,
|
|
)
|
|
class_counts = {
|
|
row["class_name"]: row["count"] for row in cursor.fetchall()
|
|
}
|
|
|
|
# Average confidence
|
|
cursor.execute(
|
|
f"SELECT AVG(confidence) as avg_conf FROM detections WHERE 1=1{date_filter}",
|
|
params,
|
|
)
|
|
avg_confidence = cursor.fetchone()["avg_conf"] or 0
|
|
|
|
# Confidence distribution
|
|
cursor.execute(
|
|
f"""
|
|
SELECT
|
|
CASE
|
|
WHEN confidence < 0.3 THEN 'low'
|
|
WHEN confidence < 0.7 THEN 'medium'
|
|
ELSE 'high'
|
|
END as conf_level,
|
|
COUNT(*) as count
|
|
FROM detections
|
|
WHERE 1=1{date_filter}
|
|
GROUP BY conf_level
|
|
""",
|
|
params,
|
|
)
|
|
conf_dist = {row["conf_level"]: row["count"] for row in cursor.fetchall()}
|
|
|
|
return {
|
|
"total_detections": total_count,
|
|
"class_counts": class_counts,
|
|
"average_confidence": avg_confidence,
|
|
"confidence_distribution": conf_dist,
|
|
}
|
|
finally:
|
|
conn.close()
|
|
|
|
def get_class_distribution(self, model_id: Optional[int] = None) -> Dict[str, int]:
|
|
"""Get count of detections per class."""
|
|
conn = self.get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
query = "SELECT class_name, COUNT(*) as count FROM detections"
|
|
params = []
|
|
|
|
if model_id:
|
|
query += " WHERE model_id = ?"
|
|
params.append(model_id)
|
|
|
|
query += " GROUP BY class_name ORDER BY count DESC"
|
|
|
|
cursor.execute(query, params)
|
|
return {row["class_name"]: row["count"] for row in cursor.fetchall()}
|
|
finally:
|
|
conn.close()
|
|
|
|
# ==================== Export Operations ====================
|
|
|
|
def export_detections_to_csv(
|
|
self, output_path: str, filters: Optional[Dict] = None
|
|
) -> bool:
|
|
"""Export detections to CSV file."""
|
|
try:
|
|
detections = self.get_detections(filters)
|
|
|
|
with open(output_path, "w", newline="") as csvfile:
|
|
if not detections:
|
|
return True
|
|
|
|
fieldnames = [
|
|
"id",
|
|
"image_path",
|
|
"model_name",
|
|
"model_version",
|
|
"class_name",
|
|
"x_min",
|
|
"y_min",
|
|
"x_max",
|
|
"y_max",
|
|
"confidence",
|
|
"segmentation_mask",
|
|
"detected_at",
|
|
]
|
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
|
writer.writeheader()
|
|
|
|
for det in detections:
|
|
row = {k: det[k] for k in fieldnames if k in det}
|
|
# Convert segmentation mask list to JSON string for CSV
|
|
if row.get("segmentation_mask") and isinstance(
|
|
row["segmentation_mask"], list
|
|
):
|
|
row["segmentation_mask"] = json.dumps(row["segmentation_mask"])
|
|
writer.writerow(row)
|
|
|
|
return True
|
|
except Exception as e:
|
|
print(f"Error exporting to CSV: {e}")
|
|
return False
|
|
|
|
def export_detections_to_json(
|
|
self, output_path: str, filters: Optional[Dict] = None
|
|
) -> bool:
|
|
"""Export detections to JSON file."""
|
|
try:
|
|
detections = self.get_detections(filters)
|
|
|
|
# Convert datetime objects to strings
|
|
for det in detections:
|
|
if isinstance(det.get("detected_at"), datetime):
|
|
det["detected_at"] = det["detected_at"].isoformat()
|
|
|
|
with open(output_path, "w") as jsonfile:
|
|
json.dump(detections, jsonfile, indent=2)
|
|
|
|
return True
|
|
except Exception as e:
|
|
print(f"Error exporting to JSON: {e}")
|
|
return False
|
|
|
|
# ==================== Annotation Operations ====================
|
|
|
|
def add_annotation(
|
|
self,
|
|
image_id: int,
|
|
class_id: int,
|
|
bbox: Tuple[float, float, float, float],
|
|
annotator: str,
|
|
segmentation_mask: Optional[List[List[float]]] = None,
|
|
verified: bool = False,
|
|
) -> int:
|
|
"""
|
|
Add manual annotation.
|
|
|
|
Args:
|
|
image_id: ID of the image
|
|
class_id: ID of the object class (foreign key to object_classes)
|
|
bbox: Bounding box coordinates (normalized 0-1)
|
|
annotator: Name of person/tool creating annotation
|
|
segmentation_mask: Polygon coordinates for segmentation
|
|
verified: Whether annotation has been verified
|
|
|
|
Returns:
|
|
ID of the inserted annotation
|
|
"""
|
|
conn = self.get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
x_min, y_min, x_max, y_max = bbox
|
|
cursor.execute(
|
|
"""
|
|
INSERT INTO annotations (image_id, class_id, x_min, y_min, x_max, y_max, segmentation_mask, annotator, verified)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
image_id,
|
|
class_id,
|
|
x_min,
|
|
y_min,
|
|
x_max,
|
|
y_max,
|
|
json.dumps(segmentation_mask) if segmentation_mask else None,
|
|
annotator,
|
|
verified,
|
|
),
|
|
)
|
|
conn.commit()
|
|
return cursor.lastrowid
|
|
finally:
|
|
conn.close()
|
|
|
|
def get_annotations_for_image(self, image_id: int) -> List[Dict]:
|
|
"""
|
|
Get all annotations for an image with class information.
|
|
|
|
Args:
|
|
image_id: ID of the image
|
|
|
|
Returns:
|
|
List of annotation dictionaries with joined class information
|
|
"""
|
|
conn = self.get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"""
|
|
SELECT
|
|
a.*,
|
|
c.class_name,
|
|
c.color as class_color,
|
|
c.description as class_description
|
|
FROM annotations a
|
|
JOIN object_classes c ON a.class_id = c.id
|
|
WHERE a.image_id = ?
|
|
ORDER BY a.created_at DESC
|
|
""",
|
|
(image_id,),
|
|
)
|
|
annotations = []
|
|
for row in cursor.fetchall():
|
|
ann = dict(row)
|
|
if ann.get("segmentation_mask"):
|
|
ann["segmentation_mask"] = json.loads(ann["segmentation_mask"])
|
|
annotations.append(ann)
|
|
return annotations
|
|
finally:
|
|
conn.close()
|
|
|
|
def delete_annotation(self, annotation_id: int) -> bool:
|
|
"""
|
|
Delete a manual annotation by ID.
|
|
|
|
Args:
|
|
annotation_id: ID of the annotation to delete
|
|
|
|
Returns:
|
|
True if an annotation was deleted, False otherwise.
|
|
"""
|
|
conn = self.get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute("DELETE FROM annotations WHERE id = ?", (annotation_id,))
|
|
conn.commit()
|
|
return cursor.rowcount > 0
|
|
finally:
|
|
conn.close()
|
|
|
|
# ==================== Object Class Operations ====================
|
|
|
|
def get_object_classes(self) -> List[Dict]:
|
|
"""
|
|
Get all object classes.
|
|
|
|
Returns:
|
|
List of object class dictionaries
|
|
"""
|
|
conn = self.get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT * FROM object_classes ORDER BY class_name")
|
|
return [dict(row) for row in cursor.fetchall()]
|
|
finally:
|
|
conn.close()
|
|
|
|
def get_object_class_by_id(self, class_id: int) -> Optional[Dict]:
|
|
"""Get object class by ID."""
|
|
conn = self.get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT * FROM object_classes WHERE id = ?", (class_id,))
|
|
row = cursor.fetchone()
|
|
return dict(row) if row else None
|
|
finally:
|
|
conn.close()
|
|
|
|
def get_object_class_by_name(self, class_name: str) -> Optional[Dict]:
|
|
"""Get object class by name."""
|
|
conn = self.get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"SELECT * FROM object_classes WHERE class_name = ?", (class_name,)
|
|
)
|
|
row = cursor.fetchone()
|
|
return dict(row) if row else None
|
|
finally:
|
|
conn.close()
|
|
|
|
def add_object_class(
|
|
self, class_name: str, color: str, description: Optional[str] = None
|
|
) -> int:
|
|
"""
|
|
Add a new object class.
|
|
|
|
Args:
|
|
class_name: Name of the object class
|
|
color: Hex color code (e.g., '#FF0000')
|
|
description: Optional description
|
|
|
|
Returns:
|
|
ID of the inserted object class
|
|
"""
|
|
conn = self.get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"""
|
|
INSERT INTO object_classes (class_name, color, description)
|
|
VALUES (?, ?, ?)
|
|
""",
|
|
(class_name, color, description),
|
|
)
|
|
conn.commit()
|
|
return cursor.lastrowid
|
|
except sqlite3.IntegrityError:
|
|
# Class already exists
|
|
existing = self.get_object_class_by_name(class_name)
|
|
return existing["id"] if existing else None
|
|
finally:
|
|
conn.close()
|
|
|
|
def update_object_class(
|
|
self,
|
|
class_id: int,
|
|
class_name: Optional[str] = None,
|
|
color: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
) -> bool:
|
|
"""
|
|
Update an object class.
|
|
|
|
Args:
|
|
class_id: ID of the class to update
|
|
class_name: New class name (optional)
|
|
color: New color (optional)
|
|
description: New description (optional)
|
|
|
|
Returns:
|
|
True if updated, False otherwise
|
|
"""
|
|
conn = self.get_connection()
|
|
try:
|
|
updates = {}
|
|
if class_name is not None:
|
|
updates["class_name"] = class_name
|
|
if color is not None:
|
|
updates["color"] = color
|
|
if description is not None:
|
|
updates["description"] = description
|
|
|
|
if not updates:
|
|
return False
|
|
|
|
set_clauses = [f"{key} = ?" for key in updates.keys()]
|
|
params = list(updates.values()) + [class_id]
|
|
|
|
query = f"UPDATE object_classes SET {', '.join(set_clauses)} WHERE id = ?"
|
|
cursor = conn.cursor()
|
|
cursor.execute(query, params)
|
|
conn.commit()
|
|
return cursor.rowcount > 0
|
|
finally:
|
|
conn.close()
|
|
|
|
def delete_object_class(self, class_id: int) -> bool:
|
|
"""
|
|
Delete an object class.
|
|
|
|
Args:
|
|
class_id: ID of the class to delete
|
|
|
|
Returns:
|
|
True if deleted, False otherwise
|
|
"""
|
|
conn = self.get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute("DELETE FROM object_classes WHERE id = ?", (class_id,))
|
|
conn.commit()
|
|
return cursor.rowcount > 0
|
|
finally:
|
|
conn.close()
|
|
|
|
# ==================== Dataset Utilities ====================
|
|
|
|
def compose_data_yaml(
|
|
self,
|
|
dataset_root: str,
|
|
output_path: Optional[str] = None,
|
|
splits: Optional[Dict[str, str]] = None,
|
|
) -> str:
|
|
"""
|
|
Compose a YOLO data.yaml file based on dataset folders and database metadata.
|
|
|
|
Args:
|
|
dataset_root: Base directory containing the dataset structure.
|
|
output_path: Optional output path; defaults to <dataset_root>/data.yaml.
|
|
splits: Optional mapping overriding train/val/test image directories (relative
|
|
to dataset_root or absolute paths).
|
|
|
|
Returns:
|
|
Path to the generated YAML file.
|
|
"""
|
|
dataset_root_path = Path(dataset_root).expanduser()
|
|
if not dataset_root_path.exists():
|
|
raise ValueError(f"Dataset root does not exist: {dataset_root_path}")
|
|
dataset_root_path = dataset_root_path.resolve()
|
|
|
|
split_map: Dict[str, str] = {key: "" for key in ("train", "val", "test")}
|
|
if splits:
|
|
for key, value in splits.items():
|
|
if key in split_map and value:
|
|
split_map[key] = value
|
|
|
|
inferred = self._infer_split_dirs(dataset_root_path)
|
|
for key in split_map:
|
|
if not split_map[key]:
|
|
split_map[key] = inferred.get(key, "")
|
|
|
|
for required in ("train", "val"):
|
|
if not split_map[required]:
|
|
raise ValueError(
|
|
"Unable to determine %s image directory under %s. Provide it "
|
|
"explicitly via the 'splits' argument."
|
|
% (required, dataset_root_path)
|
|
)
|
|
|
|
yaml_splits: Dict[str, str] = {}
|
|
for key, value in split_map.items():
|
|
if not value:
|
|
continue
|
|
yaml_splits[key] = self._normalize_split_value(value, dataset_root_path)
|
|
|
|
class_names = self._fetch_annotation_class_names()
|
|
if not class_names:
|
|
class_names = [cls["class_name"] for cls in self.get_object_classes()]
|
|
if not class_names:
|
|
raise ValueError("No object classes available to populate data.yaml")
|
|
|
|
names_map = {idx: name for idx, name in enumerate(class_names)}
|
|
payload: Dict[str, Any] = {
|
|
"path": dataset_root_path.as_posix(),
|
|
"train": yaml_splits["train"],
|
|
"val": yaml_splits["val"],
|
|
"names": names_map,
|
|
"nc": len(class_names),
|
|
}
|
|
if yaml_splits.get("test"):
|
|
payload["test"] = yaml_splits["test"]
|
|
|
|
output_path_obj = (
|
|
Path(output_path).expanduser()
|
|
if output_path
|
|
else dataset_root_path / "data.yaml"
|
|
)
|
|
output_path_obj.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
with open(output_path_obj, "w", encoding="utf-8") as handle:
|
|
yaml.safe_dump(payload, handle, sort_keys=False)
|
|
|
|
logger.info(f"Generated data.yaml at {output_path_obj}")
|
|
return output_path_obj.as_posix()
|
|
|
|
def _fetch_annotation_class_names(self) -> List[str]:
|
|
"""Return class names referenced by annotations (ordered by class ID)."""
|
|
conn = self.get_connection()
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"""
|
|
SELECT DISTINCT c.id, c.class_name
|
|
FROM annotations a
|
|
JOIN object_classes c ON a.class_id = c.id
|
|
ORDER BY c.id
|
|
"""
|
|
)
|
|
rows = cursor.fetchall()
|
|
return [row["class_name"] for row in rows]
|
|
finally:
|
|
conn.close()
|
|
|
|
def _infer_split_dirs(self, dataset_root: Path) -> Dict[str, str]:
|
|
"""Infer train/val/test image directories relative to dataset_root."""
|
|
patterns = {
|
|
"train": [
|
|
"train/images",
|
|
"training/images",
|
|
"images/train",
|
|
"images/training",
|
|
"train",
|
|
"training",
|
|
],
|
|
"val": [
|
|
"val/images",
|
|
"validation/images",
|
|
"images/val",
|
|
"images/validation",
|
|
"val",
|
|
"validation",
|
|
],
|
|
"test": [
|
|
"test/images",
|
|
"testing/images",
|
|
"images/test",
|
|
"images/testing",
|
|
"test",
|
|
"testing",
|
|
],
|
|
}
|
|
|
|
inferred: Dict[str, str] = {key: "" for key in patterns}
|
|
for split_name, options in patterns.items():
|
|
for relative in options:
|
|
candidate = (dataset_root / relative).resolve()
|
|
if (
|
|
candidate.exists()
|
|
and candidate.is_dir()
|
|
and self._directory_has_images(candidate)
|
|
):
|
|
try:
|
|
inferred[split_name] = candidate.relative_to(
|
|
dataset_root
|
|
).as_posix()
|
|
except ValueError:
|
|
inferred[split_name] = candidate.as_posix()
|
|
break
|
|
return inferred
|
|
|
|
def _normalize_split_value(self, split_value: str, dataset_root: Path) -> str:
|
|
"""Validate and normalize a split directory to a YAML-friendly string."""
|
|
split_path = Path(split_value).expanduser()
|
|
if not split_path.is_absolute():
|
|
split_path = (dataset_root / split_path).resolve()
|
|
else:
|
|
split_path = split_path.resolve()
|
|
|
|
if not split_path.exists() or not split_path.is_dir():
|
|
raise ValueError(f"Split directory not found: {split_path}")
|
|
|
|
if not self._directory_has_images(split_path):
|
|
raise ValueError(f"No images found under {split_path}")
|
|
|
|
try:
|
|
return split_path.relative_to(dataset_root).as_posix()
|
|
except ValueError:
|
|
return split_path.as_posix()
|
|
|
|
@staticmethod
|
|
def _directory_has_images(directory: Path, max_checks: int = 2000) -> bool:
|
|
"""Return True if directory tree contains at least one image file."""
|
|
checked = 0
|
|
try:
|
|
for file_path in directory.rglob("*"):
|
|
if not file_path.is_file():
|
|
continue
|
|
if file_path.suffix.lower() in IMAGE_EXTENSIONS:
|
|
return True
|
|
checked += 1
|
|
if checked >= max_checks:
|
|
break
|
|
except Exception:
|
|
return False
|
|
return False
|
|
|
|
@staticmethod
|
|
def calculate_checksum(file_path: str) -> str:
|
|
"""Calculate MD5 checksum of a file."""
|
|
md5_hash = hashlib.md5()
|
|
with open(file_path, "rb") as f:
|
|
for chunk in iter(lambda: f.read(4096), b""):
|
|
md5_hash.update(chunk)
|
|
return md5_hash.hexdigest()
|