""" Database manager for the microscopy object detection application. Handles all database operations including CRUD operations, queries, and exports. """ import sqlite3 import json from datetime import datetime from typing import List, Dict, Optional, Tuple, Any, Union from pathlib import Path import csv import hashlib import yaml from src.utils.logger import get_logger IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp") logger = get_logger(__name__) class DatabaseManager: """Manages all database operations for the application.""" def __init__(self, db_path: str = "data/detections.db"): """ Initialize database manager. Args: db_path: Path to SQLite database file """ self.db_path = db_path self._ensure_database_exists() def _ensure_database_exists(self) -> None: """Create database and tables if they don't exist.""" # Create directory if it doesn't exist Path(self.db_path).parent.mkdir(parents=True, exist_ok=True) conn = self.get_connection() try: # Check if annotations table needs migration self._migrate_annotations_table(conn) # Read schema file and execute schema_path = Path(__file__).parent / "schema.sql" with open(schema_path, "r") as f: schema_sql = f.read() conn.executescript(schema_sql) conn.commit() finally: conn.close() def _migrate_annotations_table(self, conn: sqlite3.Connection) -> None: """ Migrate annotations table from old schema (class_name) to new schema (class_id). """ cursor = conn.cursor() # Check if annotations table exists cursor.execute( "SELECT name FROM sqlite_master WHERE type='table' AND name='annotations'" ) if not cursor.fetchone(): # Table doesn't exist yet, no migration needed return # Check if table has old schema (class_name column) cursor.execute("PRAGMA table_info(annotations)") columns = {row[1]: row for row in cursor.fetchall()} if "class_name" in columns and "class_id" not in columns: # Old schema detected, need to migrate print("Migrating annotations table to new schema with class_id...") # Drop old annotations table (assuming no critical data since this is a new feature) cursor.execute("DROP TABLE IF EXISTS annotations") conn.commit() print("Old annotations table dropped, will be recreated with new schema") def get_connection(self) -> sqlite3.Connection: """Get database connection with proper settings.""" conn = sqlite3.connect(self.db_path) conn.row_factory = sqlite3.Row # Enable column access by name conn.execute("PRAGMA foreign_keys = ON") # Enable foreign keys return conn # ==================== Model Operations ==================== def add_model( self, model_name: str, model_version: str, model_path: str, base_model: str = "yolov8s-seg.pt", training_params: Optional[Dict] = None, metrics: Optional[Dict] = None, ) -> int: """ Add a new model to the database. Args: model_name: Name of the model model_version: Version string model_path: Path to model weights file base_model: Base model used for training training_params: Dictionary of training parameters metrics: Dictionary of validation metrics Returns: ID of the inserted model """ conn = self.get_connection() try: cursor = conn.cursor() cursor.execute( """ INSERT INTO models (model_name, model_version, model_path, base_model, training_params, metrics) VALUES (?, ?, ?, ?, ?, ?) """, ( model_name, model_version, model_path, base_model, json.dumps(training_params) if training_params else None, json.dumps(metrics) if metrics else None, ), ) conn.commit() return cursor.lastrowid finally: conn.close() def get_models(self, filters: Optional[Dict] = None) -> List[Dict]: """ Retrieve models from database. Args: filters: Optional filters (e.g., {'model_name': 'my_model'}) Returns: List of model dictionaries """ conn = self.get_connection() try: query = "SELECT * FROM models" params = [] if filters: conditions = [] for key, value in filters.items(): conditions.append(f"{key} = ?") params.append(value) query += " WHERE " + " AND ".join(conditions) query += " ORDER BY created_at DESC" cursor = conn.cursor() cursor.execute(query, params) models = [] for row in cursor.fetchall(): model = dict(row) # Parse JSON fields if model["training_params"]: model["training_params"] = json.loads(model["training_params"]) if model["metrics"]: model["metrics"] = json.loads(model["metrics"]) models.append(model) return models finally: conn.close() def get_model_by_id(self, model_id: int) -> Optional[Dict]: """Get model by ID.""" models = self.get_models({"id": model_id}) return models[0] if models else None def update_model(self, model_id: int, updates: Dict) -> bool: """Update model fields.""" conn = self.get_connection() try: # Build update query set_clauses = [] params = [] for key, value in updates.items(): if key in ["training_params", "metrics"] and isinstance(value, dict): value = json.dumps(value) set_clauses.append(f"{key} = ?") params.append(value) params.append(model_id) query = f"UPDATE models SET {', '.join(set_clauses)} WHERE id = ?" cursor = conn.cursor() cursor.execute(query, params) conn.commit() return cursor.rowcount > 0 finally: conn.close() # ==================== Image Operations ==================== def add_image( self, relative_path: str, filename: str, width: int, height: int, captured_at: Optional[datetime] = None, checksum: Optional[str] = None, ) -> int: """ Add a new image to the database. Args: relative_path: Path relative to image repository filename: Image filename width: Image width in pixels height: Image height in pixels captured_at: When image was captured (if known) checksum: MD5 checksum of image file Returns: ID of the inserted image """ conn = self.get_connection() try: cursor = conn.cursor() cursor.execute( """ INSERT INTO images (relative_path, filename, width, height, captured_at, checksum) VALUES (?, ?, ?, ?, ?, ?) """, (relative_path, filename, width, height, captured_at, checksum), ) conn.commit() return cursor.lastrowid except sqlite3.IntegrityError: # Image already exists, return its ID cursor.execute( "SELECT id FROM images WHERE relative_path = ?", (relative_path,) ) row = cursor.fetchone() return row["id"] if row else None finally: conn.close() def get_image_by_path(self, relative_path: str) -> Optional[Dict]: """Get image by relative path.""" conn = self.get_connection() try: cursor = conn.cursor() cursor.execute( "SELECT * FROM images WHERE relative_path = ?", (relative_path,) ) row = cursor.fetchone() return dict(row) if row else None finally: conn.close() def get_or_create_image( self, relative_path: str, filename: str, width: int, height: int ) -> int: """Get existing image or create new one.""" existing = self.get_image_by_path(relative_path) if existing: return existing["id"] return self.add_image(relative_path, filename, width, height) # ==================== Detection Operations ==================== def add_detection( self, image_id: int, model_id: int, class_name: str, bbox: Tuple[float, float, float, float], # (x_min, y_min, x_max, y_max) confidence: float, segmentation_mask: Optional[List[List[float]]] = None, metadata: Optional[Dict] = None, ) -> int: """ Add a new detection to the database. Args: image_id: ID of the image model_id: ID of the model used class_name: Detected object class bbox: Bounding box coordinates (normalized 0-1) confidence: Detection confidence score segmentation_mask: Polygon coordinates for segmentation [[x1,y1], [x2,y2], ...] metadata: Additional metadata Returns: ID of the inserted detection """ conn = self.get_connection() try: cursor = conn.cursor() x_min, y_min, x_max, y_max = bbox cursor.execute( """ INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, segmentation_mask, metadata) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, json.dumps(segmentation_mask) if segmentation_mask else None, json.dumps(metadata) if metadata else None, ), ) conn.commit() return cursor.lastrowid finally: conn.close() def add_detections_batch(self, detections: List[Dict]) -> int: """ Add multiple detections in a single transaction. Args: detections: List of detection dictionaries Returns: Number of detections inserted """ conn = self.get_connection() try: cursor = conn.cursor() for det in detections: bbox = det["bbox"] cursor.execute( """ INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, segmentation_mask, metadata) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( det["image_id"], det["model_id"], det["class_name"], bbox[0], bbox[1], bbox[2], bbox[3], det["confidence"], ( json.dumps(det.get("segmentation_mask")) if det.get("segmentation_mask") else None ), ( json.dumps(det.get("metadata")) if det.get("metadata") else None ), ), ) conn.commit() return len(detections) finally: conn.close() def get_detections( self, filters: Optional[Dict] = None, limit: Optional[int] = None, offset: int = 0, ) -> List[Dict]: """ Retrieve detections from database. Args: filters: Optional filters for querying limit: Maximum number of results offset: Number of results to skip Returns: List of detection dictionaries with joined data """ conn = self.get_connection() try: query = """ SELECT d.*, i.relative_path as image_path, i.filename as image_filename, i.width as image_width, i.height as image_height, m.model_name, m.model_version FROM detections d JOIN images i ON d.image_id = i.id JOIN models m ON d.model_id = m.id """ params = [] if filters: conditions = [] for key, value in filters.items(): if ( key.startswith("d.") or key.startswith("i.") or key.startswith("m.") ): conditions.append(f"{key} = ?") else: conditions.append(f"d.{key} = ?") params.append(value) query += " WHERE " + " AND ".join(conditions) query += " ORDER BY d.detected_at DESC" if limit: query += f" LIMIT {limit} OFFSET {offset}" cursor = conn.cursor() cursor.execute(query, params) detections = [] for row in cursor.fetchall(): det = dict(row) # Parse JSON fields if det.get("metadata"): det["metadata"] = json.loads(det["metadata"]) if det.get("segmentation_mask"): det["segmentation_mask"] = json.loads(det["segmentation_mask"]) detections.append(det) return detections finally: conn.close() def get_detections_for_image( self, image_id: int, model_id: Optional[int] = None ) -> List[Dict]: """Get all detections for a specific image.""" filters = {"image_id": image_id} if model_id: filters["model_id"] = model_id return self.get_detections(filters) def delete_detections_for_model(self, model_id: int) -> int: """Delete all detections for a specific model.""" conn = self.get_connection() try: cursor = conn.cursor() cursor.execute("DELETE FROM detections WHERE model_id = ?", (model_id,)) conn.commit() return cursor.rowcount finally: conn.close() # ==================== Statistics Operations ==================== def get_detection_statistics( self, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None ) -> Dict: """ Get detection statistics for a date range. Returns: Dictionary with statistics (count by class, confidence distribution, etc.) """ conn = self.get_connection() try: cursor = conn.cursor() # Build date filter date_filter = "" params = [] if start_date: date_filter += " AND detected_at >= ?" params.append(start_date) if end_date: date_filter += " AND detected_at <= ?" params.append(end_date) # Total detections cursor.execute( f"SELECT COUNT(*) as count FROM detections WHERE 1=1{date_filter}", params, ) total_count = cursor.fetchone()["count"] # Count by class cursor.execute( f""" SELECT class_name, COUNT(*) as count FROM detections WHERE 1=1{date_filter} GROUP BY class_name ORDER BY count DESC """, params, ) class_counts = { row["class_name"]: row["count"] for row in cursor.fetchall() } # Average confidence cursor.execute( f"SELECT AVG(confidence) as avg_conf FROM detections WHERE 1=1{date_filter}", params, ) avg_confidence = cursor.fetchone()["avg_conf"] or 0 # Confidence distribution cursor.execute( f""" SELECT CASE WHEN confidence < 0.3 THEN 'low' WHEN confidence < 0.7 THEN 'medium' ELSE 'high' END as conf_level, COUNT(*) as count FROM detections WHERE 1=1{date_filter} GROUP BY conf_level """, params, ) conf_dist = {row["conf_level"]: row["count"] for row in cursor.fetchall()} return { "total_detections": total_count, "class_counts": class_counts, "average_confidence": avg_confidence, "confidence_distribution": conf_dist, } finally: conn.close() def get_class_distribution(self, model_id: Optional[int] = None) -> Dict[str, int]: """Get count of detections per class.""" conn = self.get_connection() try: cursor = conn.cursor() query = "SELECT class_name, COUNT(*) as count FROM detections" params = [] if model_id: query += " WHERE model_id = ?" params.append(model_id) query += " GROUP BY class_name ORDER BY count DESC" cursor.execute(query, params) return {row["class_name"]: row["count"] for row in cursor.fetchall()} finally: conn.close() # ==================== Export Operations ==================== def export_detections_to_csv( self, output_path: str, filters: Optional[Dict] = None ) -> bool: """Export detections to CSV file.""" try: detections = self.get_detections(filters) with open(output_path, "w", newline="") as csvfile: if not detections: return True fieldnames = [ "id", "image_path", "model_name", "model_version", "class_name", "x_min", "y_min", "x_max", "y_max", "confidence", "segmentation_mask", "detected_at", ] writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() for det in detections: row = {k: det[k] for k in fieldnames if k in det} # Convert segmentation mask list to JSON string for CSV if row.get("segmentation_mask") and isinstance( row["segmentation_mask"], list ): row["segmentation_mask"] = json.dumps(row["segmentation_mask"]) writer.writerow(row) return True except Exception as e: print(f"Error exporting to CSV: {e}") return False def export_detections_to_json( self, output_path: str, filters: Optional[Dict] = None ) -> bool: """Export detections to JSON file.""" try: detections = self.get_detections(filters) # Convert datetime objects to strings for det in detections: if isinstance(det.get("detected_at"), datetime): det["detected_at"] = det["detected_at"].isoformat() with open(output_path, "w") as jsonfile: json.dump(detections, jsonfile, indent=2) return True except Exception as e: print(f"Error exporting to JSON: {e}") return False # ==================== Annotation Operations ==================== def add_annotation( self, image_id: int, class_id: int, bbox: Tuple[float, float, float, float], annotator: str, segmentation_mask: Optional[List[List[float]]] = None, verified: bool = False, ) -> int: """ Add manual annotation. Args: image_id: ID of the image class_id: ID of the object class (foreign key to object_classes) bbox: Bounding box coordinates (normalized 0-1) annotator: Name of person/tool creating annotation segmentation_mask: Polygon coordinates for segmentation verified: Whether annotation has been verified Returns: ID of the inserted annotation """ conn = self.get_connection() try: cursor = conn.cursor() x_min, y_min, x_max, y_max = bbox cursor.execute( """ INSERT INTO annotations (image_id, class_id, x_min, y_min, x_max, y_max, segmentation_mask, annotator, verified) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( image_id, class_id, x_min, y_min, x_max, y_max, json.dumps(segmentation_mask) if segmentation_mask else None, annotator, verified, ), ) conn.commit() return cursor.lastrowid finally: conn.close() def get_annotations_for_image(self, image_id: int) -> List[Dict]: """ Get all annotations for an image with class information. Args: image_id: ID of the image Returns: List of annotation dictionaries with joined class information """ conn = self.get_connection() try: cursor = conn.cursor() cursor.execute( """ SELECT a.*, c.class_name, c.color as class_color, c.description as class_description FROM annotations a JOIN object_classes c ON a.class_id = c.id WHERE a.image_id = ? ORDER BY a.created_at DESC """, (image_id,), ) annotations = [] for row in cursor.fetchall(): ann = dict(row) if ann.get("segmentation_mask"): ann["segmentation_mask"] = json.loads(ann["segmentation_mask"]) annotations.append(ann) return annotations finally: conn.close() def delete_annotation(self, annotation_id: int) -> bool: """ Delete a manual annotation by ID. Args: annotation_id: ID of the annotation to delete Returns: True if an annotation was deleted, False otherwise. """ conn = self.get_connection() try: cursor = conn.cursor() cursor.execute("DELETE FROM annotations WHERE id = ?", (annotation_id,)) conn.commit() return cursor.rowcount > 0 finally: conn.close() # ==================== Object Class Operations ==================== def get_object_classes(self) -> List[Dict]: """ Get all object classes. Returns: List of object class dictionaries """ conn = self.get_connection() try: cursor = conn.cursor() cursor.execute("SELECT * FROM object_classes ORDER BY class_name") return [dict(row) for row in cursor.fetchall()] finally: conn.close() def get_object_class_by_id(self, class_id: int) -> Optional[Dict]: """Get object class by ID.""" conn = self.get_connection() try: cursor = conn.cursor() cursor.execute("SELECT * FROM object_classes WHERE id = ?", (class_id,)) row = cursor.fetchone() return dict(row) if row else None finally: conn.close() def get_object_class_by_name(self, class_name: str) -> Optional[Dict]: """Get object class by name.""" conn = self.get_connection() try: cursor = conn.cursor() cursor.execute( "SELECT * FROM object_classes WHERE class_name = ?", (class_name,) ) row = cursor.fetchone() return dict(row) if row else None finally: conn.close() def add_object_class( self, class_name: str, color: str, description: Optional[str] = None ) -> int: """ Add a new object class. Args: class_name: Name of the object class color: Hex color code (e.g., '#FF0000') description: Optional description Returns: ID of the inserted object class """ conn = self.get_connection() try: cursor = conn.cursor() cursor.execute( """ INSERT INTO object_classes (class_name, color, description) VALUES (?, ?, ?) """, (class_name, color, description), ) conn.commit() return cursor.lastrowid except sqlite3.IntegrityError: # Class already exists existing = self.get_object_class_by_name(class_name) return existing["id"] if existing else None finally: conn.close() def update_object_class( self, class_id: int, class_name: Optional[str] = None, color: Optional[str] = None, description: Optional[str] = None, ) -> bool: """ Update an object class. Args: class_id: ID of the class to update class_name: New class name (optional) color: New color (optional) description: New description (optional) Returns: True if updated, False otherwise """ conn = self.get_connection() try: updates = {} if class_name is not None: updates["class_name"] = class_name if color is not None: updates["color"] = color if description is not None: updates["description"] = description if not updates: return False set_clauses = [f"{key} = ?" for key in updates.keys()] params = list(updates.values()) + [class_id] query = f"UPDATE object_classes SET {', '.join(set_clauses)} WHERE id = ?" cursor = conn.cursor() cursor.execute(query, params) conn.commit() return cursor.rowcount > 0 finally: conn.close() def delete_object_class(self, class_id: int) -> bool: """ Delete an object class. Args: class_id: ID of the class to delete Returns: True if deleted, False otherwise """ conn = self.get_connection() try: cursor = conn.cursor() cursor.execute("DELETE FROM object_classes WHERE id = ?", (class_id,)) conn.commit() return cursor.rowcount > 0 finally: conn.close() # ==================== Dataset Utilities ==================== def compose_data_yaml( self, dataset_root: str, output_path: Optional[str] = None, splits: Optional[Dict[str, str]] = None, ) -> str: """ Compose a YOLO data.yaml file based on dataset folders and database metadata. Args: dataset_root: Base directory containing the dataset structure. output_path: Optional output path; defaults to /data.yaml. splits: Optional mapping overriding train/val/test image directories (relative to dataset_root or absolute paths). Returns: Path to the generated YAML file. """ dataset_root_path = Path(dataset_root).expanduser() if not dataset_root_path.exists(): raise ValueError(f"Dataset root does not exist: {dataset_root_path}") dataset_root_path = dataset_root_path.resolve() split_map: Dict[str, str] = {key: "" for key in ("train", "val", "test")} if splits: for key, value in splits.items(): if key in split_map and value: split_map[key] = value inferred = self._infer_split_dirs(dataset_root_path) for key in split_map: if not split_map[key]: split_map[key] = inferred.get(key, "") for required in ("train", "val"): if not split_map[required]: raise ValueError( "Unable to determine %s image directory under %s. Provide it " "explicitly via the 'splits' argument." % (required, dataset_root_path) ) yaml_splits: Dict[str, str] = {} for key, value in split_map.items(): if not value: continue yaml_splits[key] = self._normalize_split_value(value, dataset_root_path) class_names = self._fetch_annotation_class_names() if not class_names: class_names = [cls["class_name"] for cls in self.get_object_classes()] if not class_names: raise ValueError("No object classes available to populate data.yaml") names_map = {idx: name for idx, name in enumerate(class_names)} payload: Dict[str, Any] = { "path": dataset_root_path.as_posix(), "train": yaml_splits["train"], "val": yaml_splits["val"], "names": names_map, "nc": len(class_names), } if yaml_splits.get("test"): payload["test"] = yaml_splits["test"] output_path_obj = ( Path(output_path).expanduser() if output_path else dataset_root_path / "data.yaml" ) output_path_obj.parent.mkdir(parents=True, exist_ok=True) with open(output_path_obj, "w", encoding="utf-8") as handle: yaml.safe_dump(payload, handle, sort_keys=False) logger.info(f"Generated data.yaml at {output_path_obj}") return output_path_obj.as_posix() def _fetch_annotation_class_names(self) -> List[str]: """Return class names referenced by annotations (ordered by class ID).""" conn = self.get_connection() try: cursor = conn.cursor() cursor.execute( """ SELECT DISTINCT c.id, c.class_name FROM annotations a JOIN object_classes c ON a.class_id = c.id ORDER BY c.id """ ) rows = cursor.fetchall() return [row["class_name"] for row in rows] finally: conn.close() def _infer_split_dirs(self, dataset_root: Path) -> Dict[str, str]: """Infer train/val/test image directories relative to dataset_root.""" patterns = { "train": [ "train/images", "training/images", "images/train", "images/training", "train", "training", ], "val": [ "val/images", "validation/images", "images/val", "images/validation", "val", "validation", ], "test": [ "test/images", "testing/images", "images/test", "images/testing", "test", "testing", ], } inferred: Dict[str, str] = {key: "" for key in patterns} for split_name, options in patterns.items(): for relative in options: candidate = (dataset_root / relative).resolve() if ( candidate.exists() and candidate.is_dir() and self._directory_has_images(candidate) ): try: inferred[split_name] = candidate.relative_to( dataset_root ).as_posix() except ValueError: inferred[split_name] = candidate.as_posix() break return inferred def _normalize_split_value(self, split_value: str, dataset_root: Path) -> str: """Validate and normalize a split directory to a YAML-friendly string.""" split_path = Path(split_value).expanduser() if not split_path.is_absolute(): split_path = (dataset_root / split_path).resolve() else: split_path = split_path.resolve() if not split_path.exists() or not split_path.is_dir(): raise ValueError(f"Split directory not found: {split_path}") if not self._directory_has_images(split_path): raise ValueError(f"No images found under {split_path}") try: return split_path.relative_to(dataset_root).as_posix() except ValueError: return split_path.as_posix() @staticmethod def _directory_has_images(directory: Path, max_checks: int = 2000) -> bool: """Return True if directory tree contains at least one image file.""" checked = 0 try: for file_path in directory.rglob("*"): if not file_path.is_file(): continue if file_path.suffix.lower() in IMAGE_EXTENSIONS: return True checked += 1 if checked >= max_checks: break except Exception: return False return False @staticmethod def calculate_checksum(file_path: str) -> str: """Calculate MD5 checksum of a file.""" md5_hash = hashlib.md5() with open(file_path, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): md5_hash.update(chunk) return md5_hash.hexdigest()