diff --git a/src/database/db_manager.py b/src/database/db_manager.py index 1a331e8..f9572c3 100644 --- a/src/database/db_manager.py +++ b/src/database/db_manager.py @@ -60,9 +60,7 @@ class DatabaseManager: cursor = conn.cursor() # Check if annotations table exists - cursor.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name='annotations'" - ) + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='annotations'") if not cursor.fetchone(): # Table doesn't exist yet, no migration needed return @@ -242,9 +240,7 @@ class DatabaseManager: return cursor.lastrowid except sqlite3.IntegrityError: # Image already exists, return its ID - cursor.execute( - "SELECT id FROM images WHERE relative_path = ?", (relative_path,) - ) + cursor.execute("SELECT id FROM images WHERE relative_path = ?", (relative_path,)) row = cursor.fetchone() return row["id"] if row else None finally: @@ -255,17 +251,13 @@ class DatabaseManager: conn = self.get_connection() try: cursor = conn.cursor() - cursor.execute( - "SELECT * FROM images WHERE relative_path = ?", (relative_path,) - ) + cursor.execute("SELECT * FROM images WHERE relative_path = ?", (relative_path,)) row = cursor.fetchone() return dict(row) if row else None finally: conn.close() - def get_or_create_image( - self, relative_path: str, filename: str, width: int, height: int - ) -> int: + def get_or_create_image(self, relative_path: str, filename: str, width: int, height: int) -> int: """Get existing image or create new one.""" existing = self.get_image_by_path(relative_path) if existing: @@ -355,16 +347,8 @@ class DatabaseManager: bbox[2], bbox[3], det["confidence"], - ( - json.dumps(det.get("segmentation_mask")) - if det.get("segmentation_mask") - else None - ), - ( - json.dumps(det.get("metadata")) - if det.get("metadata") - else None - ), + (json.dumps(det.get("segmentation_mask")) if det.get("segmentation_mask") else None), + (json.dumps(det.get("metadata")) if det.get("metadata") else None), ), ) conn.commit() @@ -409,15 +393,16 @@ class DatabaseManager: if filters: conditions = [] for key, value in filters.items(): - if ( - key.startswith("d.") - or key.startswith("i.") - or key.startswith("m.") - ): - conditions.append(f"{key} = ?") + if key.startswith("d.") or key.startswith("i.") or key.startswith("m."): + if "like" in value.lower(): + conditions.append(f"{key} LIKE ?") + params.append(value.split(" ")[1]) + else: + conditions.append(f"{key} = ?") + params.append(value) else: conditions.append(f"d.{key} = ?") - params.append(value) + params.append(value) query += " WHERE " + " AND ".join(conditions) query += " ORDER BY d.detected_at DESC" @@ -442,18 +427,14 @@ class DatabaseManager: finally: conn.close() - def get_detections_for_image( - self, image_id: int, model_id: Optional[int] = None - ) -> List[Dict]: + def get_detections_for_image(self, image_id: int, model_id: Optional[int] = None) -> List[Dict]: """Get all detections for a specific image.""" filters = {"image_id": image_id} if model_id: filters["model_id"] = model_id return self.get_detections(filters) - def delete_detections_for_image( - self, image_id: int, model_id: Optional[int] = None - ) -> int: + def delete_detections_for_image(self, image_id: int, model_id: Optional[int] = None) -> int: """Delete detections tied to a specific image and optional model.""" conn = self.get_connection() try: @@ -524,9 +505,7 @@ class DatabaseManager: """, params, ) - class_counts = { - row["class_name"]: row["count"] for row in cursor.fetchall() - } + class_counts = {row["class_name"]: row["count"] for row in cursor.fetchall()} # Average confidence cursor.execute( @@ -583,9 +562,7 @@ class DatabaseManager: # ==================== Export Operations ==================== - def export_detections_to_csv( - self, output_path: str, filters: Optional[Dict] = None - ) -> bool: + def export_detections_to_csv(self, output_path: str, filters: Optional[Dict] = None) -> bool: """Export detections to CSV file.""" try: detections = self.get_detections(filters) @@ -614,9 +591,7 @@ class DatabaseManager: for det in detections: row = {k: det[k] for k in fieldnames if k in det} # Convert segmentation mask list to JSON string for CSV - if row.get("segmentation_mask") and isinstance( - row["segmentation_mask"], list - ): + if row.get("segmentation_mask") and isinstance(row["segmentation_mask"], list): row["segmentation_mask"] = json.dumps(row["segmentation_mask"]) writer.writerow(row) @@ -625,9 +600,7 @@ class DatabaseManager: print(f"Error exporting to CSV: {e}") return False - def export_detections_to_json( - self, output_path: str, filters: Optional[Dict] = None - ) -> bool: + def export_detections_to_json(self, output_path: str, filters: Optional[Dict] = None) -> bool: """Export detections to JSON file.""" try: detections = self.get_detections(filters) @@ -785,17 +758,13 @@ class DatabaseManager: conn = self.get_connection() try: cursor = conn.cursor() - cursor.execute( - "SELECT * FROM object_classes WHERE class_name = ?", (class_name,) - ) + cursor.execute("SELECT * FROM object_classes WHERE class_name = ?", (class_name,)) row = cursor.fetchone() return dict(row) if row else None finally: conn.close() - def add_object_class( - self, class_name: str, color: str, description: Optional[str] = None - ) -> int: + def add_object_class(self, class_name: str, color: str, description: Optional[str] = None) -> int: """ Add a new object class. @@ -928,8 +897,7 @@ class DatabaseManager: if not split_map[required]: raise ValueError( "Unable to determine %s image directory under %s. Provide it " - "explicitly via the 'splits' argument." - % (required, dataset_root_path) + "explicitly via the 'splits' argument." % (required, dataset_root_path) ) yaml_splits: Dict[str, str] = {} @@ -955,11 +923,7 @@ class DatabaseManager: if yaml_splits.get("test"): payload["test"] = yaml_splits["test"] - output_path_obj = ( - Path(output_path).expanduser() - if output_path - else dataset_root_path / "data.yaml" - ) + output_path_obj = Path(output_path).expanduser() if output_path else dataset_root_path / "data.yaml" output_path_obj.parent.mkdir(parents=True, exist_ok=True) with open(output_path_obj, "w", encoding="utf-8") as handle: @@ -1019,15 +983,9 @@ class DatabaseManager: for split_name, options in patterns.items(): for relative in options: candidate = (dataset_root / relative).resolve() - if ( - candidate.exists() - and candidate.is_dir() - and self._directory_has_images(candidate) - ): + if candidate.exists() and candidate.is_dir() and self._directory_has_images(candidate): try: - inferred[split_name] = candidate.relative_to( - dataset_root - ).as_posix() + inferred[split_name] = candidate.relative_to(dataset_root).as_posix() except ValueError: inferred[split_name] = candidate.as_posix() break