diff --git a/src/__init__.py b/src/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/database/__init__.py b/src/database/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/database/db_manager.py b/src/database/db_manager.py
new file mode 100644
index 0000000..4329499
--- /dev/null
+++ b/src/database/db_manager.py
@@ -0,0 +1,619 @@
+"""
+Database manager for the microscopy object detection application.
+Handles all database operations including CRUD operations, queries, and exports.
+"""
+
+import sqlite3
+import json
+from datetime import datetime
+from typing import List, Dict, Optional, Tuple, Any
+from pathlib import Path
+import csv
+import hashlib
+
+
+class DatabaseManager:
+ """Manages all database operations for the application."""
+
+ def __init__(self, db_path: str = "data/detections.db"):
+ """
+ Initialize database manager.
+
+ Args:
+ db_path: Path to SQLite database file
+ """
+ self.db_path = db_path
+ self._ensure_database_exists()
+
+ def _ensure_database_exists(self) -> None:
+ """Create database and tables if they don't exist."""
+ # Create directory if it doesn't exist
+ Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
+
+ # Read schema file and execute
+ schema_path = Path(__file__).parent / "schema.sql"
+ with open(schema_path, "r") as f:
+ schema_sql = f.read()
+
+ conn = self.get_connection()
+ try:
+ conn.executescript(schema_sql)
+ conn.commit()
+ finally:
+ conn.close()
+
+ def get_connection(self) -> sqlite3.Connection:
+ """Get database connection with proper settings."""
+ conn = sqlite3.connect(self.db_path)
+ conn.row_factory = sqlite3.Row # Enable column access by name
+ conn.execute("PRAGMA foreign_keys = ON") # Enable foreign keys
+ return conn
+
+ # ==================== Model Operations ====================
+
+ def add_model(
+ self,
+ model_name: str,
+ model_version: str,
+ model_path: str,
+ base_model: str = "yolov8s.pt",
+ training_params: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ ) -> int:
+ """
+ Add a new model to the database.
+
+ Args:
+ model_name: Name of the model
+ model_version: Version string
+ model_path: Path to model weights file
+ base_model: Base model used for training
+ training_params: Dictionary of training parameters
+ metrics: Dictionary of validation metrics
+
+ Returns:
+ ID of the inserted model
+ """
+ conn = self.get_connection()
+ try:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ INSERT INTO models (model_name, model_version, model_path, base_model, training_params, metrics)
+ VALUES (?, ?, ?, ?, ?, ?)
+ """,
+ (
+ model_name,
+ model_version,
+ model_path,
+ base_model,
+ json.dumps(training_params) if training_params else None,
+ json.dumps(metrics) if metrics else None,
+ ),
+ )
+ conn.commit()
+ return cursor.lastrowid
+ finally:
+ conn.close()
+
+ def get_models(self, filters: Optional[Dict] = None) -> List[Dict]:
+ """
+ Retrieve models from database.
+
+ Args:
+ filters: Optional filters (e.g., {'model_name': 'my_model'})
+
+ Returns:
+ List of model dictionaries
+ """
+ conn = self.get_connection()
+ try:
+ query = "SELECT * FROM models"
+ params = []
+
+ if filters:
+ conditions = []
+ for key, value in filters.items():
+ conditions.append(f"{key} = ?")
+ params.append(value)
+ query += " WHERE " + " AND ".join(conditions)
+
+ query += " ORDER BY created_at DESC"
+
+ cursor = conn.cursor()
+ cursor.execute(query, params)
+
+ models = []
+ for row in cursor.fetchall():
+ model = dict(row)
+ # Parse JSON fields
+ if model["training_params"]:
+ model["training_params"] = json.loads(model["training_params"])
+ if model["metrics"]:
+ model["metrics"] = json.loads(model["metrics"])
+ models.append(model)
+
+ return models
+ finally:
+ conn.close()
+
+ def get_model_by_id(self, model_id: int) -> Optional[Dict]:
+ """Get model by ID."""
+ models = self.get_models({"id": model_id})
+ return models[0] if models else None
+
+ def update_model(self, model_id: int, updates: Dict) -> bool:
+ """Update model fields."""
+ conn = self.get_connection()
+ try:
+ # Build update query
+ set_clauses = []
+ params = []
+ for key, value in updates.items():
+ if key in ["training_params", "metrics"] and isinstance(value, dict):
+ value = json.dumps(value)
+ set_clauses.append(f"{key} = ?")
+ params.append(value)
+
+ params.append(model_id)
+
+ query = f"UPDATE models SET {', '.join(set_clauses)} WHERE id = ?"
+ cursor = conn.cursor()
+ cursor.execute(query, params)
+ conn.commit()
+ return cursor.rowcount > 0
+ finally:
+ conn.close()
+
+ # ==================== Image Operations ====================
+
+ def add_image(
+ self,
+ relative_path: str,
+ filename: str,
+ width: int,
+ height: int,
+ captured_at: Optional[datetime] = None,
+ checksum: Optional[str] = None,
+ ) -> int:
+ """
+ Add a new image to the database.
+
+ Args:
+ relative_path: Path relative to image repository
+ filename: Image filename
+ width: Image width in pixels
+ height: Image height in pixels
+ captured_at: When image was captured (if known)
+ checksum: MD5 checksum of image file
+
+ Returns:
+ ID of the inserted image
+ """
+ conn = self.get_connection()
+ try:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ INSERT INTO images (relative_path, filename, width, height, captured_at, checksum)
+ VALUES (?, ?, ?, ?, ?, ?)
+ """,
+ (relative_path, filename, width, height, captured_at, checksum),
+ )
+ conn.commit()
+ return cursor.lastrowid
+ except sqlite3.IntegrityError:
+ # Image already exists, return its ID
+ cursor.execute(
+ "SELECT id FROM images WHERE relative_path = ?", (relative_path,)
+ )
+ row = cursor.fetchone()
+ return row["id"] if row else None
+ finally:
+ conn.close()
+
+ def get_image_by_path(self, relative_path: str) -> Optional[Dict]:
+ """Get image by relative path."""
+ conn = self.get_connection()
+ try:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT * FROM images WHERE relative_path = ?", (relative_path,)
+ )
+ row = cursor.fetchone()
+ return dict(row) if row else None
+ finally:
+ conn.close()
+
+ def get_or_create_image(
+ self, relative_path: str, filename: str, width: int, height: int
+ ) -> int:
+ """Get existing image or create new one."""
+ existing = self.get_image_by_path(relative_path)
+ if existing:
+ return existing["id"]
+ return self.add_image(relative_path, filename, width, height)
+
+ # ==================== Detection Operations ====================
+
+ def add_detection(
+ self,
+ image_id: int,
+ model_id: int,
+ class_name: str,
+ bbox: Tuple[float, float, float, float], # (x_min, y_min, x_max, y_max)
+ confidence: float,
+ metadata: Optional[Dict] = None,
+ ) -> int:
+ """
+ Add a new detection to the database.
+
+ Args:
+ image_id: ID of the image
+ model_id: ID of the model used
+ class_name: Detected object class
+ bbox: Bounding box coordinates (normalized 0-1)
+ confidence: Detection confidence score
+ metadata: Additional metadata
+
+ Returns:
+ ID of the inserted detection
+ """
+ conn = self.get_connection()
+ try:
+ cursor = conn.cursor()
+ x_min, y_min, x_max, y_max = bbox
+ cursor.execute(
+ """
+ INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, metadata)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
+ """,
+ (
+ image_id,
+ model_id,
+ class_name,
+ x_min,
+ y_min,
+ x_max,
+ y_max,
+ confidence,
+ json.dumps(metadata) if metadata else None,
+ ),
+ )
+ conn.commit()
+ return cursor.lastrowid
+ finally:
+ conn.close()
+
+ def add_detections_batch(self, detections: List[Dict]) -> int:
+ """
+ Add multiple detections in a single transaction.
+
+ Args:
+ detections: List of detection dictionaries
+
+ Returns:
+ Number of detections inserted
+ """
+ conn = self.get_connection()
+ try:
+ cursor = conn.cursor()
+ for det in detections:
+ bbox = det["bbox"]
+ cursor.execute(
+ """
+ INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, metadata)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
+ """,
+ (
+ det["image_id"],
+ det["model_id"],
+ det["class_name"],
+ bbox[0],
+ bbox[1],
+ bbox[2],
+ bbox[3],
+ det["confidence"],
+ (
+ json.dumps(det.get("metadata"))
+ if det.get("metadata")
+ else None
+ ),
+ ),
+ )
+ conn.commit()
+ return len(detections)
+ finally:
+ conn.close()
+
+ def get_detections(
+ self,
+ filters: Optional[Dict] = None,
+ limit: Optional[int] = None,
+ offset: int = 0,
+ ) -> List[Dict]:
+ """
+ Retrieve detections from database.
+
+ Args:
+ filters: Optional filters for querying
+ limit: Maximum number of results
+ offset: Number of results to skip
+
+ Returns:
+ List of detection dictionaries with joined data
+ """
+ conn = self.get_connection()
+ try:
+ query = """
+ SELECT
+ d.*,
+ i.relative_path as image_path,
+ i.filename as image_filename,
+ i.width as image_width,
+ i.height as image_height,
+ m.model_name,
+ m.model_version
+ FROM detections d
+ JOIN images i ON d.image_id = i.id
+ JOIN models m ON d.model_id = m.id
+ """
+ params = []
+
+ if filters:
+ conditions = []
+ for key, value in filters.items():
+ if (
+ key.startswith("d.")
+ or key.startswith("i.")
+ or key.startswith("m.")
+ ):
+ conditions.append(f"{key} = ?")
+ else:
+ conditions.append(f"d.{key} = ?")
+ params.append(value)
+ query += " WHERE " + " AND ".join(conditions)
+
+ query += " ORDER BY d.detected_at DESC"
+
+ if limit:
+ query += f" LIMIT {limit} OFFSET {offset}"
+
+ cursor = conn.cursor()
+ cursor.execute(query, params)
+
+ detections = []
+ for row in cursor.fetchall():
+ det = dict(row)
+ # Parse JSON metadata
+ if det.get("metadata"):
+ det["metadata"] = json.loads(det["metadata"])
+ detections.append(det)
+
+ return detections
+ finally:
+ conn.close()
+
+ def get_detections_for_image(
+ self, image_id: int, model_id: Optional[int] = None
+ ) -> List[Dict]:
+ """Get all detections for a specific image."""
+ filters = {"image_id": image_id}
+ if model_id:
+ filters["model_id"] = model_id
+ return self.get_detections(filters)
+
+ def delete_detections_for_model(self, model_id: int) -> int:
+ """Delete all detections for a specific model."""
+ conn = self.get_connection()
+ try:
+ cursor = conn.cursor()
+ cursor.execute("DELETE FROM detections WHERE model_id = ?", (model_id,))
+ conn.commit()
+ return cursor.rowcount
+ finally:
+ conn.close()
+
+ # ==================== Statistics Operations ====================
+
+ def get_detection_statistics(
+ self, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None
+ ) -> Dict:
+ """
+ Get detection statistics for a date range.
+
+ Returns:
+ Dictionary with statistics (count by class, confidence distribution, etc.)
+ """
+ conn = self.get_connection()
+ try:
+ cursor = conn.cursor()
+
+ # Build date filter
+ date_filter = ""
+ params = []
+ if start_date:
+ date_filter += " AND detected_at >= ?"
+ params.append(start_date)
+ if end_date:
+ date_filter += " AND detected_at <= ?"
+ params.append(end_date)
+
+ # Total detections
+ cursor.execute(
+ f"SELECT COUNT(*) as count FROM detections WHERE 1=1{date_filter}",
+ params,
+ )
+ total_count = cursor.fetchone()["count"]
+
+ # Count by class
+ cursor.execute(
+ f"""
+ SELECT class_name, COUNT(*) as count
+ FROM detections
+ WHERE 1=1{date_filter}
+ GROUP BY class_name
+ ORDER BY count DESC
+ """,
+ params,
+ )
+ class_counts = {
+ row["class_name"]: row["count"] for row in cursor.fetchall()
+ }
+
+ # Average confidence
+ cursor.execute(
+ f"SELECT AVG(confidence) as avg_conf FROM detections WHERE 1=1{date_filter}",
+ params,
+ )
+ avg_confidence = cursor.fetchone()["avg_conf"] or 0
+
+ # Confidence distribution
+ cursor.execute(
+ f"""
+ SELECT
+ CASE
+ WHEN confidence < 0.3 THEN 'low'
+ WHEN confidence < 0.7 THEN 'medium'
+ ELSE 'high'
+ END as conf_level,
+ COUNT(*) as count
+ FROM detections
+ WHERE 1=1{date_filter}
+ GROUP BY conf_level
+ """,
+ params,
+ )
+ conf_dist = {row["conf_level"]: row["count"] for row in cursor.fetchall()}
+
+ return {
+ "total_detections": total_count,
+ "class_counts": class_counts,
+ "average_confidence": avg_confidence,
+ "confidence_distribution": conf_dist,
+ }
+ finally:
+ conn.close()
+
+ def get_class_distribution(self, model_id: Optional[int] = None) -> Dict[str, int]:
+ """Get count of detections per class."""
+ conn = self.get_connection()
+ try:
+ cursor = conn.cursor()
+ query = "SELECT class_name, COUNT(*) as count FROM detections"
+ params = []
+
+ if model_id:
+ query += " WHERE model_id = ?"
+ params.append(model_id)
+
+ query += " GROUP BY class_name ORDER BY count DESC"
+
+ cursor.execute(query, params)
+ return {row["class_name"]: row["count"] for row in cursor.fetchall()}
+ finally:
+ conn.close()
+
+ # ==================== Export Operations ====================
+
+ def export_detections_to_csv(
+ self, output_path: str, filters: Optional[Dict] = None
+ ) -> bool:
+ """Export detections to CSV file."""
+ try:
+ detections = self.get_detections(filters)
+
+ with open(output_path, "w", newline="") as csvfile:
+ if not detections:
+ return True
+
+ fieldnames = [
+ "id",
+ "image_path",
+ "model_name",
+ "model_version",
+ "class_name",
+ "x_min",
+ "y_min",
+ "x_max",
+ "y_max",
+ "confidence",
+ "detected_at",
+ ]
+ writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
+ writer.writeheader()
+
+ for det in detections:
+ row = {k: det[k] for k in fieldnames if k in det}
+ writer.writerow(row)
+
+ return True
+ except Exception as e:
+ print(f"Error exporting to CSV: {e}")
+ return False
+
+ def export_detections_to_json(
+ self, output_path: str, filters: Optional[Dict] = None
+ ) -> bool:
+ """Export detections to JSON file."""
+ try:
+ detections = self.get_detections(filters)
+
+ # Convert datetime objects to strings
+ for det in detections:
+ if isinstance(det.get("detected_at"), datetime):
+ det["detected_at"] = det["detected_at"].isoformat()
+
+ with open(output_path, "w") as jsonfile:
+ json.dump(detections, jsonfile, indent=2)
+
+ return True
+ except Exception as e:
+ print(f"Error exporting to JSON: {e}")
+ return False
+
+ # ==================== Annotation Operations ====================
+
+ def add_annotation(
+ self,
+ image_id: int,
+ class_name: str,
+ bbox: Tuple[float, float, float, float],
+ annotator: str,
+ verified: bool = False,
+ ) -> int:
+ """Add manual annotation."""
+ conn = self.get_connection()
+ try:
+ cursor = conn.cursor()
+ x_min, y_min, x_max, y_max = bbox
+ cursor.execute(
+ """
+ INSERT INTO annotations (image_id, class_name, x_min, y_min, x_max, y_max, annotator, verified)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
+ """,
+ (image_id, class_name, x_min, y_min, x_max, y_max, annotator, verified),
+ )
+ conn.commit()
+ return cursor.lastrowid
+ finally:
+ conn.close()
+
+ def get_annotations_for_image(self, image_id: int) -> List[Dict]:
+ """Get all annotations for an image."""
+ conn = self.get_connection()
+ try:
+ cursor = conn.cursor()
+ cursor.execute("SELECT * FROM annotations WHERE image_id = ?", (image_id,))
+ return [dict(row) for row in cursor.fetchall()]
+ finally:
+ conn.close()
+
+ @staticmethod
+ def calculate_checksum(file_path: str) -> str:
+ """Calculate MD5 checksum of a file."""
+ md5_hash = hashlib.md5()
+ with open(file_path, "rb") as f:
+ for chunk in iter(lambda: f.read(4096), b""):
+ md5_hash.update(chunk)
+ return md5_hash.hexdigest()
diff --git a/src/database/models.py b/src/database/models.py
new file mode 100644
index 0000000..cdcd237
--- /dev/null
+++ b/src/database/models.py
@@ -0,0 +1,63 @@
+"""
+Data models for the microscopy object detection application.
+These dataclasses represent the database entities.
+"""
+
+from dataclasses import dataclass
+from datetime import datetime
+from typing import Optional, Dict, Tuple
+
+
+@dataclass
+class Model:
+ """Represents a trained model."""
+
+ id: Optional[int]
+ model_name: str
+ model_version: str
+ model_path: str
+ base_model: str
+ created_at: datetime
+ training_params: Optional[Dict]
+ metrics: Optional[Dict]
+
+
+@dataclass
+class Image:
+ """Represents an image in the database."""
+
+ id: Optional[int]
+ relative_path: str
+ filename: str
+ width: int
+ height: int
+ captured_at: Optional[datetime]
+ added_at: datetime
+ checksum: Optional[str]
+
+
+@dataclass
+class Detection:
+ """Represents a detection result."""
+
+ id: Optional[int]
+ image_id: int
+ model_id: int
+ class_name: str
+ bbox: Tuple[float, float, float, float] # (x_min, y_min, x_max, y_max)
+ confidence: float
+ detected_at: datetime
+ metadata: Optional[Dict]
+
+
+@dataclass
+class Annotation:
+ """Represents a manual annotation."""
+
+ id: Optional[int]
+ image_id: int
+ class_name: str
+ bbox: Tuple[float, float, float, float] # (x_min, y_min, x_max, y_max)
+ annotator: str
+ created_at: datetime
+ verified: bool
diff --git a/src/database/schema.sql b/src/database/schema.sql
new file mode 100644
index 0000000..f6080f7
--- /dev/null
+++ b/src/database/schema.sql
@@ -0,0 +1,70 @@
+-- Microscopy Object Detection Application - Database Schema
+-- SQLite Database Schema for storing models, images, detections, and annotations
+
+-- Models table: stores trained model information
+CREATE TABLE IF NOT EXISTS models (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ model_name TEXT NOT NULL,
+ model_version TEXT NOT NULL,
+ model_path TEXT NOT NULL,
+ base_model TEXT NOT NULL DEFAULT 'yolov8s.pt',
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ training_params TEXT, -- JSON string of training parameters
+ metrics TEXT, -- JSON string of validation metrics
+ UNIQUE(model_name, model_version)
+);
+
+-- Images table: stores image metadata
+CREATE TABLE IF NOT EXISTS images (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ relative_path TEXT NOT NULL UNIQUE,
+ filename TEXT NOT NULL,
+ width INTEGER,
+ height INTEGER,
+ captured_at TIMESTAMP,
+ added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ checksum TEXT
+);
+
+-- Detections table: stores detection results
+CREATE TABLE IF NOT EXISTS detections (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ image_id INTEGER NOT NULL,
+ model_id INTEGER NOT NULL,
+ class_name TEXT NOT NULL,
+ x_min REAL NOT NULL CHECK(x_min >= 0 AND x_min <= 1),
+ y_min REAL NOT NULL CHECK(y_min >= 0 AND y_min <= 1),
+ x_max REAL NOT NULL CHECK(x_max >= 0 AND x_max <= 1),
+ y_max REAL NOT NULL CHECK(y_max >= 0 AND y_max <= 1),
+ confidence REAL NOT NULL CHECK(confidence >= 0 AND confidence <= 1),
+ detected_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ metadata TEXT, -- JSON string for additional metadata
+ FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE,
+ FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE
+);
+
+-- Annotations table: stores manual annotations (future feature)
+CREATE TABLE IF NOT EXISTS annotations (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ image_id INTEGER NOT NULL,
+ class_name TEXT NOT NULL,
+ x_min REAL NOT NULL CHECK(x_min >= 0 AND x_min <= 1),
+ y_min REAL NOT NULL CHECK(y_min >= 0 AND y_min <= 1),
+ x_max REAL NOT NULL CHECK(x_max >= 0 AND x_max <= 1),
+ y_max REAL NOT NULL CHECK(y_max >= 0 AND y_max <= 1),
+ annotator TEXT,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ verified BOOLEAN DEFAULT 0,
+ FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE
+);
+
+-- Create indexes for performance optimization
+CREATE INDEX IF NOT EXISTS idx_detections_image_id ON detections(image_id);
+CREATE INDEX IF NOT EXISTS idx_detections_model_id ON detections(model_id);
+CREATE INDEX IF NOT EXISTS idx_detections_class_name ON detections(class_name);
+CREATE INDEX IF NOT EXISTS idx_detections_detected_at ON detections(detected_at);
+CREATE INDEX IF NOT EXISTS idx_detections_confidence ON detections(confidence);
+CREATE INDEX IF NOT EXISTS idx_images_relative_path ON images(relative_path);
+CREATE INDEX IF NOT EXISTS idx_images_added_at ON images(added_at);
+CREATE INDEX IF NOT EXISTS idx_annotations_image_id ON annotations(image_id);
+CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at);
\ No newline at end of file
diff --git a/src/gui/__init__.py b/src/gui/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/gui/dialogs/__init__.py b/src/gui/dialogs/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/gui/dialogs/config_dialog.py b/src/gui/dialogs/config_dialog.py
new file mode 100644
index 0000000..9abfe1b
--- /dev/null
+++ b/src/gui/dialogs/config_dialog.py
@@ -0,0 +1,291 @@
+"""
+Configuration dialog for the microscopy object detection application.
+"""
+
+from PySide6.QtWidgets import (
+ QDialog,
+ QVBoxLayout,
+ QHBoxLayout,
+ QFormLayout,
+ QPushButton,
+ QLineEdit,
+ QSpinBox,
+ QDoubleSpinBox,
+ QFileDialog,
+ QTabWidget,
+ QWidget,
+ QLabel,
+ QGroupBox,
+)
+from PySide6.QtCore import Qt
+
+from src.utils.config_manager import ConfigManager
+from src.utils.logger import get_logger
+
+
+logger = get_logger(__name__)
+
+
+class ConfigDialog(QDialog):
+ """Configuration dialog window."""
+
+ def __init__(self, config_manager: ConfigManager, parent=None):
+ super().__init__(parent)
+
+ self.config_manager = config_manager
+
+ self.setWindowTitle("Settings")
+ self.setMinimumWidth(500)
+ self.setMinimumHeight(400)
+
+ self._setup_ui()
+ self._load_settings()
+
+ def _setup_ui(self):
+ """Setup user interface."""
+ layout = QVBoxLayout()
+
+ # Create tab widget for different setting categories
+ self.tab_widget = QTabWidget()
+
+ # General settings tab
+ general_tab = self._create_general_tab()
+ self.tab_widget.addTab(general_tab, "General")
+
+ # Training settings tab
+ training_tab = self._create_training_tab()
+ self.tab_widget.addTab(training_tab, "Training")
+
+ # Detection settings tab
+ detection_tab = self._create_detection_tab()
+ self.tab_widget.addTab(detection_tab, "Detection")
+
+ layout.addWidget(self.tab_widget)
+
+ # Buttons
+ button_layout = QHBoxLayout()
+ button_layout.addStretch()
+
+ self.save_button = QPushButton("Save")
+ self.save_button.clicked.connect(self.accept)
+ button_layout.addWidget(self.save_button)
+
+ self.cancel_button = QPushButton("Cancel")
+ self.cancel_button.clicked.connect(self.reject)
+ button_layout.addWidget(self.cancel_button)
+
+ layout.addLayout(button_layout)
+
+ self.setLayout(layout)
+
+ def _create_general_tab(self) -> QWidget:
+ """Create general settings tab."""
+ widget = QWidget()
+ layout = QVBoxLayout()
+
+ # Image repository group
+ repo_group = QGroupBox("Image Repository")
+ repo_layout = QFormLayout()
+
+ # Repository path
+ path_layout = QHBoxLayout()
+ self.repo_path_edit = QLineEdit()
+ self.repo_path_edit.setPlaceholderText("Path to image repository")
+ path_layout.addWidget(self.repo_path_edit)
+
+ browse_button = QPushButton("Browse...")
+ browse_button.clicked.connect(self._browse_repository)
+ path_layout.addWidget(browse_button)
+
+ repo_layout.addRow("Base Path:", path_layout)
+ repo_group.setLayout(repo_layout)
+ layout.addWidget(repo_group)
+
+ # Database group
+ db_group = QGroupBox("Database")
+ db_layout = QFormLayout()
+
+ self.db_path_edit = QLineEdit()
+ self.db_path_edit.setPlaceholderText("Path to database file")
+ db_layout.addRow("Database Path:", self.db_path_edit)
+
+ db_group.setLayout(db_layout)
+ layout.addWidget(db_group)
+
+ # Models group
+ models_group = QGroupBox("Models")
+ models_layout = QFormLayout()
+
+ self.models_dir_edit = QLineEdit()
+ self.models_dir_edit.setPlaceholderText("Directory for saved models")
+ models_layout.addRow("Models Directory:", self.models_dir_edit)
+
+ self.base_model_edit = QLineEdit()
+ self.base_model_edit.setPlaceholderText("yolov8s.pt")
+ models_layout.addRow("Default Base Model:", self.base_model_edit)
+
+ models_group.setLayout(models_layout)
+ layout.addWidget(models_group)
+
+ layout.addStretch()
+ widget.setLayout(layout)
+ return widget
+
+ def _create_training_tab(self) -> QWidget:
+ """Create training settings tab."""
+ widget = QWidget()
+ layout = QVBoxLayout()
+
+ form_layout = QFormLayout()
+
+ # Epochs
+ self.epochs_spin = QSpinBox()
+ self.epochs_spin.setRange(1, 1000)
+ self.epochs_spin.setValue(100)
+ form_layout.addRow("Default Epochs:", self.epochs_spin)
+
+ # Batch size
+ self.batch_size_spin = QSpinBox()
+ self.batch_size_spin.setRange(1, 128)
+ self.batch_size_spin.setValue(16)
+ form_layout.addRow("Default Batch Size:", self.batch_size_spin)
+
+ # Image size
+ self.imgsz_spin = QSpinBox()
+ self.imgsz_spin.setRange(320, 1280)
+ self.imgsz_spin.setSingleStep(32)
+ self.imgsz_spin.setValue(640)
+ form_layout.addRow("Default Image Size:", self.imgsz_spin)
+
+ # Patience
+ self.patience_spin = QSpinBox()
+ self.patience_spin.setRange(1, 200)
+ self.patience_spin.setValue(50)
+ form_layout.addRow("Default Patience:", self.patience_spin)
+
+ # Learning rate
+ self.lr_spin = QDoubleSpinBox()
+ self.lr_spin.setRange(0.0001, 0.1)
+ self.lr_spin.setSingleStep(0.001)
+ self.lr_spin.setDecimals(4)
+ self.lr_spin.setValue(0.01)
+ form_layout.addRow("Default Learning Rate:", self.lr_spin)
+
+ layout.addLayout(form_layout)
+ layout.addStretch()
+ widget.setLayout(layout)
+ return widget
+
+ def _create_detection_tab(self) -> QWidget:
+ """Create detection settings tab."""
+ widget = QWidget()
+ layout = QVBoxLayout()
+
+ form_layout = QFormLayout()
+
+ # Confidence threshold
+ self.conf_spin = QDoubleSpinBox()
+ self.conf_spin.setRange(0.0, 1.0)
+ self.conf_spin.setSingleStep(0.05)
+ self.conf_spin.setDecimals(2)
+ self.conf_spin.setValue(0.25)
+ form_layout.addRow("Default Confidence:", self.conf_spin)
+
+ # IoU threshold
+ self.iou_spin = QDoubleSpinBox()
+ self.iou_spin.setRange(0.0, 1.0)
+ self.iou_spin.setSingleStep(0.05)
+ self.iou_spin.setDecimals(2)
+ self.iou_spin.setValue(0.45)
+ form_layout.addRow("Default IoU:", self.iou_spin)
+
+ # Max batch size
+ self.max_batch_spin = QSpinBox()
+ self.max_batch_spin.setRange(1, 1000)
+ self.max_batch_spin.setValue(100)
+ form_layout.addRow("Max Batch Size:", self.max_batch_spin)
+
+ layout.addLayout(form_layout)
+ layout.addStretch()
+ widget.setLayout(layout)
+ return widget
+
+ def _browse_repository(self):
+ """Browse for image repository directory."""
+ directory = QFileDialog.getExistingDirectory(
+ self, "Select Image Repository", self.repo_path_edit.text()
+ )
+
+ if directory:
+ self.repo_path_edit.setText(directory)
+
+ def _load_settings(self):
+ """Load current settings into dialog."""
+ # General settings
+ self.repo_path_edit.setText(
+ self.config_manager.get("image_repository.base_path", "")
+ )
+ self.db_path_edit.setText(
+ self.config_manager.get("database.path", "data/detections.db")
+ )
+ self.models_dir_edit.setText(
+ self.config_manager.get("models.models_directory", "data/models")
+ )
+ self.base_model_edit.setText(
+ self.config_manager.get("models.default_base_model", "yolov8s.pt")
+ )
+
+ # Training settings
+ self.epochs_spin.setValue(
+ self.config_manager.get("training.default_epochs", 100)
+ )
+ self.batch_size_spin.setValue(
+ self.config_manager.get("training.default_batch_size", 16)
+ )
+ self.imgsz_spin.setValue(self.config_manager.get("training.default_imgsz", 640))
+ self.patience_spin.setValue(
+ self.config_manager.get("training.default_patience", 50)
+ )
+ self.lr_spin.setValue(self.config_manager.get("training.default_lr0", 0.01))
+
+ # Detection settings
+ self.conf_spin.setValue(
+ self.config_manager.get("detection.default_confidence", 0.25)
+ )
+ self.iou_spin.setValue(self.config_manager.get("detection.default_iou", 0.45))
+ self.max_batch_spin.setValue(
+ self.config_manager.get("detection.max_batch_size", 100)
+ )
+
+ def accept(self):
+ """Save settings and close dialog."""
+ logger.info("Saving configuration")
+
+ # Save general settings
+ self.config_manager.set(
+ "image_repository.base_path", self.repo_path_edit.text()
+ )
+ self.config_manager.set("database.path", self.db_path_edit.text())
+ self.config_manager.set("models.models_directory", self.models_dir_edit.text())
+ self.config_manager.set(
+ "models.default_base_model", self.base_model_edit.text()
+ )
+
+ # Save training settings
+ self.config_manager.set("training.default_epochs", self.epochs_spin.value())
+ self.config_manager.set(
+ "training.default_batch_size", self.batch_size_spin.value()
+ )
+ self.config_manager.set("training.default_imgsz", self.imgsz_spin.value())
+ self.config_manager.set("training.default_patience", self.patience_spin.value())
+ self.config_manager.set("training.default_lr0", self.lr_spin.value())
+
+ # Save detection settings
+ self.config_manager.set("detection.default_confidence", self.conf_spin.value())
+ self.config_manager.set("detection.default_iou", self.iou_spin.value())
+ self.config_manager.set("detection.max_batch_size", self.max_batch_spin.value())
+
+ # Save to file
+ self.config_manager.save_config()
+
+ super().accept()
diff --git a/src/gui/main_window.py b/src/gui/main_window.py
new file mode 100644
index 0000000..71e322b
--- /dev/null
+++ b/src/gui/main_window.py
@@ -0,0 +1,282 @@
+"""
+Main window for the microscopy object detection application.
+"""
+
+from PySide6.QtWidgets import (
+ QMainWindow,
+ QTabWidget,
+ QMenuBar,
+ QMenu,
+ QStatusBar,
+ QMessageBox,
+ QWidget,
+ QVBoxLayout,
+ QLabel,
+)
+from PySide6.QtCore import Qt, QTimer
+from PySide6.QtGui import QAction, QKeySequence
+
+from src.database.db_manager import DatabaseManager
+from src.utils.config_manager import ConfigManager
+from src.utils.logger import get_logger
+from src.gui.dialogs.config_dialog import ConfigDialog
+from src.gui.tabs.detection_tab import DetectionTab
+from src.gui.tabs.training_tab import TrainingTab
+from src.gui.tabs.validation_tab import ValidationTab
+from src.gui.tabs.results_tab import ResultsTab
+from src.gui.tabs.annotation_tab import AnnotationTab
+
+
+logger = get_logger(__name__)
+
+
+class MainWindow(QMainWindow):
+ """Main application window."""
+
+ def __init__(self):
+ super().__init__()
+
+ # Initialize managers
+ self.config_manager = ConfigManager()
+
+ db_path = self.config_manager.get_database_path()
+ self.db_manager = DatabaseManager(db_path)
+
+ logger.info("Main window initializing")
+
+ # Setup UI
+ self.setWindowTitle("Microscopy Object Detection")
+ self.setMinimumSize(1200, 800)
+
+ self._create_menu_bar()
+ self._create_tab_widget()
+ self._create_status_bar()
+
+ # Center window on screen
+ self._center_window()
+
+ logger.info("Main window initialized")
+
+ def _create_menu_bar(self):
+ """Create application menu bar."""
+ menubar = self.menuBar()
+
+ # File menu
+ file_menu = menubar.addMenu("&File")
+
+ settings_action = QAction("&Settings", self)
+ settings_action.setShortcut(QKeySequence("Ctrl+,"))
+ settings_action.triggered.connect(self._show_settings)
+ file_menu.addAction(settings_action)
+
+ file_menu.addSeparator()
+
+ exit_action = QAction("E&xit", self)
+ exit_action.setShortcut(QKeySequence("Ctrl+Q"))
+ exit_action.triggered.connect(self.close)
+ file_menu.addAction(exit_action)
+
+ # View menu
+ view_menu = menubar.addMenu("&View")
+
+ refresh_action = QAction("&Refresh", self)
+ refresh_action.setShortcut(QKeySequence("F5"))
+ refresh_action.triggered.connect(self._refresh_current_tab)
+ view_menu.addAction(refresh_action)
+
+ # Tools menu
+ tools_menu = menubar.addMenu("&Tools")
+
+ db_stats_action = QAction("Database &Statistics", self)
+ db_stats_action.triggered.connect(self._show_database_stats)
+ tools_menu.addAction(db_stats_action)
+
+ # Help menu
+ help_menu = menubar.addMenu("&Help")
+
+ about_action = QAction("&About", self)
+ about_action.triggered.connect(self._show_about)
+ help_menu.addAction(about_action)
+
+ docs_action = QAction("&Documentation", self)
+ docs_action.triggered.connect(self._show_documentation)
+ help_menu.addAction(docs_action)
+
+ def _create_tab_widget(self):
+ """Create main tab widget with all tabs."""
+ self.tab_widget = QTabWidget()
+ self.tab_widget.setTabPosition(QTabWidget.North)
+
+ # Create tabs
+ try:
+ self.detection_tab = DetectionTab(self.db_manager, self.config_manager)
+ self.training_tab = TrainingTab(self.db_manager, self.config_manager)
+ self.validation_tab = ValidationTab(self.db_manager, self.config_manager)
+ self.results_tab = ResultsTab(self.db_manager, self.config_manager)
+ self.annotation_tab = AnnotationTab(self.db_manager, self.config_manager)
+
+ # Add tabs to widget
+ self.tab_widget.addTab(self.detection_tab, "Detection")
+ self.tab_widget.addTab(self.training_tab, "Training")
+ self.tab_widget.addTab(self.validation_tab, "Validation")
+ self.tab_widget.addTab(self.results_tab, "Results")
+ self.tab_widget.addTab(self.annotation_tab, "Annotation (Future)")
+
+ # Connect tab change signal
+ self.tab_widget.currentChanged.connect(self._on_tab_changed)
+
+ except Exception as e:
+ logger.error(f"Error creating tabs: {e}")
+ # Create placeholder
+ placeholder = QWidget()
+ layout = QVBoxLayout()
+ layout.addWidget(QLabel(f"Error creating tabs: {e}"))
+ placeholder.setLayout(layout)
+ self.tab_widget.addTab(placeholder, "Error")
+
+ self.setCentralWidget(self.tab_widget)
+
+ def _create_status_bar(self):
+ """Create status bar."""
+ self.status_bar = QStatusBar()
+ self.setStatusBar(self.status_bar)
+
+ # Add permanent widgets to status bar
+ self.status_label = QLabel("Ready")
+ self.status_bar.addWidget(self.status_label)
+
+ # Initial status message
+ self._update_status("Ready")
+
+ def _center_window(self):
+ """Center window on screen."""
+ screen = self.screen().geometry()
+ size = self.geometry()
+ self.move(
+ (screen.width() - size.width()) // 2, (screen.height() - size.height()) // 2
+ )
+
+ def _show_settings(self):
+ """Show settings dialog."""
+ logger.info("Opening settings dialog")
+ dialog = ConfigDialog(self.config_manager, self)
+ if dialog.exec():
+ self._apply_settings()
+ self._update_status("Settings saved")
+
+ def _apply_settings(self):
+ """Apply changed settings."""
+ logger.info("Applying settings changes")
+ # Reload configuration in all tabs if needed
+ try:
+ if hasattr(self, "detection_tab"):
+ self.detection_tab.refresh()
+ if hasattr(self, "training_tab"):
+ self.training_tab.refresh()
+ if hasattr(self, "results_tab"):
+ self.results_tab.refresh()
+ except Exception as e:
+ logger.error(f"Error applying settings: {e}")
+
+ def _refresh_current_tab(self):
+ """Refresh the current tab."""
+ current_widget = self.tab_widget.currentWidget()
+ if hasattr(current_widget, "refresh"):
+ current_widget.refresh()
+ self._update_status("Tab refreshed")
+
+ def _on_tab_changed(self, index: int):
+ """Handle tab change event."""
+ tab_name = self.tab_widget.tabText(index)
+ logger.debug(f"Switched to tab: {tab_name}")
+ self._update_status(f"Viewing: {tab_name}")
+
+ def _show_database_stats(self):
+ """Show database statistics dialog."""
+ try:
+ stats = self.db_manager.get_detection_statistics()
+
+ message = f"""
+
Database Statistics
+ Total Detections: {stats.get('total_detections', 0)}
+ Average Confidence: {stats.get('average_confidence', 0):.2%}
+ Classes:
+
+ """
+
+ for class_name, count in stats.get("class_counts", {}).items():
+ message += f"- {class_name}: {count}
"
+
+ message += "
"
+
+ QMessageBox.information(self, "Database Statistics", message)
+
+ except Exception as e:
+ logger.error(f"Error getting database stats: {e}")
+ QMessageBox.warning(
+ self, "Error", f"Failed to get database statistics:\n{str(e)}"
+ )
+
+ def _show_about(self):
+ """Show about dialog."""
+ about_text = """
+ Microscopy Object Detection Application
+ Version: 1.0.0
+ A desktop application for detecting organelles and membrane branching
+ structures in microscopy images using YOLOv8.
+
+ Features:
+
+ - Object detection with YOLOv8
+ - Model training and validation
+ - Detection results storage
+ - Interactive visualization
+ - Export capabilities
+
+
+ Technologies:
+
+ - Ultralytics YOLOv8
+ - PySide6
+ - pyqtgraph
+ - SQLite
+
+ """
+
+ QMessageBox.about(self, "About", about_text)
+
+ def _show_documentation(self):
+ """Show documentation."""
+ QMessageBox.information(
+ self,
+ "Documentation",
+ "Please refer to README.md and ARCHITECTURE.md files in the project directory.",
+ )
+
+ def _update_status(self, message: str, timeout: int = 5000):
+ """
+ Update status bar message.
+
+ Args:
+ message: Status message to display
+ timeout: Time in milliseconds to show message (0 for permanent)
+ """
+ self.status_label.setText(message)
+ if timeout > 0:
+ QTimer.singleShot(timeout, lambda: self.status_label.setText("Ready"))
+
+ def closeEvent(self, event):
+ """Handle window close event."""
+ reply = QMessageBox.question(
+ self,
+ "Confirm Exit",
+ "Are you sure you want to exit?",
+ QMessageBox.Yes | QMessageBox.No,
+ QMessageBox.No,
+ )
+
+ if reply == QMessageBox.Yes:
+ logger.info("Application closing")
+ event.accept()
+ else:
+ event.ignore()
diff --git a/src/gui/tabs/__init__.py b/src/gui/tabs/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/gui/tabs/annotation_tab.py b/src/gui/tabs/annotation_tab.py
new file mode 100644
index 0000000..7b2f5fc
--- /dev/null
+++ b/src/gui/tabs/annotation_tab.py
@@ -0,0 +1,48 @@
+"""
+Annotation tab for the microscopy object detection application.
+Future feature for manual annotation.
+"""
+
+from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
+
+from src.database.db_manager import DatabaseManager
+from src.utils.config_manager import ConfigManager
+
+
+class AnnotationTab(QWidget):
+ """Annotation tab placeholder (future feature)."""
+
+ def __init__(
+ self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
+ ):
+ super().__init__(parent)
+ self.db_manager = db_manager
+ self.config_manager = config_manager
+
+ self._setup_ui()
+
+ def _setup_ui(self):
+ """Setup user interface."""
+ layout = QVBoxLayout()
+
+ group = QGroupBox("Annotation Tool (Future Feature)")
+ group_layout = QVBoxLayout()
+ label = QLabel(
+ "Annotation functionality will be implemented in future version.\n\n"
+ "Planned Features:\n"
+ "- Image browser\n"
+ "- Drawing tools for bounding boxes\n"
+ "- Class label assignment\n"
+ "- Export annotations to YOLO format\n"
+ "- Annotation verification"
+ )
+ group_layout.addWidget(label)
+ group.setLayout(group_layout)
+
+ layout.addWidget(group)
+ layout.addStretch()
+ self.setLayout(layout)
+
+ def refresh(self):
+ """Refresh the tab."""
+ pass
diff --git a/src/gui/tabs/detection_tab.py b/src/gui/tabs/detection_tab.py
new file mode 100644
index 0000000..4fe71ce
--- /dev/null
+++ b/src/gui/tabs/detection_tab.py
@@ -0,0 +1,344 @@
+"""
+Detection tab for the microscopy object detection application.
+Handles single image and batch detection.
+"""
+
+from PySide6.QtWidgets import (
+ QWidget,
+ QVBoxLayout,
+ QHBoxLayout,
+ QPushButton,
+ QLabel,
+ QComboBox,
+ QSlider,
+ QFileDialog,
+ QMessageBox,
+ QProgressBar,
+ QTextEdit,
+ QGroupBox,
+ QFormLayout,
+)
+from PySide6.QtCore import Qt, QThread, Signal
+from pathlib import Path
+
+from src.database.db_manager import DatabaseManager
+from src.utils.config_manager import ConfigManager
+from src.utils.logger import get_logger
+from src.utils.file_utils import get_image_files
+from src.model.inference import InferenceEngine
+
+
+logger = get_logger(__name__)
+
+
+class DetectionWorker(QThread):
+ """Worker thread for running detection."""
+
+ progress = Signal(int, int, str) # current, total, message
+ finished = Signal(list) # results
+ error = Signal(str) # error message
+
+ def __init__(self, engine, image_paths, repo_root, conf):
+ super().__init__()
+ self.engine = engine
+ self.image_paths = image_paths
+ self.repo_root = repo_root
+ self.conf = conf
+
+ def run(self):
+ """Run detection in background thread."""
+ try:
+ results = self.engine.detect_batch(
+ self.image_paths, self.repo_root, self.conf, self.progress.emit
+ )
+ self.finished.emit(results)
+ except Exception as e:
+ logger.error(f"Detection error: {e}")
+ self.error.emit(str(e))
+
+
+class DetectionTab(QWidget):
+ """Detection tab for single image and batch detection."""
+
+ def __init__(
+ self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
+ ):
+ super().__init__(parent)
+ self.db_manager = db_manager
+ self.config_manager = config_manager
+ self.inference_engine = None
+ self.current_model_id = None
+
+ self._setup_ui()
+ self._connect_signals()
+ self._load_models()
+
+ def _setup_ui(self):
+ """Setup user interface."""
+ layout = QVBoxLayout()
+
+ # Model selection group
+ model_group = QGroupBox("Model Selection")
+ model_layout = QFormLayout()
+
+ self.model_combo = QComboBox()
+ self.model_combo.addItem("No models available", None)
+ model_layout.addRow("Model:", self.model_combo)
+
+ model_group.setLayout(model_layout)
+ layout.addWidget(model_group)
+
+ # Detection settings group
+ settings_group = QGroupBox("Detection Settings")
+ settings_layout = QFormLayout()
+
+ # Confidence threshold
+ conf_layout = QHBoxLayout()
+ self.conf_slider = QSlider(Qt.Horizontal)
+ self.conf_slider.setRange(0, 100)
+ self.conf_slider.setValue(25)
+ self.conf_slider.setTickPosition(QSlider.TicksBelow)
+ self.conf_slider.setTickInterval(10)
+ conf_layout.addWidget(self.conf_slider)
+
+ self.conf_label = QLabel("0.25")
+ conf_layout.addWidget(self.conf_label)
+
+ settings_layout.addRow("Confidence:", conf_layout)
+ settings_group.setLayout(settings_layout)
+ layout.addWidget(settings_group)
+
+ # Action buttons
+ button_layout = QHBoxLayout()
+
+ self.single_image_btn = QPushButton("Detect Single Image")
+ self.single_image_btn.clicked.connect(self._detect_single_image)
+ button_layout.addWidget(self.single_image_btn)
+
+ self.batch_btn = QPushButton("Detect Batch (Folder)")
+ self.batch_btn.clicked.connect(self._detect_batch)
+ button_layout.addWidget(self.batch_btn)
+
+ layout.addLayout(button_layout)
+
+ # Progress bar
+ self.progress_bar = QProgressBar()
+ self.progress_bar.setVisible(False)
+ layout.addWidget(self.progress_bar)
+
+ # Results display
+ results_group = QGroupBox("Detection Results")
+ results_layout = QVBoxLayout()
+
+ self.results_text = QTextEdit()
+ self.results_text.setReadOnly(True)
+ self.results_text.setMaximumHeight(200)
+ results_layout.addWidget(self.results_text)
+
+ results_group.setLayout(results_layout)
+ layout.addWidget(results_group)
+
+ layout.addStretch()
+ self.setLayout(layout)
+
+ def _connect_signals(self):
+ """Connect signals and slots."""
+ self.conf_slider.valueChanged.connect(self._update_confidence_label)
+ self.model_combo.currentIndexChanged.connect(self._on_model_changed)
+
+ def _load_models(self):
+ """Load available models from database."""
+ try:
+ models = self.db_manager.get_models()
+ self.model_combo.clear()
+
+ if not models:
+ self.model_combo.addItem("No models available", None)
+ self._set_buttons_enabled(False)
+ return
+
+ # Add base model option
+ base_model = self.config_manager.get(
+ "models.default_base_model", "yolov8s.pt"
+ )
+ self.model_combo.addItem(
+ f"Base Model ({base_model})", {"id": 0, "path": base_model}
+ )
+
+ # Add trained models
+ for model in models:
+ display_name = f"{model['model_name']} v{model['model_version']}"
+ self.model_combo.addItem(display_name, model)
+
+ self._set_buttons_enabled(True)
+
+ except Exception as e:
+ logger.error(f"Error loading models: {e}")
+ QMessageBox.warning(self, "Error", f"Failed to load models:\n{str(e)}")
+
+ def _on_model_changed(self, index: int):
+ """Handle model selection change."""
+ model_data = self.model_combo.itemData(index)
+ if model_data and model_data["id"] != 0:
+ self.current_model_id = model_data["id"]
+ else:
+ self.current_model_id = None
+
+ def _update_confidence_label(self, value: int):
+ """Update confidence label."""
+ conf = value / 100.0
+ self.conf_label.setText(f"{conf:.2f}")
+
+ def _detect_single_image(self):
+ """Detect objects in a single image."""
+ # Get image file
+ repo_path = self.config_manager.get_image_repository_path()
+ start_dir = repo_path if repo_path else ""
+
+ file_path, _ = QFileDialog.getOpenFileName(
+ self,
+ "Select Image",
+ start_dir,
+ "Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)",
+ )
+
+ if not file_path:
+ return
+
+ # Run detection
+ self._run_detection([file_path])
+
+ def _detect_batch(self):
+ """Detect objects in batch (folder)."""
+ # Get folder
+ repo_path = self.config_manager.get_image_repository_path()
+ start_dir = repo_path if repo_path else ""
+
+ folder_path = QFileDialog.getExistingDirectory(self, "Select Folder", start_dir)
+
+ if not folder_path:
+ return
+
+ # Get all image files
+ allowed_ext = self.config_manager.get_allowed_extensions()
+ image_files = get_image_files(folder_path, allowed_ext, recursive=False)
+
+ if not image_files:
+ QMessageBox.information(
+ self, "No Images", "No image files found in selected folder."
+ )
+ return
+
+ # Confirm batch processing
+ reply = QMessageBox.question(
+ self,
+ "Confirm Batch Detection",
+ f"Process {len(image_files)} images?",
+ QMessageBox.Yes | QMessageBox.No,
+ )
+
+ if reply == QMessageBox.Yes:
+ self._run_detection(image_files)
+
+ def _run_detection(self, image_paths: list):
+ """Run detection on image list."""
+ try:
+ # Get selected model
+ model_data = self.model_combo.currentData()
+ if not model_data:
+ QMessageBox.warning(self, "No Model", "Please select a model first.")
+ return
+
+ model_path = model_data["path"]
+ model_id = model_data["id"]
+
+ # Ensure we have a valid model ID (create entry for base model if needed)
+ if model_id == 0:
+ # Create database entry for base model
+ base_model = self.config_manager.get(
+ "models.default_base_model", "yolov8s.pt"
+ )
+ model_id = self.db_manager.add_model(
+ model_name="Base Model",
+ model_version="pretrained",
+ model_path=base_model,
+ base_model=base_model,
+ )
+
+ # Create inference engine
+ self.inference_engine = InferenceEngine(
+ model_path, self.db_manager, model_id
+ )
+
+ # Get confidence threshold
+ conf = self.conf_slider.value() / 100.0
+
+ # Get repository root
+ repo_root = self.config_manager.get_image_repository_path()
+ if not repo_root:
+ repo_root = str(Path(image_paths[0]).parent)
+
+ # Show progress bar
+ self.progress_bar.setVisible(True)
+ self.progress_bar.setMaximum(len(image_paths))
+ self._set_buttons_enabled(False)
+
+ # Create and start worker thread
+ self.worker = DetectionWorker(
+ self.inference_engine, image_paths, repo_root, conf
+ )
+ self.worker.progress.connect(self._on_progress)
+ self.worker.finished.connect(self._on_detection_finished)
+ self.worker.error.connect(self._on_detection_error)
+ self.worker.start()
+
+ except Exception as e:
+ logger.error(f"Error starting detection: {e}")
+ QMessageBox.critical(self, "Error", f"Failed to start detection:\n{str(e)}")
+ self._set_buttons_enabled(True)
+
+ def _on_progress(self, current: int, total: int, message: str):
+ """Handle progress update."""
+ self.progress_bar.setValue(current)
+ self.results_text.append(f"[{current}/{total}] {message}")
+
+ def _on_detection_finished(self, results: list):
+ """Handle detection completion."""
+ self.progress_bar.setVisible(False)
+ self._set_buttons_enabled(True)
+
+ # Calculate statistics
+ total_detections = sum(r["count"] for r in results)
+ successful = sum(1 for r in results if r.get("success", False))
+
+ summary = f"\n=== Detection Complete ===\n"
+ summary += f"Processed: {len(results)} images\n"
+ summary += f"Successful: {successful}\n"
+ summary += f"Total detections: {total_detections}\n"
+
+ self.results_text.append(summary)
+
+ QMessageBox.information(
+ self,
+ "Detection Complete",
+ f"Processed {len(results)} images\n{total_detections} objects detected",
+ )
+
+ def _on_detection_error(self, error_msg: str):
+ """Handle detection error."""
+ self.progress_bar.setVisible(False)
+ self._set_buttons_enabled(True)
+
+ self.results_text.append(f"\nERROR: {error_msg}")
+ QMessageBox.critical(self, "Detection Error", error_msg)
+
+ def _set_buttons_enabled(self, enabled: bool):
+ """Enable/disable action buttons."""
+ self.single_image_btn.setEnabled(enabled)
+ self.batch_btn.setEnabled(enabled)
+ self.model_combo.setEnabled(enabled)
+
+ def refresh(self):
+ """Refresh the tab."""
+ self._load_models()
+ self.results_text.clear()
diff --git a/src/gui/tabs/results_tab.py b/src/gui/tabs/results_tab.py
new file mode 100644
index 0000000..71e523c
--- /dev/null
+++ b/src/gui/tabs/results_tab.py
@@ -0,0 +1,46 @@
+"""
+Results tab for the microscopy object detection application.
+"""
+
+from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
+
+from src.database.db_manager import DatabaseManager
+from src.utils.config_manager import ConfigManager
+
+
+class ResultsTab(QWidget):
+ """Results tab placeholder."""
+
+ def __init__(
+ self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
+ ):
+ super().__init__(parent)
+ self.db_manager = db_manager
+ self.config_manager = config_manager
+
+ self._setup_ui()
+
+ def _setup_ui(self):
+ """Setup user interface."""
+ layout = QVBoxLayout()
+
+ group = QGroupBox("Results")
+ group_layout = QVBoxLayout()
+ label = QLabel(
+ "Results viewer will be implemented here.\n\n"
+ "Features:\n"
+ "- Detection history browser\n"
+ "- Advanced filtering\n"
+ "- Statistics dashboard\n"
+ "- Export functionality"
+ )
+ group_layout.addWidget(label)
+ group.setLayout(group_layout)
+
+ layout.addWidget(group)
+ layout.addStretch()
+ self.setLayout(layout)
+
+ def refresh(self):
+ """Refresh the tab."""
+ pass
diff --git a/src/gui/tabs/training_tab.py b/src/gui/tabs/training_tab.py
new file mode 100644
index 0000000..8c3e778
--- /dev/null
+++ b/src/gui/tabs/training_tab.py
@@ -0,0 +1,52 @@
+"""
+Training tab for the microscopy object detection application.
+Handles model training with YOLO.
+"""
+
+from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
+
+from src.database.db_manager import DatabaseManager
+from src.utils.config_manager import ConfigManager
+from src.utils.logger import get_logger
+
+
+logger = get_logger(__name__)
+
+
+class TrainingTab(QWidget):
+ """Training tab for model training."""
+
+ def __init__(
+ self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
+ ):
+ super().__init__(parent)
+ self.db_manager = db_manager
+ self.config_manager = config_manager
+
+ self._setup_ui()
+
+ def _setup_ui(self):
+ """Setup user interface."""
+ layout = QVBoxLayout()
+
+ # Placeholder
+ group = QGroupBox("Training")
+ group_layout = QVBoxLayout()
+ label = QLabel(
+ "Training functionality will be implemented here.\n\n"
+ "Features:\n"
+ "- Dataset selection\n"
+ "- Training parameter configuration\n"
+ "- Real-time training progress\n"
+ "- Loss and metric visualization"
+ )
+ group_layout.addWidget(label)
+ group.setLayout(group_layout)
+
+ layout.addWidget(group)
+ layout.addStretch()
+ self.setLayout(layout)
+
+ def refresh(self):
+ """Refresh the tab."""
+ pass
diff --git a/src/gui/tabs/validation_tab.py b/src/gui/tabs/validation_tab.py
new file mode 100644
index 0000000..2e7749a
--- /dev/null
+++ b/src/gui/tabs/validation_tab.py
@@ -0,0 +1,46 @@
+"""
+Validation tab for the microscopy object detection application.
+"""
+
+from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
+
+from src.database.db_manager import DatabaseManager
+from src.utils.config_manager import ConfigManager
+
+
+class ValidationTab(QWidget):
+ """Validation tab placeholder."""
+
+ def __init__(
+ self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
+ ):
+ super().__init__(parent)
+ self.db_manager = db_manager
+ self.config_manager = config_manager
+
+ self._setup_ui()
+
+ def _setup_ui(self):
+ """Setup user interface."""
+ layout = QVBoxLayout()
+
+ group = QGroupBox("Validation")
+ group_layout = QVBoxLayout()
+ label = QLabel(
+ "Validation functionality will be implemented here.\n\n"
+ "Features:\n"
+ "- Model validation\n"
+ "- Metrics visualization\n"
+ "- Confusion matrix\n"
+ "- Precision-Recall curves"
+ )
+ group_layout.addWidget(label)
+ group.setLayout(group_layout)
+
+ layout.addWidget(group)
+ layout.addStretch()
+ self.setLayout(layout)
+
+ def refresh(self):
+ """Refresh the tab."""
+ pass
diff --git a/src/gui/widgets/__init__.py b/src/gui/widgets/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/model/__init__.py b/src/model/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/model/inference.py b/src/model/inference.py
new file mode 100644
index 0000000..1fc5ab8
--- /dev/null
+++ b/src/model/inference.py
@@ -0,0 +1,323 @@
+"""
+Inference engine for the microscopy object detection application.
+Handles detection inference and result storage.
+"""
+
+from typing import List, Dict, Optional, Callable
+from pathlib import Path
+from PIL import Image
+import cv2
+import numpy as np
+
+from src.model.yolo_wrapper import YOLOWrapper
+from src.database.db_manager import DatabaseManager
+from src.utils.logger import get_logger
+from src.utils.file_utils import get_relative_path
+
+
+logger = get_logger(__name__)
+
+
+class InferenceEngine:
+ """Handles detection inference and result storage."""
+
+ def __init__(self, model_path: str, db_manager: DatabaseManager, model_id: int):
+ """
+ Initialize inference engine.
+
+ Args:
+ model_path: Path to YOLO model weights
+ db_manager: Database manager instance
+ model_id: ID of the model in database
+ """
+ self.yolo = YOLOWrapper(model_path)
+ self.yolo.load_model()
+ self.db_manager = db_manager
+ self.model_id = model_id
+ logger.info(f"InferenceEngine initialized with model_id {model_id}")
+
+ def detect_single(
+ self,
+ image_path: str,
+ relative_path: str,
+ conf: float = 0.25,
+ save_to_db: bool = True,
+ ) -> Dict:
+ """
+ Detect objects in a single image.
+
+ Args:
+ image_path: Absolute path to image file
+ relative_path: Relative path from repository root
+ conf: Confidence threshold
+ save_to_db: Whether to save results to database
+
+ Returns:
+ Dictionary with detection results
+ """
+ try:
+ # Get image dimensions
+ img = Image.open(image_path)
+ width, height = img.size
+ img.close()
+
+ # Perform detection
+ detections = self.yolo.predict(image_path, conf=conf)
+
+ # Add/get image in database
+ image_id = self.db_manager.get_or_create_image(
+ relative_path=relative_path,
+ filename=Path(image_path).name,
+ width=width,
+ height=height,
+ )
+
+ # Save detections to database
+ if save_to_db and detections:
+ detection_records = []
+ for det in detections:
+ # Use normalized bbox from detection
+ bbox_normalized = det[
+ "bbox_normalized"
+ ] # [x_min, y_min, x_max, y_max]
+
+ record = {
+ "image_id": image_id,
+ "model_id": self.model_id,
+ "class_name": det["class_name"],
+ "bbox": tuple(bbox_normalized),
+ "confidence": det["confidence"],
+ "metadata": {"class_id": det["class_id"]},
+ }
+ detection_records.append(record)
+
+ self.db_manager.add_detections_batch(detection_records)
+ logger.info(f"Saved {len(detection_records)} detections to database")
+
+ return {
+ "success": True,
+ "image_path": image_path,
+ "image_id": image_id,
+ "detections": detections,
+ "count": len(detections),
+ }
+
+ except Exception as e:
+ logger.error(f"Error detecting objects in {image_path}: {e}")
+ return {
+ "success": False,
+ "image_path": image_path,
+ "error": str(e),
+ "detections": [],
+ "count": 0,
+ }
+
+ def detect_batch(
+ self,
+ image_paths: List[str],
+ repository_root: str,
+ conf: float = 0.25,
+ progress_callback: Optional[Callable[[int, int, str], None]] = None,
+ ) -> List[Dict]:
+ """
+ Detect objects in multiple images.
+
+ Args:
+ image_paths: List of absolute image paths
+ repository_root: Root directory for relative paths
+ conf: Confidence threshold
+ progress_callback: Optional callback(current, total, message)
+
+ Returns:
+ List of detection result dictionaries
+ """
+ results = []
+ total = len(image_paths)
+
+ logger.info(f"Starting batch detection on {total} images")
+
+ for i, image_path in enumerate(image_paths, 1):
+ # Calculate relative path
+ rel_path = get_relative_path(image_path, repository_root)
+
+ # Perform detection
+ result = self.detect_single(image_path, rel_path, conf)
+ results.append(result)
+
+ # Update progress
+ if progress_callback:
+ progress_callback(i, total, f"Processed {rel_path}")
+
+ if i % 10 == 0:
+ logger.info(f"Processed {i}/{total} images")
+
+ logger.info(f"Batch detection complete: {total} images processed")
+ return results
+
+ def detect_with_visualization(
+ self,
+ image_path: str,
+ conf: float = 0.25,
+ bbox_thickness: int = 2,
+ bbox_colors: Optional[Dict[str, str]] = None,
+ ) -> tuple:
+ """
+ Detect objects and return annotated image.
+
+ Args:
+ image_path: Path to image
+ conf: Confidence threshold
+ bbox_thickness: Thickness of bounding boxes
+ bbox_colors: Dictionary mapping class names to hex colors
+
+ Returns:
+ Tuple of (detections, annotated_image_array)
+ """
+ try:
+ detections = self.yolo.predict(image_path, conf=conf)
+
+ # Load image
+ img = cv2.imread(image_path)
+ if img is None:
+ raise ValueError(f"Failed to load image: {image_path}")
+
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ height, width = img.shape[:2]
+
+ # Default colors if not provided
+ if bbox_colors is None:
+ bbox_colors = {}
+ default_color = self._hex_to_bgr(bbox_colors.get("default", "#00FF00"))
+
+ # Draw bounding boxes
+ for det in detections:
+ # Get absolute coordinates
+ bbox_abs = det["bbox_absolute"]
+ x1, y1, x2, y2 = [int(v) for v in bbox_abs]
+
+ # Get color for this class
+ class_name = det["class_name"]
+ color_hex = bbox_colors.get(
+ class_name, bbox_colors.get("default", "#00FF00")
+ )
+ color = self._hex_to_bgr(color_hex)
+
+ # Draw box
+ cv2.rectangle(img, (x1, y1), (x2, y2), color, bbox_thickness)
+
+ # Prepare label
+ label = f"{class_name} {det['confidence']:.2f}"
+
+ # Draw label background
+ (label_w, label_h), baseline = cv2.getTextSize(
+ label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
+ )
+ cv2.rectangle(
+ img,
+ (x1, y1 - label_h - baseline - 5),
+ (x1 + label_w, y1),
+ color,
+ -1,
+ )
+
+ # Draw label text
+ cv2.putText(
+ img,
+ label,
+ (x1, y1 - baseline - 5),
+ cv2.FONT_HERSHEY_SIMPLEX,
+ 0.5,
+ (255, 255, 255),
+ 1,
+ )
+
+ return detections, img
+
+ except Exception as e:
+ logger.error(f"Error creating visualization: {e}")
+ # Return empty detections and original image if possible
+ try:
+ img = cv2.imread(image_path)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ return [], img
+ except:
+ return [], np.zeros((480, 640, 3), dtype=np.uint8)
+
+ def get_detection_summary(self, detections: List[Dict]) -> Dict[str, any]:
+ """
+ Generate summary statistics for detections.
+
+ Args:
+ detections: List of detection dictionaries
+
+ Returns:
+ Dictionary with summary statistics
+ """
+ if not detections:
+ return {
+ "total_count": 0,
+ "class_counts": {},
+ "avg_confidence": 0.0,
+ "confidence_range": (0.0, 0.0),
+ }
+
+ # Count by class
+ class_counts = {}
+ confidences = []
+
+ for det in detections:
+ class_name = det["class_name"]
+ class_counts[class_name] = class_counts.get(class_name, 0) + 1
+ confidences.append(det["confidence"])
+
+ return {
+ "total_count": len(detections),
+ "class_counts": class_counts,
+ "avg_confidence": sum(confidences) / len(confidences),
+ "confidence_range": (min(confidences), max(confidences)),
+ }
+
+ @staticmethod
+ def _hex_to_bgr(hex_color: str) -> tuple:
+ """
+ Convert hex color to BGR tuple.
+
+ Args:
+ hex_color: Hex color string (e.g., '#FF0000')
+
+ Returns:
+ BGR tuple (B, G, R)
+ """
+ hex_color = hex_color.lstrip("#")
+ if len(hex_color) != 6:
+ return (0, 255, 0) # Default green
+
+ try:
+ r = int(hex_color[0:2], 16)
+ g = int(hex_color[2:4], 16)
+ b = int(hex_color[4:6], 16)
+ return (b, g, r) # OpenCV uses BGR
+ except ValueError:
+ return (0, 255, 0) # Default green
+
+ def change_model(self, model_path: str, model_id: int) -> bool:
+ """
+ Change the current model.
+
+ Args:
+ model_path: Path to new model weights
+ model_id: ID of new model in database
+
+ Returns:
+ True if successful, False otherwise
+ """
+ try:
+ self.yolo = YOLOWrapper(model_path)
+ if self.yolo.load_model():
+ self.model_id = model_id
+ logger.info(f"Model changed to {model_path}")
+ return True
+ return False
+ except Exception as e:
+ logger.error(f"Error changing model: {e}")
+ return False
diff --git a/src/model/yolo_wrapper.py b/src/model/yolo_wrapper.py
new file mode 100644
index 0000000..d4a8050
--- /dev/null
+++ b/src/model/yolo_wrapper.py
@@ -0,0 +1,364 @@
+"""
+YOLO model wrapper for the microscopy object detection application.
+Provides a clean interface to YOLOv8 for training, validation, and inference.
+"""
+
+from ultralytics import YOLO
+from pathlib import Path
+from typing import Optional, List, Dict, Callable, Any
+import torch
+from src.utils.logger import get_logger
+
+
+logger = get_logger(__name__)
+
+
+class YOLOWrapper:
+ """Wrapper for YOLOv8 model operations."""
+
+ def __init__(self, model_path: str = "yolov8s.pt"):
+ """
+ Initialize YOLO model.
+
+ Args:
+ model_path: Path to model weights (.pt file)
+ """
+ self.model_path = model_path
+ self.model = None
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+ logger.info(f"YOLOWrapper initialized with device: {self.device}")
+
+ def load_model(self) -> bool:
+ """
+ Load YOLO model from path.
+
+ Returns:
+ True if loaded successfully, False otherwise
+ """
+ try:
+ logger.info(f"Loading YOLO model from {self.model_path}")
+ self.model = YOLO(self.model_path)
+ self.model.to(self.device)
+ logger.info("Model loaded successfully")
+ return True
+ except Exception as e:
+ logger.error(f"Error loading model: {e}")
+ return False
+
+ def train(
+ self,
+ data_yaml: str,
+ epochs: int = 100,
+ imgsz: int = 640,
+ batch: int = 16,
+ patience: int = 50,
+ save_dir: str = "data/models",
+ name: str = "custom_model",
+ resume: bool = False,
+ **kwargs,
+ ) -> Dict[str, Any]:
+ """
+ Train the YOLO model.
+
+ Args:
+ data_yaml: Path to data.yaml configuration file
+ epochs: Number of training epochs
+ imgsz: Input image size
+ batch: Batch size
+ patience: Early stopping patience
+ save_dir: Directory to save trained model
+ name: Name for the training run
+ resume: Resume training from last checkpoint
+ **kwargs: Additional training arguments
+
+ Returns:
+ Dictionary with training results
+ """
+ if self.model is None:
+ self.load_model()
+
+ try:
+ logger.info(f"Starting training: {name}")
+ logger.info(
+ f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
+ )
+
+ # Train the model
+ results = self.model.train(
+ data=data_yaml,
+ epochs=epochs,
+ imgsz=imgsz,
+ batch=batch,
+ patience=patience,
+ project=save_dir,
+ name=name,
+ device=self.device,
+ resume=resume,
+ **kwargs,
+ )
+
+ logger.info("Training completed successfully")
+ return self._format_training_results(results)
+
+ except Exception as e:
+ logger.error(f"Error during training: {e}")
+ raise
+
+ def validate(self, data_yaml: str, split: str = "val", **kwargs) -> Dict[str, Any]:
+ """
+ Validate the model.
+
+ Args:
+ data_yaml: Path to data.yaml configuration file
+ split: Dataset split to validate on ('val' or 'test')
+ **kwargs: Additional validation arguments
+
+ Returns:
+ Dictionary with validation metrics
+ """
+ if self.model is None:
+ self.load_model()
+
+ try:
+ logger.info(f"Starting validation on {split} split")
+ results = self.model.val(
+ data=data_yaml, split=split, device=self.device, **kwargs
+ )
+
+ logger.info("Validation completed successfully")
+ return self._format_validation_results(results)
+
+ except Exception as e:
+ logger.error(f"Error during validation: {e}")
+ raise
+
+ def predict(
+ self,
+ source: str,
+ conf: float = 0.25,
+ iou: float = 0.45,
+ save: bool = False,
+ save_txt: bool = False,
+ save_conf: bool = False,
+ **kwargs,
+ ) -> List[Dict]:
+ """
+ Perform inference on image(s).
+
+ Args:
+ source: Path to image or directory
+ conf: Confidence threshold
+ iou: IoU threshold for NMS
+ save: Whether to save annotated images
+ save_txt: Whether to save labels to .txt files
+ save_conf: Whether to save confidence in labels
+ **kwargs: Additional prediction arguments
+
+ Returns:
+ List of detection dictionaries
+ """
+ if self.model is None:
+ self.load_model()
+
+ try:
+ logger.info(f"Running inference on {source}")
+ results = self.model.predict(
+ source=source,
+ conf=conf,
+ iou=iou,
+ save=save,
+ save_txt=save_txt,
+ save_conf=save_conf,
+ device=self.device,
+ **kwargs,
+ )
+
+ detections = self._format_prediction_results(results)
+ logger.info(f"Inference complete: {len(detections)} detections")
+ return detections
+
+ except Exception as e:
+ logger.error(f"Error during inference: {e}")
+ raise
+
+ def export(
+ self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
+ ) -> str:
+ """
+ Export model to different format.
+
+ Args:
+ format: Export format (onnx, torchscript, tflite, etc.)
+ output_path: Path for exported model
+ **kwargs: Additional export arguments
+
+ Returns:
+ Path to exported model
+ """
+ if self.model is None:
+ self.load_model()
+
+ try:
+ logger.info(f"Exporting model to {format} format")
+ export_path = self.model.export(format=format, **kwargs)
+ logger.info(f"Model exported to {export_path}")
+ return str(export_path)
+
+ except Exception as e:
+ logger.error(f"Error exporting model: {e}")
+ raise
+
+ def _format_training_results(self, results) -> Dict[str, Any]:
+ """Format training results into dictionary."""
+ try:
+ # Get the results dict
+ results_dict = (
+ results.results_dict if hasattr(results, "results_dict") else {}
+ )
+
+ formatted = {
+ "success": True,
+ "final_epoch": getattr(results, "epoch", 0),
+ "metrics": {
+ "mAP50": float(results_dict.get("metrics/mAP50(B)", 0)),
+ "mAP50-95": float(results_dict.get("metrics/mAP50-95(B)", 0)),
+ "precision": float(results_dict.get("metrics/precision(B)", 0)),
+ "recall": float(results_dict.get("metrics/recall(B)", 0)),
+ },
+ "best_model_path": str(Path(results.save_dir) / "weights" / "best.pt"),
+ "last_model_path": str(Path(results.save_dir) / "weights" / "last.pt"),
+ "save_dir": str(results.save_dir),
+ }
+
+ return formatted
+
+ except Exception as e:
+ logger.error(f"Error formatting training results: {e}")
+ return {"success": False, "error": str(e)}
+
+ def _format_validation_results(self, results) -> Dict[str, Any]:
+ """Format validation results into dictionary."""
+ try:
+ box_metrics = results.box
+
+ formatted = {
+ "success": True,
+ "mAP50": float(box_metrics.map50),
+ "mAP50-95": float(box_metrics.map),
+ "precision": float(box_metrics.mp),
+ "recall": float(box_metrics.mr),
+ "fitness": (
+ float(results.fitness) if hasattr(results, "fitness") else 0.0
+ ),
+ }
+
+ # Add per-class metrics if available
+ if hasattr(box_metrics, "ap") and hasattr(results, "names"):
+ class_metrics = {}
+ for idx, name in results.names.items():
+ if idx < len(box_metrics.ap):
+ class_metrics[name] = {
+ "ap": float(box_metrics.ap[idx]),
+ "ap50": (
+ float(box_metrics.ap50[idx])
+ if hasattr(box_metrics, "ap50")
+ else 0.0
+ ),
+ }
+ formatted["class_metrics"] = class_metrics
+
+ return formatted
+
+ except Exception as e:
+ logger.error(f"Error formatting validation results: {e}")
+ return {"success": False, "error": str(e)}
+
+ def _format_prediction_results(self, results) -> List[Dict]:
+ """Format prediction results into list of dictionaries."""
+ detections = []
+
+ try:
+ for result in results:
+ boxes = result.boxes
+ image_path = str(result.path)
+ orig_shape = result.orig_shape # (height, width)
+
+ for i in range(len(boxes)):
+ # Get normalized coordinates
+ xyxyn = boxes.xyxyn[i].cpu().numpy() # Normalized [x1, y1, x2, y2]
+
+ detection = {
+ "image_path": image_path,
+ "class_id": int(boxes.cls[i]),
+ "class_name": result.names[int(boxes.cls[i])],
+ "confidence": float(boxes.conf[i]),
+ "bbox_normalized": [
+ float(v) for v in xyxyn
+ ], # [x_min, y_min, x_max, y_max]
+ "bbox_absolute": [
+ float(v) for v in boxes.xyxy[i].cpu().numpy()
+ ], # Absolute pixels
+ }
+ detections.append(detection)
+
+ return detections
+
+ except Exception as e:
+ logger.error(f"Error formatting prediction results: {e}")
+ return []
+
+ @staticmethod
+ def convert_bbox_format(
+ bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy"
+ ) -> List[float]:
+ """
+ Convert bounding box between formats.
+
+ Formats:
+ - xywh: [x_center, y_center, width, height]
+ - xyxy: [x_min, y_min, x_max, y_max]
+
+ Args:
+ bbox: Bounding box coordinates
+ format_from: Source format
+ format_to: Target format
+
+ Returns:
+ Converted bounding box
+ """
+ if format_from == "xywh" and format_to == "xyxy":
+ x, y, w, h = bbox
+ return [x - w / 2, y - h / 2, x + w / 2, y + h / 2]
+ elif format_from == "xyxy" and format_to == "xywh":
+ x1, y1, x2, y2 = bbox
+ return [(x1 + x2) / 2, (y1 + y2) / 2, x2 - x1, y2 - y1]
+ else:
+ return bbox
+
+ def get_model_info(self) -> Dict[str, Any]:
+ """
+ Get information about the loaded model.
+
+ Returns:
+ Dictionary with model information
+ """
+ if self.model is None:
+ return {"error": "Model not loaded"}
+
+ try:
+ info = {
+ "model_path": self.model_path,
+ "device": self.device,
+ "task": getattr(self.model, "task", "unknown"),
+ }
+
+ # Try to get class names
+ if hasattr(self.model, "names"):
+ info["classes"] = self.model.names
+ info["num_classes"] = len(self.model.names)
+
+ return info
+
+ except Exception as e:
+ logger.error(f"Error getting model info: {e}")
+ return {"error": str(e)}
diff --git a/src/utils/__init__.py b/src/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/utils/config_manager.py b/src/utils/config_manager.py
new file mode 100644
index 0000000..385d2a9
--- /dev/null
+++ b/src/utils/config_manager.py
@@ -0,0 +1,218 @@
+"""
+Configuration manager for the microscopy object detection application.
+Handles loading, saving, and accessing application configuration.
+"""
+
+import yaml
+from pathlib import Path
+from typing import Any, Dict, Optional
+from src.utils.logger import get_logger
+
+
+logger = get_logger(__name__)
+
+
+class ConfigManager:
+ """Manages application configuration."""
+
+ def __init__(self, config_path: str = "config/app_config.yaml"):
+ """
+ Initialize configuration manager.
+
+ Args:
+ config_path: Path to configuration file
+ """
+ self.config_path = Path(config_path)
+ self.config: Dict[str, Any] = {}
+ self._load_config()
+
+ def _load_config(self) -> None:
+ """Load configuration from YAML file."""
+ try:
+ if self.config_path.exists():
+ with open(self.config_path, "r") as f:
+ self.config = yaml.safe_load(f) or {}
+ logger.info(f"Configuration loaded from {self.config_path}")
+ else:
+ logger.warning(f"Configuration file not found: {self.config_path}")
+ self._create_default_config()
+ except Exception as e:
+ logger.error(f"Error loading configuration: {e}")
+ self._create_default_config()
+
+ def _create_default_config(self) -> None:
+ """Create default configuration."""
+ self.config = {
+ "database": {"path": "data/detections.db"},
+ "image_repository": {
+ "base_path": "",
+ "allowed_extensions": [
+ ".jpg",
+ ".jpeg",
+ ".png",
+ ".tif",
+ ".tiff",
+ ".bmp",
+ ],
+ },
+ "models": {
+ "default_base_model": "yolov8s.pt",
+ "models_directory": "data/models",
+ },
+ "training": {
+ "default_epochs": 100,
+ "default_batch_size": 16,
+ "default_imgsz": 640,
+ "default_patience": 50,
+ "default_lr0": 0.01,
+ },
+ "detection": {
+ "default_confidence": 0.25,
+ "default_iou": 0.45,
+ "max_batch_size": 100,
+ },
+ "visualization": {
+ "bbox_colors": {
+ "organelle": "#FF6B6B",
+ "membrane_branch": "#4ECDC4",
+ "default": "#00FF00",
+ },
+ "bbox_thickness": 2,
+ "font_size": 12,
+ },
+ "export": {"formats": ["csv", "json", "excel"], "default_format": "csv"},
+ "logging": {
+ "level": "INFO",
+ "file": "logs/app.log",
+ "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+ },
+ }
+ self.save_config()
+
+ def save_config(self) -> bool:
+ """
+ Save current configuration to file.
+
+ Returns:
+ True if successful, False otherwise
+ """
+ try:
+ # Create directory if it doesn't exist
+ self.config_path.parent.mkdir(parents=True, exist_ok=True)
+
+ with open(self.config_path, "w") as f:
+ yaml.dump(self.config, f, default_flow_style=False, sort_keys=False)
+
+ logger.info(f"Configuration saved to {self.config_path}")
+ return True
+ except Exception as e:
+ logger.error(f"Error saving configuration: {e}")
+ return False
+
+ def get(self, key: str, default: Any = None) -> Any:
+ """
+ Get configuration value by key.
+
+ Args:
+ key: Configuration key (can use dot notation, e.g., 'database.path')
+ default: Default value if key not found
+
+ Returns:
+ Configuration value or default
+ """
+ keys = key.split(".")
+ value = self.config
+
+ for k in keys:
+ if isinstance(value, dict) and k in value:
+ value = value[k]
+ else:
+ return default
+
+ return value
+
+ def set(self, key: str, value: Any) -> None:
+ """
+ Set configuration value by key.
+
+ Args:
+ key: Configuration key (can use dot notation)
+ value: Value to set
+ """
+ keys = key.split(".")
+ config = self.config
+
+ # Navigate to the nested dictionary
+ for k in keys[:-1]:
+ if k not in config:
+ config[k] = {}
+ config = config[k]
+
+ # Set the value
+ config[keys[-1]] = value
+ logger.debug(f"Configuration updated: {key} = {value}")
+
+ def get_section(self, section: str) -> Dict[str, Any]:
+ """
+ Get entire configuration section.
+
+ Args:
+ section: Section name (e.g., 'database', 'training')
+
+ Returns:
+ Dictionary with section configuration
+ """
+ return self.config.get(section, {})
+
+ def update_section(self, section: str, values: Dict[str, Any]) -> None:
+ """
+ Update entire configuration section.
+
+ Args:
+ section: Section name
+ values: Dictionary with new values
+ """
+ if section not in self.config:
+ self.config[section] = {}
+
+ self.config[section].update(values)
+ logger.debug(f"Configuration section updated: {section}")
+
+ def reload(self) -> None:
+ """Reload configuration from file."""
+ self._load_config()
+
+ def get_database_path(self) -> str:
+ """Get database path."""
+ return self.get("database.path", "data/detections.db")
+
+ def get_image_repository_path(self) -> str:
+ """Get image repository base path."""
+ return self.get("image_repository.base_path", "")
+
+ def set_image_repository_path(self, path: str) -> None:
+ """Set image repository base path."""
+ self.set("image_repository.base_path", path)
+ self.save_config()
+
+ def get_models_directory(self) -> str:
+ """Get models directory path."""
+ return self.get("models.models_directory", "data/models")
+
+ def get_default_training_params(self) -> Dict[str, Any]:
+ """Get default training parameters."""
+ return self.get_section("training")
+
+ def get_default_detection_params(self) -> Dict[str, Any]:
+ """Get default detection parameters."""
+ return self.get_section("detection")
+
+ def get_bbox_colors(self) -> Dict[str, str]:
+ """Get bounding box colors for different classes."""
+ return self.get("visualization.bbox_colors", {})
+
+ def get_allowed_extensions(self) -> list:
+ """Get list of allowed image file extensions."""
+ return self.get(
+ "image_repository.allowed_extensions", [".jpg", ".jpeg", ".png"]
+ )
diff --git a/src/utils/file_utils.py b/src/utils/file_utils.py
new file mode 100644
index 0000000..019852e
--- /dev/null
+++ b/src/utils/file_utils.py
@@ -0,0 +1,235 @@
+"""
+File utility functions for the microscopy object detection application.
+"""
+
+import os
+from pathlib import Path
+from typing import List, Optional
+from src.utils.logger import get_logger
+
+
+logger = get_logger(__name__)
+
+
+def get_image_files(
+ directory: str,
+ allowed_extensions: Optional[List[str]] = None,
+ recursive: bool = False,
+) -> List[str]:
+ """
+ Get all image files in a directory.
+
+ Args:
+ directory: Directory path to search
+ allowed_extensions: List of allowed file extensions (e.g., ['.jpg', '.png'])
+ recursive: Whether to search recursively
+
+ Returns:
+ List of absolute paths to image files
+ """
+ if allowed_extensions is None:
+ allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
+
+ # Normalize extensions to lowercase
+ allowed_extensions = [ext.lower() for ext in allowed_extensions]
+
+ image_files = []
+ directory_path = Path(directory)
+
+ if not directory_path.exists():
+ logger.error(f"Directory does not exist: {directory}")
+ return image_files
+
+ try:
+ if recursive:
+ # Recursive search
+ for ext in allowed_extensions:
+ image_files.extend(directory_path.rglob(f"*{ext}"))
+ # Also search uppercase extensions
+ image_files.extend(directory_path.rglob(f"*{ext.upper()}"))
+ else:
+ # Top-level search only
+ for ext in allowed_extensions:
+ image_files.extend(directory_path.glob(f"*{ext}"))
+ # Also search uppercase extensions
+ image_files.extend(directory_path.glob(f"*{ext.upper()}"))
+
+ # Convert to absolute paths and sort
+ image_files = sorted([str(f.absolute()) for f in image_files])
+ logger.info(f"Found {len(image_files)} image files in {directory}")
+
+ except Exception as e:
+ logger.error(f"Error searching for images: {e}")
+
+ return image_files
+
+
+def ensure_directory(directory: str) -> bool:
+ """
+ Ensure a directory exists, create if it doesn't.
+
+ Args:
+ directory: Directory path
+
+ Returns:
+ True if directory exists or was created successfully
+ """
+ try:
+ Path(directory).mkdir(parents=True, exist_ok=True)
+ return True
+ except Exception as e:
+ logger.error(f"Error creating directory {directory}: {e}")
+ return False
+
+
+def get_relative_path(file_path: str, base_path: str) -> str:
+ """
+ Get relative path from base path.
+
+ Args:
+ file_path: Absolute file path
+ base_path: Base directory path
+
+ Returns:
+ Relative path string
+ """
+ try:
+ return str(Path(file_path).relative_to(base_path))
+ except ValueError:
+ # If file_path is not relative to base_path, return the filename
+ return Path(file_path).name
+
+
+def validate_file_path(file_path: str, must_exist: bool = True) -> bool:
+ """
+ Validate a file path.
+
+ Args:
+ file_path: Path to validate
+ must_exist: Whether the file must exist
+
+ Returns:
+ True if valid, False otherwise
+ """
+ path = Path(file_path)
+
+ if must_exist and not path.exists():
+ logger.error(f"File does not exist: {file_path}")
+ return False
+
+ if must_exist and not path.is_file():
+ logger.error(f"Path is not a file: {file_path}")
+ return False
+
+ return True
+
+
+def get_file_size(file_path: str) -> int:
+ """
+ Get file size in bytes.
+
+ Args:
+ file_path: Path to file
+
+ Returns:
+ File size in bytes, or 0 if error
+ """
+ try:
+ return Path(file_path).stat().st_size
+ except Exception as e:
+ logger.error(f"Error getting file size for {file_path}: {e}")
+ return 0
+
+
+def format_file_size(size_bytes: int) -> str:
+ """
+ Format file size in human-readable format.
+
+ Args:
+ size_bytes: Size in bytes
+
+ Returns:
+ Formatted string (e.g., "1.5 MB")
+ """
+ for unit in ["B", "KB", "MB", "GB"]:
+ if size_bytes < 1024.0:
+ return f"{size_bytes:.1f} {unit}"
+ size_bytes /= 1024.0
+ return f"{size_bytes:.1f} TB"
+
+
+def create_unique_filename(directory: str, base_name: str, extension: str) -> str:
+ """
+ Create a unique filename by adding a number suffix if file exists.
+
+ Args:
+ directory: Directory path
+ base_name: Base filename without extension
+ extension: File extension (with or without dot)
+
+ Returns:
+ Unique filename
+ """
+ if not extension.startswith("."):
+ extension = "." + extension
+
+ directory_path = Path(directory)
+ filename = f"{base_name}{extension}"
+ file_path = directory_path / filename
+
+ if not file_path.exists():
+ return filename
+
+ # Add number suffix
+ counter = 1
+ while True:
+ filename = f"{base_name}_{counter}{extension}"
+ file_path = directory_path / filename
+ if not file_path.exists():
+ return filename
+ counter += 1
+
+
+def is_image_file(
+ file_path: str, allowed_extensions: Optional[List[str]] = None
+) -> bool:
+ """
+ Check if a file is an image based on extension.
+
+ Args:
+ file_path: Path to file
+ allowed_extensions: List of allowed extensions
+
+ Returns:
+ True if file is an image
+ """
+ if allowed_extensions is None:
+ allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
+
+ extension = Path(file_path).suffix.lower()
+ return extension in [ext.lower() for ext in allowed_extensions]
+
+
+def safe_filename(filename: str) -> str:
+ """
+ Convert a string to a safe filename by removing/replacing invalid characters.
+
+ Args:
+ filename: Original filename
+
+ Returns:
+ Safe filename
+ """
+ # Replace invalid characters
+ invalid_chars = '<>:"/\\|?*'
+ for char in invalid_chars:
+ filename = filename.replace(char, "_")
+
+ # Remove leading/trailing spaces and dots
+ filename = filename.strip(". ")
+
+ # Ensure filename is not empty
+ if not filename:
+ filename = "unnamed"
+
+ return filename
diff --git a/src/utils/logger.py b/src/utils/logger.py
new file mode 100644
index 0000000..5faf6df
--- /dev/null
+++ b/src/utils/logger.py
@@ -0,0 +1,75 @@
+"""
+Logging configuration for the microscopy object detection application.
+"""
+
+import logging
+import sys
+from pathlib import Path
+from typing import Optional
+
+
+def setup_logging(
+ log_file: str = "logs/app.log",
+ level: str = "INFO",
+ log_format: Optional[str] = None,
+) -> logging.Logger:
+ """
+ Setup application logging.
+
+ Args:
+ log_file: Path to log file
+ level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
+ log_format: Custom log format string
+
+ Returns:
+ Configured logger instance
+ """
+ # Create logs directory if it doesn't exist
+ log_path = Path(log_file)
+ log_path.parent.mkdir(parents=True, exist_ok=True)
+
+ # Default format if none provided
+ if log_format is None:
+ log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+
+ # Convert level string to logging constant
+ numeric_level = getattr(logging, level.upper(), logging.INFO)
+
+ # Configure root logger
+ root_logger = logging.getLogger()
+ root_logger.setLevel(numeric_level)
+
+ # Remove existing handlers
+ root_logger.handlers.clear()
+
+ # Console handler
+ console_handler = logging.StreamHandler(sys.stdout)
+ console_handler.setLevel(numeric_level)
+ console_formatter = logging.Formatter(log_format)
+ console_handler.setFormatter(console_formatter)
+ root_logger.addHandler(console_handler)
+
+ # File handler
+ file_handler = logging.FileHandler(log_file)
+ file_handler.setLevel(numeric_level)
+ file_formatter = logging.Formatter(log_format)
+ file_handler.setFormatter(file_formatter)
+ root_logger.addHandler(file_handler)
+
+ # Log initial message
+ root_logger.info("Logging initialized")
+
+ return root_logger
+
+
+def get_logger(name: str) -> logging.Logger:
+ """
+ Get a logger instance for a specific module.
+
+ Args:
+ name: Logger name (typically __name__)
+
+ Returns:
+ Logger instance
+ """
+ return logging.getLogger(name)
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29