""" Database manager for the microscopy object detection application. Handles all database operations including CRUD operations, queries, and exports. """ import sqlite3 import json from datetime import datetime from typing import List, Dict, Optional, Tuple, Any, Union from pathlib import Path import csv import hashlib class DatabaseManager: """Manages all database operations for the application.""" def __init__(self, db_path: str = "data/detections.db"): """ Initialize database manager. Args: db_path: Path to SQLite database file """ self.db_path = db_path self._ensure_database_exists() def _ensure_database_exists(self) -> None: """Create database and tables if they don't exist.""" # Create directory if it doesn't exist Path(self.db_path).parent.mkdir(parents=True, exist_ok=True) # Read schema file and execute schema_path = Path(__file__).parent / "schema.sql" with open(schema_path, "r") as f: schema_sql = f.read() conn = self.get_connection() try: conn.executescript(schema_sql) conn.commit() finally: conn.close() def get_connection(self) -> sqlite3.Connection: """Get database connection with proper settings.""" conn = sqlite3.connect(self.db_path) conn.row_factory = sqlite3.Row # Enable column access by name conn.execute("PRAGMA foreign_keys = ON") # Enable foreign keys return conn # ==================== Model Operations ==================== def add_model( self, model_name: str, model_version: str, model_path: str, base_model: str = "yolov8s-seg.pt", training_params: Optional[Dict] = None, metrics: Optional[Dict] = None, ) -> int: """ Add a new model to the database. Args: model_name: Name of the model model_version: Version string model_path: Path to model weights file base_model: Base model used for training training_params: Dictionary of training parameters metrics: Dictionary of validation metrics Returns: ID of the inserted model """ conn = self.get_connection() try: cursor = conn.cursor() cursor.execute( """ INSERT INTO models (model_name, model_version, model_path, base_model, training_params, metrics) VALUES (?, ?, ?, ?, ?, ?) """, ( model_name, model_version, model_path, base_model, json.dumps(training_params) if training_params else None, json.dumps(metrics) if metrics else None, ), ) conn.commit() return cursor.lastrowid finally: conn.close() def get_models(self, filters: Optional[Dict] = None) -> List[Dict]: """ Retrieve models from database. Args: filters: Optional filters (e.g., {'model_name': 'my_model'}) Returns: List of model dictionaries """ conn = self.get_connection() try: query = "SELECT * FROM models" params = [] if filters: conditions = [] for key, value in filters.items(): conditions.append(f"{key} = ?") params.append(value) query += " WHERE " + " AND ".join(conditions) query += " ORDER BY created_at DESC" cursor = conn.cursor() cursor.execute(query, params) models = [] for row in cursor.fetchall(): model = dict(row) # Parse JSON fields if model["training_params"]: model["training_params"] = json.loads(model["training_params"]) if model["metrics"]: model["metrics"] = json.loads(model["metrics"]) models.append(model) return models finally: conn.close() def get_model_by_id(self, model_id: int) -> Optional[Dict]: """Get model by ID.""" models = self.get_models({"id": model_id}) return models[0] if models else None def update_model(self, model_id: int, updates: Dict) -> bool: """Update model fields.""" conn = self.get_connection() try: # Build update query set_clauses = [] params = [] for key, value in updates.items(): if key in ["training_params", "metrics"] and isinstance(value, dict): value = json.dumps(value) set_clauses.append(f"{key} = ?") params.append(value) params.append(model_id) query = f"UPDATE models SET {', '.join(set_clauses)} WHERE id = ?" cursor = conn.cursor() cursor.execute(query, params) conn.commit() return cursor.rowcount > 0 finally: conn.close() # ==================== Image Operations ==================== def add_image( self, relative_path: str, filename: str, width: int, height: int, captured_at: Optional[datetime] = None, checksum: Optional[str] = None, ) -> int: """ Add a new image to the database. Args: relative_path: Path relative to image repository filename: Image filename width: Image width in pixels height: Image height in pixels captured_at: When image was captured (if known) checksum: MD5 checksum of image file Returns: ID of the inserted image """ conn = self.get_connection() try: cursor = conn.cursor() cursor.execute( """ INSERT INTO images (relative_path, filename, width, height, captured_at, checksum) VALUES (?, ?, ?, ?, ?, ?) """, (relative_path, filename, width, height, captured_at, checksum), ) conn.commit() return cursor.lastrowid except sqlite3.IntegrityError: # Image already exists, return its ID cursor.execute( "SELECT id FROM images WHERE relative_path = ?", (relative_path,) ) row = cursor.fetchone() return row["id"] if row else None finally: conn.close() def get_image_by_path(self, relative_path: str) -> Optional[Dict]: """Get image by relative path.""" conn = self.get_connection() try: cursor = conn.cursor() cursor.execute( "SELECT * FROM images WHERE relative_path = ?", (relative_path,) ) row = cursor.fetchone() return dict(row) if row else None finally: conn.close() def get_or_create_image( self, relative_path: str, filename: str, width: int, height: int ) -> int: """Get existing image or create new one.""" existing = self.get_image_by_path(relative_path) if existing: return existing["id"] return self.add_image(relative_path, filename, width, height) # ==================== Detection Operations ==================== def add_detection( self, image_id: int, model_id: int, class_name: str, bbox: Tuple[float, float, float, float], # (x_min, y_min, x_max, y_max) confidence: float, segmentation_mask: Optional[List[List[float]]] = None, metadata: Optional[Dict] = None, ) -> int: """ Add a new detection to the database. Args: image_id: ID of the image model_id: ID of the model used class_name: Detected object class bbox: Bounding box coordinates (normalized 0-1) confidence: Detection confidence score segmentation_mask: Polygon coordinates for segmentation [[x1,y1], [x2,y2], ...] metadata: Additional metadata Returns: ID of the inserted detection """ conn = self.get_connection() try: cursor = conn.cursor() x_min, y_min, x_max, y_max = bbox cursor.execute( """ INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, segmentation_mask, metadata) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, json.dumps(segmentation_mask) if segmentation_mask else None, json.dumps(metadata) if metadata else None, ), ) conn.commit() return cursor.lastrowid finally: conn.close() def add_detections_batch(self, detections: List[Dict]) -> int: """ Add multiple detections in a single transaction. Args: detections: List of detection dictionaries Returns: Number of detections inserted """ conn = self.get_connection() try: cursor = conn.cursor() for det in detections: bbox = det["bbox"] cursor.execute( """ INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, segmentation_mask, metadata) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( det["image_id"], det["model_id"], det["class_name"], bbox[0], bbox[1], bbox[2], bbox[3], det["confidence"], ( json.dumps(det.get("segmentation_mask")) if det.get("segmentation_mask") else None ), ( json.dumps(det.get("metadata")) if det.get("metadata") else None ), ), ) conn.commit() return len(detections) finally: conn.close() def get_detections( self, filters: Optional[Dict] = None, limit: Optional[int] = None, offset: int = 0, ) -> List[Dict]: """ Retrieve detections from database. Args: filters: Optional filters for querying limit: Maximum number of results offset: Number of results to skip Returns: List of detection dictionaries with joined data """ conn = self.get_connection() try: query = """ SELECT d.*, i.relative_path as image_path, i.filename as image_filename, i.width as image_width, i.height as image_height, m.model_name, m.model_version FROM detections d JOIN images i ON d.image_id = i.id JOIN models m ON d.model_id = m.id """ params = [] if filters: conditions = [] for key, value in filters.items(): if ( key.startswith("d.") or key.startswith("i.") or key.startswith("m.") ): conditions.append(f"{key} = ?") else: conditions.append(f"d.{key} = ?") params.append(value) query += " WHERE " + " AND ".join(conditions) query += " ORDER BY d.detected_at DESC" if limit: query += f" LIMIT {limit} OFFSET {offset}" cursor = conn.cursor() cursor.execute(query, params) detections = [] for row in cursor.fetchall(): det = dict(row) # Parse JSON fields if det.get("metadata"): det["metadata"] = json.loads(det["metadata"]) if det.get("segmentation_mask"): det["segmentation_mask"] = json.loads(det["segmentation_mask"]) detections.append(det) return detections finally: conn.close() def get_detections_for_image( self, image_id: int, model_id: Optional[int] = None ) -> List[Dict]: """Get all detections for a specific image.""" filters = {"image_id": image_id} if model_id: filters["model_id"] = model_id return self.get_detections(filters) def delete_detections_for_model(self, model_id: int) -> int: """Delete all detections for a specific model.""" conn = self.get_connection() try: cursor = conn.cursor() cursor.execute("DELETE FROM detections WHERE model_id = ?", (model_id,)) conn.commit() return cursor.rowcount finally: conn.close() # ==================== Statistics Operations ==================== def get_detection_statistics( self, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None ) -> Dict: """ Get detection statistics for a date range. Returns: Dictionary with statistics (count by class, confidence distribution, etc.) """ conn = self.get_connection() try: cursor = conn.cursor() # Build date filter date_filter = "" params = [] if start_date: date_filter += " AND detected_at >= ?" params.append(start_date) if end_date: date_filter += " AND detected_at <= ?" params.append(end_date) # Total detections cursor.execute( f"SELECT COUNT(*) as count FROM detections WHERE 1=1{date_filter}", params, ) total_count = cursor.fetchone()["count"] # Count by class cursor.execute( f""" SELECT class_name, COUNT(*) as count FROM detections WHERE 1=1{date_filter} GROUP BY class_name ORDER BY count DESC """, params, ) class_counts = { row["class_name"]: row["count"] for row in cursor.fetchall() } # Average confidence cursor.execute( f"SELECT AVG(confidence) as avg_conf FROM detections WHERE 1=1{date_filter}", params, ) avg_confidence = cursor.fetchone()["avg_conf"] or 0 # Confidence distribution cursor.execute( f""" SELECT CASE WHEN confidence < 0.3 THEN 'low' WHEN confidence < 0.7 THEN 'medium' ELSE 'high' END as conf_level, COUNT(*) as count FROM detections WHERE 1=1{date_filter} GROUP BY conf_level """, params, ) conf_dist = {row["conf_level"]: row["count"] for row in cursor.fetchall()} return { "total_detections": total_count, "class_counts": class_counts, "average_confidence": avg_confidence, "confidence_distribution": conf_dist, } finally: conn.close() def get_class_distribution(self, model_id: Optional[int] = None) -> Dict[str, int]: """Get count of detections per class.""" conn = self.get_connection() try: cursor = conn.cursor() query = "SELECT class_name, COUNT(*) as count FROM detections" params = [] if model_id: query += " WHERE model_id = ?" params.append(model_id) query += " GROUP BY class_name ORDER BY count DESC" cursor.execute(query, params) return {row["class_name"]: row["count"] for row in cursor.fetchall()} finally: conn.close() # ==================== Export Operations ==================== def export_detections_to_csv( self, output_path: str, filters: Optional[Dict] = None ) -> bool: """Export detections to CSV file.""" try: detections = self.get_detections(filters) with open(output_path, "w", newline="") as csvfile: if not detections: return True fieldnames = [ "id", "image_path", "model_name", "model_version", "class_name", "x_min", "y_min", "x_max", "y_max", "confidence", "segmentation_mask", "detected_at", ] writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() for det in detections: row = {k: det[k] for k in fieldnames if k in det} # Convert segmentation mask list to JSON string for CSV if row.get("segmentation_mask") and isinstance( row["segmentation_mask"], list ): row["segmentation_mask"] = json.dumps(row["segmentation_mask"]) writer.writerow(row) return True except Exception as e: print(f"Error exporting to CSV: {e}") return False def export_detections_to_json( self, output_path: str, filters: Optional[Dict] = None ) -> bool: """Export detections to JSON file.""" try: detections = self.get_detections(filters) # Convert datetime objects to strings for det in detections: if isinstance(det.get("detected_at"), datetime): det["detected_at"] = det["detected_at"].isoformat() with open(output_path, "w") as jsonfile: json.dump(detections, jsonfile, indent=2) return True except Exception as e: print(f"Error exporting to JSON: {e}") return False # ==================== Annotation Operations ==================== def add_annotation( self, image_id: int, class_name: str, bbox: Tuple[float, float, float, float], annotator: str, segmentation_mask: Optional[List[List[float]]] = None, verified: bool = False, ) -> int: """Add manual annotation.""" conn = self.get_connection() try: cursor = conn.cursor() x_min, y_min, x_max, y_max = bbox cursor.execute( """ INSERT INTO annotations (image_id, class_name, x_min, y_min, x_max, y_max, segmentation_mask, annotator, verified) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( image_id, class_name, x_min, y_min, x_max, y_max, json.dumps(segmentation_mask) if segmentation_mask else None, annotator, verified, ), ) conn.commit() return cursor.lastrowid finally: conn.close() def get_annotations_for_image(self, image_id: int) -> List[Dict]: """Get all annotations for an image.""" conn = self.get_connection() try: cursor = conn.cursor() cursor.execute("SELECT * FROM annotations WHERE image_id = ?", (image_id,)) return [dict(row) for row in cursor.fetchall()] finally: conn.close() @staticmethod def calculate_checksum(file_path: str) -> str: """Calculate MD5 checksum of a file.""" md5_hash = hashlib.md5() with open(file_path, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): md5_hash.update(chunk) return md5_hash.hexdigest()