From 6bd2b100caa2d27a0c5d1273a302574bfbc0b156 Mon Sep 17 00:00:00 2001 From: Martin Laasmaa Date: Fri, 5 Dec 2025 09:50:50 +0200 Subject: [PATCH] Adding python files --- src/__init__.py | 0 src/database/__init__.py | 0 src/database/db_manager.py | 619 +++++++++++++++++++++++++++++++ src/database/models.py | 63 ++++ src/database/schema.sql | 70 ++++ src/gui/__init__.py | 0 src/gui/dialogs/__init__.py | 0 src/gui/dialogs/config_dialog.py | 291 +++++++++++++++ src/gui/main_window.py | 282 ++++++++++++++ src/gui/tabs/__init__.py | 0 src/gui/tabs/annotation_tab.py | 48 +++ src/gui/tabs/detection_tab.py | 344 +++++++++++++++++ src/gui/tabs/results_tab.py | 46 +++ src/gui/tabs/training_tab.py | 52 +++ src/gui/tabs/validation_tab.py | 46 +++ src/gui/widgets/__init__.py | 0 src/model/__init__.py | 0 src/model/inference.py | 323 ++++++++++++++++ src/model/yolo_wrapper.py | 364 ++++++++++++++++++ src/utils/__init__.py | 0 src/utils/config_manager.py | 218 +++++++++++ src/utils/file_utils.py | 235 ++++++++++++ src/utils/logger.py | 75 ++++ tests/__init__.py | 0 24 files changed, 3076 insertions(+) create mode 100644 src/__init__.py create mode 100644 src/database/__init__.py create mode 100644 src/database/db_manager.py create mode 100644 src/database/models.py create mode 100644 src/database/schema.sql create mode 100644 src/gui/__init__.py create mode 100644 src/gui/dialogs/__init__.py create mode 100644 src/gui/dialogs/config_dialog.py create mode 100644 src/gui/main_window.py create mode 100644 src/gui/tabs/__init__.py create mode 100644 src/gui/tabs/annotation_tab.py create mode 100644 src/gui/tabs/detection_tab.py create mode 100644 src/gui/tabs/results_tab.py create mode 100644 src/gui/tabs/training_tab.py create mode 100644 src/gui/tabs/validation_tab.py create mode 100644 src/gui/widgets/__init__.py create mode 100644 src/model/__init__.py create mode 100644 src/model/inference.py create mode 100644 src/model/yolo_wrapper.py create mode 100644 src/utils/__init__.py create mode 100644 src/utils/config_manager.py create mode 100644 src/utils/file_utils.py create mode 100644 src/utils/logger.py create mode 100644 tests/__init__.py diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/database/__init__.py b/src/database/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/database/db_manager.py b/src/database/db_manager.py new file mode 100644 index 0000000..4329499 --- /dev/null +++ b/src/database/db_manager.py @@ -0,0 +1,619 @@ +""" +Database manager for the microscopy object detection application. +Handles all database operations including CRUD operations, queries, and exports. +""" + +import sqlite3 +import json +from datetime import datetime +from typing import List, Dict, Optional, Tuple, Any +from pathlib import Path +import csv +import hashlib + + +class DatabaseManager: + """Manages all database operations for the application.""" + + def __init__(self, db_path: str = "data/detections.db"): + """ + Initialize database manager. + + Args: + db_path: Path to SQLite database file + """ + self.db_path = db_path + self._ensure_database_exists() + + def _ensure_database_exists(self) -> None: + """Create database and tables if they don't exist.""" + # Create directory if it doesn't exist + Path(self.db_path).parent.mkdir(parents=True, exist_ok=True) + + # Read schema file and execute + schema_path = Path(__file__).parent / "schema.sql" + with open(schema_path, "r") as f: + schema_sql = f.read() + + conn = self.get_connection() + try: + conn.executescript(schema_sql) + conn.commit() + finally: + conn.close() + + def get_connection(self) -> sqlite3.Connection: + """Get database connection with proper settings.""" + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row # Enable column access by name + conn.execute("PRAGMA foreign_keys = ON") # Enable foreign keys + return conn + + # ==================== Model Operations ==================== + + def add_model( + self, + model_name: str, + model_version: str, + model_path: str, + base_model: str = "yolov8s.pt", + training_params: Optional[Dict] = None, + metrics: Optional[Dict] = None, + ) -> int: + """ + Add a new model to the database. + + Args: + model_name: Name of the model + model_version: Version string + model_path: Path to model weights file + base_model: Base model used for training + training_params: Dictionary of training parameters + metrics: Dictionary of validation metrics + + Returns: + ID of the inserted model + """ + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute( + """ + INSERT INTO models (model_name, model_version, model_path, base_model, training_params, metrics) + VALUES (?, ?, ?, ?, ?, ?) + """, + ( + model_name, + model_version, + model_path, + base_model, + json.dumps(training_params) if training_params else None, + json.dumps(metrics) if metrics else None, + ), + ) + conn.commit() + return cursor.lastrowid + finally: + conn.close() + + def get_models(self, filters: Optional[Dict] = None) -> List[Dict]: + """ + Retrieve models from database. + + Args: + filters: Optional filters (e.g., {'model_name': 'my_model'}) + + Returns: + List of model dictionaries + """ + conn = self.get_connection() + try: + query = "SELECT * FROM models" + params = [] + + if filters: + conditions = [] + for key, value in filters.items(): + conditions.append(f"{key} = ?") + params.append(value) + query += " WHERE " + " AND ".join(conditions) + + query += " ORDER BY created_at DESC" + + cursor = conn.cursor() + cursor.execute(query, params) + + models = [] + for row in cursor.fetchall(): + model = dict(row) + # Parse JSON fields + if model["training_params"]: + model["training_params"] = json.loads(model["training_params"]) + if model["metrics"]: + model["metrics"] = json.loads(model["metrics"]) + models.append(model) + + return models + finally: + conn.close() + + def get_model_by_id(self, model_id: int) -> Optional[Dict]: + """Get model by ID.""" + models = self.get_models({"id": model_id}) + return models[0] if models else None + + def update_model(self, model_id: int, updates: Dict) -> bool: + """Update model fields.""" + conn = self.get_connection() + try: + # Build update query + set_clauses = [] + params = [] + for key, value in updates.items(): + if key in ["training_params", "metrics"] and isinstance(value, dict): + value = json.dumps(value) + set_clauses.append(f"{key} = ?") + params.append(value) + + params.append(model_id) + + query = f"UPDATE models SET {', '.join(set_clauses)} WHERE id = ?" + cursor = conn.cursor() + cursor.execute(query, params) + conn.commit() + return cursor.rowcount > 0 + finally: + conn.close() + + # ==================== Image Operations ==================== + + def add_image( + self, + relative_path: str, + filename: str, + width: int, + height: int, + captured_at: Optional[datetime] = None, + checksum: Optional[str] = None, + ) -> int: + """ + Add a new image to the database. + + Args: + relative_path: Path relative to image repository + filename: Image filename + width: Image width in pixels + height: Image height in pixels + captured_at: When image was captured (if known) + checksum: MD5 checksum of image file + + Returns: + ID of the inserted image + """ + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute( + """ + INSERT INTO images (relative_path, filename, width, height, captured_at, checksum) + VALUES (?, ?, ?, ?, ?, ?) + """, + (relative_path, filename, width, height, captured_at, checksum), + ) + conn.commit() + return cursor.lastrowid + except sqlite3.IntegrityError: + # Image already exists, return its ID + cursor.execute( + "SELECT id FROM images WHERE relative_path = ?", (relative_path,) + ) + row = cursor.fetchone() + return row["id"] if row else None + finally: + conn.close() + + def get_image_by_path(self, relative_path: str) -> Optional[Dict]: + """Get image by relative path.""" + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM images WHERE relative_path = ?", (relative_path,) + ) + row = cursor.fetchone() + return dict(row) if row else None + finally: + conn.close() + + def get_or_create_image( + self, relative_path: str, filename: str, width: int, height: int + ) -> int: + """Get existing image or create new one.""" + existing = self.get_image_by_path(relative_path) + if existing: + return existing["id"] + return self.add_image(relative_path, filename, width, height) + + # ==================== Detection Operations ==================== + + def add_detection( + self, + image_id: int, + model_id: int, + class_name: str, + bbox: Tuple[float, float, float, float], # (x_min, y_min, x_max, y_max) + confidence: float, + metadata: Optional[Dict] = None, + ) -> int: + """ + Add a new detection to the database. + + Args: + image_id: ID of the image + model_id: ID of the model used + class_name: Detected object class + bbox: Bounding box coordinates (normalized 0-1) + confidence: Detection confidence score + metadata: Additional metadata + + Returns: + ID of the inserted detection + """ + conn = self.get_connection() + try: + cursor = conn.cursor() + x_min, y_min, x_max, y_max = bbox + cursor.execute( + """ + INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, metadata) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + image_id, + model_id, + class_name, + x_min, + y_min, + x_max, + y_max, + confidence, + json.dumps(metadata) if metadata else None, + ), + ) + conn.commit() + return cursor.lastrowid + finally: + conn.close() + + def add_detections_batch(self, detections: List[Dict]) -> int: + """ + Add multiple detections in a single transaction. + + Args: + detections: List of detection dictionaries + + Returns: + Number of detections inserted + """ + conn = self.get_connection() + try: + cursor = conn.cursor() + for det in detections: + bbox = det["bbox"] + cursor.execute( + """ + INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, metadata) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + det["image_id"], + det["model_id"], + det["class_name"], + bbox[0], + bbox[1], + bbox[2], + bbox[3], + det["confidence"], + ( + json.dumps(det.get("metadata")) + if det.get("metadata") + else None + ), + ), + ) + conn.commit() + return len(detections) + finally: + conn.close() + + def get_detections( + self, + filters: Optional[Dict] = None, + limit: Optional[int] = None, + offset: int = 0, + ) -> List[Dict]: + """ + Retrieve detections from database. + + Args: + filters: Optional filters for querying + limit: Maximum number of results + offset: Number of results to skip + + Returns: + List of detection dictionaries with joined data + """ + conn = self.get_connection() + try: + query = """ + SELECT + d.*, + i.relative_path as image_path, + i.filename as image_filename, + i.width as image_width, + i.height as image_height, + m.model_name, + m.model_version + FROM detections d + JOIN images i ON d.image_id = i.id + JOIN models m ON d.model_id = m.id + """ + params = [] + + if filters: + conditions = [] + for key, value in filters.items(): + if ( + key.startswith("d.") + or key.startswith("i.") + or key.startswith("m.") + ): + conditions.append(f"{key} = ?") + else: + conditions.append(f"d.{key} = ?") + params.append(value) + query += " WHERE " + " AND ".join(conditions) + + query += " ORDER BY d.detected_at DESC" + + if limit: + query += f" LIMIT {limit} OFFSET {offset}" + + cursor = conn.cursor() + cursor.execute(query, params) + + detections = [] + for row in cursor.fetchall(): + det = dict(row) + # Parse JSON metadata + if det.get("metadata"): + det["metadata"] = json.loads(det["metadata"]) + detections.append(det) + + return detections + finally: + conn.close() + + def get_detections_for_image( + self, image_id: int, model_id: Optional[int] = None + ) -> List[Dict]: + """Get all detections for a specific image.""" + filters = {"image_id": image_id} + if model_id: + filters["model_id"] = model_id + return self.get_detections(filters) + + def delete_detections_for_model(self, model_id: int) -> int: + """Delete all detections for a specific model.""" + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute("DELETE FROM detections WHERE model_id = ?", (model_id,)) + conn.commit() + return cursor.rowcount + finally: + conn.close() + + # ==================== Statistics Operations ==================== + + def get_detection_statistics( + self, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None + ) -> Dict: + """ + Get detection statistics for a date range. + + Returns: + Dictionary with statistics (count by class, confidence distribution, etc.) + """ + conn = self.get_connection() + try: + cursor = conn.cursor() + + # Build date filter + date_filter = "" + params = [] + if start_date: + date_filter += " AND detected_at >= ?" + params.append(start_date) + if end_date: + date_filter += " AND detected_at <= ?" + params.append(end_date) + + # Total detections + cursor.execute( + f"SELECT COUNT(*) as count FROM detections WHERE 1=1{date_filter}", + params, + ) + total_count = cursor.fetchone()["count"] + + # Count by class + cursor.execute( + f""" + SELECT class_name, COUNT(*) as count + FROM detections + WHERE 1=1{date_filter} + GROUP BY class_name + ORDER BY count DESC + """, + params, + ) + class_counts = { + row["class_name"]: row["count"] for row in cursor.fetchall() + } + + # Average confidence + cursor.execute( + f"SELECT AVG(confidence) as avg_conf FROM detections WHERE 1=1{date_filter}", + params, + ) + avg_confidence = cursor.fetchone()["avg_conf"] or 0 + + # Confidence distribution + cursor.execute( + f""" + SELECT + CASE + WHEN confidence < 0.3 THEN 'low' + WHEN confidence < 0.7 THEN 'medium' + ELSE 'high' + END as conf_level, + COUNT(*) as count + FROM detections + WHERE 1=1{date_filter} + GROUP BY conf_level + """, + params, + ) + conf_dist = {row["conf_level"]: row["count"] for row in cursor.fetchall()} + + return { + "total_detections": total_count, + "class_counts": class_counts, + "average_confidence": avg_confidence, + "confidence_distribution": conf_dist, + } + finally: + conn.close() + + def get_class_distribution(self, model_id: Optional[int] = None) -> Dict[str, int]: + """Get count of detections per class.""" + conn = self.get_connection() + try: + cursor = conn.cursor() + query = "SELECT class_name, COUNT(*) as count FROM detections" + params = [] + + if model_id: + query += " WHERE model_id = ?" + params.append(model_id) + + query += " GROUP BY class_name ORDER BY count DESC" + + cursor.execute(query, params) + return {row["class_name"]: row["count"] for row in cursor.fetchall()} + finally: + conn.close() + + # ==================== Export Operations ==================== + + def export_detections_to_csv( + self, output_path: str, filters: Optional[Dict] = None + ) -> bool: + """Export detections to CSV file.""" + try: + detections = self.get_detections(filters) + + with open(output_path, "w", newline="") as csvfile: + if not detections: + return True + + fieldnames = [ + "id", + "image_path", + "model_name", + "model_version", + "class_name", + "x_min", + "y_min", + "x_max", + "y_max", + "confidence", + "detected_at", + ] + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + + for det in detections: + row = {k: det[k] for k in fieldnames if k in det} + writer.writerow(row) + + return True + except Exception as e: + print(f"Error exporting to CSV: {e}") + return False + + def export_detections_to_json( + self, output_path: str, filters: Optional[Dict] = None + ) -> bool: + """Export detections to JSON file.""" + try: + detections = self.get_detections(filters) + + # Convert datetime objects to strings + for det in detections: + if isinstance(det.get("detected_at"), datetime): + det["detected_at"] = det["detected_at"].isoformat() + + with open(output_path, "w") as jsonfile: + json.dump(detections, jsonfile, indent=2) + + return True + except Exception as e: + print(f"Error exporting to JSON: {e}") + return False + + # ==================== Annotation Operations ==================== + + def add_annotation( + self, + image_id: int, + class_name: str, + bbox: Tuple[float, float, float, float], + annotator: str, + verified: bool = False, + ) -> int: + """Add manual annotation.""" + conn = self.get_connection() + try: + cursor = conn.cursor() + x_min, y_min, x_max, y_max = bbox + cursor.execute( + """ + INSERT INTO annotations (image_id, class_name, x_min, y_min, x_max, y_max, annotator, verified) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + (image_id, class_name, x_min, y_min, x_max, y_max, annotator, verified), + ) + conn.commit() + return cursor.lastrowid + finally: + conn.close() + + def get_annotations_for_image(self, image_id: int) -> List[Dict]: + """Get all annotations for an image.""" + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute("SELECT * FROM annotations WHERE image_id = ?", (image_id,)) + return [dict(row) for row in cursor.fetchall()] + finally: + conn.close() + + @staticmethod + def calculate_checksum(file_path: str) -> str: + """Calculate MD5 checksum of a file.""" + md5_hash = hashlib.md5() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + md5_hash.update(chunk) + return md5_hash.hexdigest() diff --git a/src/database/models.py b/src/database/models.py new file mode 100644 index 0000000..cdcd237 --- /dev/null +++ b/src/database/models.py @@ -0,0 +1,63 @@ +""" +Data models for the microscopy object detection application. +These dataclasses represent the database entities. +""" + +from dataclasses import dataclass +from datetime import datetime +from typing import Optional, Dict, Tuple + + +@dataclass +class Model: + """Represents a trained model.""" + + id: Optional[int] + model_name: str + model_version: str + model_path: str + base_model: str + created_at: datetime + training_params: Optional[Dict] + metrics: Optional[Dict] + + +@dataclass +class Image: + """Represents an image in the database.""" + + id: Optional[int] + relative_path: str + filename: str + width: int + height: int + captured_at: Optional[datetime] + added_at: datetime + checksum: Optional[str] + + +@dataclass +class Detection: + """Represents a detection result.""" + + id: Optional[int] + image_id: int + model_id: int + class_name: str + bbox: Tuple[float, float, float, float] # (x_min, y_min, x_max, y_max) + confidence: float + detected_at: datetime + metadata: Optional[Dict] + + +@dataclass +class Annotation: + """Represents a manual annotation.""" + + id: Optional[int] + image_id: int + class_name: str + bbox: Tuple[float, float, float, float] # (x_min, y_min, x_max, y_max) + annotator: str + created_at: datetime + verified: bool diff --git a/src/database/schema.sql b/src/database/schema.sql new file mode 100644 index 0000000..f6080f7 --- /dev/null +++ b/src/database/schema.sql @@ -0,0 +1,70 @@ +-- Microscopy Object Detection Application - Database Schema +-- SQLite Database Schema for storing models, images, detections, and annotations + +-- Models table: stores trained model information +CREATE TABLE IF NOT EXISTS models ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + model_name TEXT NOT NULL, + model_version TEXT NOT NULL, + model_path TEXT NOT NULL, + base_model TEXT NOT NULL DEFAULT 'yolov8s.pt', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + training_params TEXT, -- JSON string of training parameters + metrics TEXT, -- JSON string of validation metrics + UNIQUE(model_name, model_version) +); + +-- Images table: stores image metadata +CREATE TABLE IF NOT EXISTS images ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + relative_path TEXT NOT NULL UNIQUE, + filename TEXT NOT NULL, + width INTEGER, + height INTEGER, + captured_at TIMESTAMP, + added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + checksum TEXT +); + +-- Detections table: stores detection results +CREATE TABLE IF NOT EXISTS detections ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + image_id INTEGER NOT NULL, + model_id INTEGER NOT NULL, + class_name TEXT NOT NULL, + x_min REAL NOT NULL CHECK(x_min >= 0 AND x_min <= 1), + y_min REAL NOT NULL CHECK(y_min >= 0 AND y_min <= 1), + x_max REAL NOT NULL CHECK(x_max >= 0 AND x_max <= 1), + y_max REAL NOT NULL CHECK(y_max >= 0 AND y_max <= 1), + confidence REAL NOT NULL CHECK(confidence >= 0 AND confidence <= 1), + detected_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + metadata TEXT, -- JSON string for additional metadata + FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE, + FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE +); + +-- Annotations table: stores manual annotations (future feature) +CREATE TABLE IF NOT EXISTS annotations ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + image_id INTEGER NOT NULL, + class_name TEXT NOT NULL, + x_min REAL NOT NULL CHECK(x_min >= 0 AND x_min <= 1), + y_min REAL NOT NULL CHECK(y_min >= 0 AND y_min <= 1), + x_max REAL NOT NULL CHECK(x_max >= 0 AND x_max <= 1), + y_max REAL NOT NULL CHECK(y_max >= 0 AND y_max <= 1), + annotator TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + verified BOOLEAN DEFAULT 0, + FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE +); + +-- Create indexes for performance optimization +CREATE INDEX IF NOT EXISTS idx_detections_image_id ON detections(image_id); +CREATE INDEX IF NOT EXISTS idx_detections_model_id ON detections(model_id); +CREATE INDEX IF NOT EXISTS idx_detections_class_name ON detections(class_name); +CREATE INDEX IF NOT EXISTS idx_detections_detected_at ON detections(detected_at); +CREATE INDEX IF NOT EXISTS idx_detections_confidence ON detections(confidence); +CREATE INDEX IF NOT EXISTS idx_images_relative_path ON images(relative_path); +CREATE INDEX IF NOT EXISTS idx_images_added_at ON images(added_at); +CREATE INDEX IF NOT EXISTS idx_annotations_image_id ON annotations(image_id); +CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at); \ No newline at end of file diff --git a/src/gui/__init__.py b/src/gui/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/gui/dialogs/__init__.py b/src/gui/dialogs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/gui/dialogs/config_dialog.py b/src/gui/dialogs/config_dialog.py new file mode 100644 index 0000000..9abfe1b --- /dev/null +++ b/src/gui/dialogs/config_dialog.py @@ -0,0 +1,291 @@ +""" +Configuration dialog for the microscopy object detection application. +""" + +from PySide6.QtWidgets import ( + QDialog, + QVBoxLayout, + QHBoxLayout, + QFormLayout, + QPushButton, + QLineEdit, + QSpinBox, + QDoubleSpinBox, + QFileDialog, + QTabWidget, + QWidget, + QLabel, + QGroupBox, +) +from PySide6.QtCore import Qt + +from src.utils.config_manager import ConfigManager +from src.utils.logger import get_logger + + +logger = get_logger(__name__) + + +class ConfigDialog(QDialog): + """Configuration dialog window.""" + + def __init__(self, config_manager: ConfigManager, parent=None): + super().__init__(parent) + + self.config_manager = config_manager + + self.setWindowTitle("Settings") + self.setMinimumWidth(500) + self.setMinimumHeight(400) + + self._setup_ui() + self._load_settings() + + def _setup_ui(self): + """Setup user interface.""" + layout = QVBoxLayout() + + # Create tab widget for different setting categories + self.tab_widget = QTabWidget() + + # General settings tab + general_tab = self._create_general_tab() + self.tab_widget.addTab(general_tab, "General") + + # Training settings tab + training_tab = self._create_training_tab() + self.tab_widget.addTab(training_tab, "Training") + + # Detection settings tab + detection_tab = self._create_detection_tab() + self.tab_widget.addTab(detection_tab, "Detection") + + layout.addWidget(self.tab_widget) + + # Buttons + button_layout = QHBoxLayout() + button_layout.addStretch() + + self.save_button = QPushButton("Save") + self.save_button.clicked.connect(self.accept) + button_layout.addWidget(self.save_button) + + self.cancel_button = QPushButton("Cancel") + self.cancel_button.clicked.connect(self.reject) + button_layout.addWidget(self.cancel_button) + + layout.addLayout(button_layout) + + self.setLayout(layout) + + def _create_general_tab(self) -> QWidget: + """Create general settings tab.""" + widget = QWidget() + layout = QVBoxLayout() + + # Image repository group + repo_group = QGroupBox("Image Repository") + repo_layout = QFormLayout() + + # Repository path + path_layout = QHBoxLayout() + self.repo_path_edit = QLineEdit() + self.repo_path_edit.setPlaceholderText("Path to image repository") + path_layout.addWidget(self.repo_path_edit) + + browse_button = QPushButton("Browse...") + browse_button.clicked.connect(self._browse_repository) + path_layout.addWidget(browse_button) + + repo_layout.addRow("Base Path:", path_layout) + repo_group.setLayout(repo_layout) + layout.addWidget(repo_group) + + # Database group + db_group = QGroupBox("Database") + db_layout = QFormLayout() + + self.db_path_edit = QLineEdit() + self.db_path_edit.setPlaceholderText("Path to database file") + db_layout.addRow("Database Path:", self.db_path_edit) + + db_group.setLayout(db_layout) + layout.addWidget(db_group) + + # Models group + models_group = QGroupBox("Models") + models_layout = QFormLayout() + + self.models_dir_edit = QLineEdit() + self.models_dir_edit.setPlaceholderText("Directory for saved models") + models_layout.addRow("Models Directory:", self.models_dir_edit) + + self.base_model_edit = QLineEdit() + self.base_model_edit.setPlaceholderText("yolov8s.pt") + models_layout.addRow("Default Base Model:", self.base_model_edit) + + models_group.setLayout(models_layout) + layout.addWidget(models_group) + + layout.addStretch() + widget.setLayout(layout) + return widget + + def _create_training_tab(self) -> QWidget: + """Create training settings tab.""" + widget = QWidget() + layout = QVBoxLayout() + + form_layout = QFormLayout() + + # Epochs + self.epochs_spin = QSpinBox() + self.epochs_spin.setRange(1, 1000) + self.epochs_spin.setValue(100) + form_layout.addRow("Default Epochs:", self.epochs_spin) + + # Batch size + self.batch_size_spin = QSpinBox() + self.batch_size_spin.setRange(1, 128) + self.batch_size_spin.setValue(16) + form_layout.addRow("Default Batch Size:", self.batch_size_spin) + + # Image size + self.imgsz_spin = QSpinBox() + self.imgsz_spin.setRange(320, 1280) + self.imgsz_spin.setSingleStep(32) + self.imgsz_spin.setValue(640) + form_layout.addRow("Default Image Size:", self.imgsz_spin) + + # Patience + self.patience_spin = QSpinBox() + self.patience_spin.setRange(1, 200) + self.patience_spin.setValue(50) + form_layout.addRow("Default Patience:", self.patience_spin) + + # Learning rate + self.lr_spin = QDoubleSpinBox() + self.lr_spin.setRange(0.0001, 0.1) + self.lr_spin.setSingleStep(0.001) + self.lr_spin.setDecimals(4) + self.lr_spin.setValue(0.01) + form_layout.addRow("Default Learning Rate:", self.lr_spin) + + layout.addLayout(form_layout) + layout.addStretch() + widget.setLayout(layout) + return widget + + def _create_detection_tab(self) -> QWidget: + """Create detection settings tab.""" + widget = QWidget() + layout = QVBoxLayout() + + form_layout = QFormLayout() + + # Confidence threshold + self.conf_spin = QDoubleSpinBox() + self.conf_spin.setRange(0.0, 1.0) + self.conf_spin.setSingleStep(0.05) + self.conf_spin.setDecimals(2) + self.conf_spin.setValue(0.25) + form_layout.addRow("Default Confidence:", self.conf_spin) + + # IoU threshold + self.iou_spin = QDoubleSpinBox() + self.iou_spin.setRange(0.0, 1.0) + self.iou_spin.setSingleStep(0.05) + self.iou_spin.setDecimals(2) + self.iou_spin.setValue(0.45) + form_layout.addRow("Default IoU:", self.iou_spin) + + # Max batch size + self.max_batch_spin = QSpinBox() + self.max_batch_spin.setRange(1, 1000) + self.max_batch_spin.setValue(100) + form_layout.addRow("Max Batch Size:", self.max_batch_spin) + + layout.addLayout(form_layout) + layout.addStretch() + widget.setLayout(layout) + return widget + + def _browse_repository(self): + """Browse for image repository directory.""" + directory = QFileDialog.getExistingDirectory( + self, "Select Image Repository", self.repo_path_edit.text() + ) + + if directory: + self.repo_path_edit.setText(directory) + + def _load_settings(self): + """Load current settings into dialog.""" + # General settings + self.repo_path_edit.setText( + self.config_manager.get("image_repository.base_path", "") + ) + self.db_path_edit.setText( + self.config_manager.get("database.path", "data/detections.db") + ) + self.models_dir_edit.setText( + self.config_manager.get("models.models_directory", "data/models") + ) + self.base_model_edit.setText( + self.config_manager.get("models.default_base_model", "yolov8s.pt") + ) + + # Training settings + self.epochs_spin.setValue( + self.config_manager.get("training.default_epochs", 100) + ) + self.batch_size_spin.setValue( + self.config_manager.get("training.default_batch_size", 16) + ) + self.imgsz_spin.setValue(self.config_manager.get("training.default_imgsz", 640)) + self.patience_spin.setValue( + self.config_manager.get("training.default_patience", 50) + ) + self.lr_spin.setValue(self.config_manager.get("training.default_lr0", 0.01)) + + # Detection settings + self.conf_spin.setValue( + self.config_manager.get("detection.default_confidence", 0.25) + ) + self.iou_spin.setValue(self.config_manager.get("detection.default_iou", 0.45)) + self.max_batch_spin.setValue( + self.config_manager.get("detection.max_batch_size", 100) + ) + + def accept(self): + """Save settings and close dialog.""" + logger.info("Saving configuration") + + # Save general settings + self.config_manager.set( + "image_repository.base_path", self.repo_path_edit.text() + ) + self.config_manager.set("database.path", self.db_path_edit.text()) + self.config_manager.set("models.models_directory", self.models_dir_edit.text()) + self.config_manager.set( + "models.default_base_model", self.base_model_edit.text() + ) + + # Save training settings + self.config_manager.set("training.default_epochs", self.epochs_spin.value()) + self.config_manager.set( + "training.default_batch_size", self.batch_size_spin.value() + ) + self.config_manager.set("training.default_imgsz", self.imgsz_spin.value()) + self.config_manager.set("training.default_patience", self.patience_spin.value()) + self.config_manager.set("training.default_lr0", self.lr_spin.value()) + + # Save detection settings + self.config_manager.set("detection.default_confidence", self.conf_spin.value()) + self.config_manager.set("detection.default_iou", self.iou_spin.value()) + self.config_manager.set("detection.max_batch_size", self.max_batch_spin.value()) + + # Save to file + self.config_manager.save_config() + + super().accept() diff --git a/src/gui/main_window.py b/src/gui/main_window.py new file mode 100644 index 0000000..71e322b --- /dev/null +++ b/src/gui/main_window.py @@ -0,0 +1,282 @@ +""" +Main window for the microscopy object detection application. +""" + +from PySide6.QtWidgets import ( + QMainWindow, + QTabWidget, + QMenuBar, + QMenu, + QStatusBar, + QMessageBox, + QWidget, + QVBoxLayout, + QLabel, +) +from PySide6.QtCore import Qt, QTimer +from PySide6.QtGui import QAction, QKeySequence + +from src.database.db_manager import DatabaseManager +from src.utils.config_manager import ConfigManager +from src.utils.logger import get_logger +from src.gui.dialogs.config_dialog import ConfigDialog +from src.gui.tabs.detection_tab import DetectionTab +from src.gui.tabs.training_tab import TrainingTab +from src.gui.tabs.validation_tab import ValidationTab +from src.gui.tabs.results_tab import ResultsTab +from src.gui.tabs.annotation_tab import AnnotationTab + + +logger = get_logger(__name__) + + +class MainWindow(QMainWindow): + """Main application window.""" + + def __init__(self): + super().__init__() + + # Initialize managers + self.config_manager = ConfigManager() + + db_path = self.config_manager.get_database_path() + self.db_manager = DatabaseManager(db_path) + + logger.info("Main window initializing") + + # Setup UI + self.setWindowTitle("Microscopy Object Detection") + self.setMinimumSize(1200, 800) + + self._create_menu_bar() + self._create_tab_widget() + self._create_status_bar() + + # Center window on screen + self._center_window() + + logger.info("Main window initialized") + + def _create_menu_bar(self): + """Create application menu bar.""" + menubar = self.menuBar() + + # File menu + file_menu = menubar.addMenu("&File") + + settings_action = QAction("&Settings", self) + settings_action.setShortcut(QKeySequence("Ctrl+,")) + settings_action.triggered.connect(self._show_settings) + file_menu.addAction(settings_action) + + file_menu.addSeparator() + + exit_action = QAction("E&xit", self) + exit_action.setShortcut(QKeySequence("Ctrl+Q")) + exit_action.triggered.connect(self.close) + file_menu.addAction(exit_action) + + # View menu + view_menu = menubar.addMenu("&View") + + refresh_action = QAction("&Refresh", self) + refresh_action.setShortcut(QKeySequence("F5")) + refresh_action.triggered.connect(self._refresh_current_tab) + view_menu.addAction(refresh_action) + + # Tools menu + tools_menu = menubar.addMenu("&Tools") + + db_stats_action = QAction("Database &Statistics", self) + db_stats_action.triggered.connect(self._show_database_stats) + tools_menu.addAction(db_stats_action) + + # Help menu + help_menu = menubar.addMenu("&Help") + + about_action = QAction("&About", self) + about_action.triggered.connect(self._show_about) + help_menu.addAction(about_action) + + docs_action = QAction("&Documentation", self) + docs_action.triggered.connect(self._show_documentation) + help_menu.addAction(docs_action) + + def _create_tab_widget(self): + """Create main tab widget with all tabs.""" + self.tab_widget = QTabWidget() + self.tab_widget.setTabPosition(QTabWidget.North) + + # Create tabs + try: + self.detection_tab = DetectionTab(self.db_manager, self.config_manager) + self.training_tab = TrainingTab(self.db_manager, self.config_manager) + self.validation_tab = ValidationTab(self.db_manager, self.config_manager) + self.results_tab = ResultsTab(self.db_manager, self.config_manager) + self.annotation_tab = AnnotationTab(self.db_manager, self.config_manager) + + # Add tabs to widget + self.tab_widget.addTab(self.detection_tab, "Detection") + self.tab_widget.addTab(self.training_tab, "Training") + self.tab_widget.addTab(self.validation_tab, "Validation") + self.tab_widget.addTab(self.results_tab, "Results") + self.tab_widget.addTab(self.annotation_tab, "Annotation (Future)") + + # Connect tab change signal + self.tab_widget.currentChanged.connect(self._on_tab_changed) + + except Exception as e: + logger.error(f"Error creating tabs: {e}") + # Create placeholder + placeholder = QWidget() + layout = QVBoxLayout() + layout.addWidget(QLabel(f"Error creating tabs: {e}")) + placeholder.setLayout(layout) + self.tab_widget.addTab(placeholder, "Error") + + self.setCentralWidget(self.tab_widget) + + def _create_status_bar(self): + """Create status bar.""" + self.status_bar = QStatusBar() + self.setStatusBar(self.status_bar) + + # Add permanent widgets to status bar + self.status_label = QLabel("Ready") + self.status_bar.addWidget(self.status_label) + + # Initial status message + self._update_status("Ready") + + def _center_window(self): + """Center window on screen.""" + screen = self.screen().geometry() + size = self.geometry() + self.move( + (screen.width() - size.width()) // 2, (screen.height() - size.height()) // 2 + ) + + def _show_settings(self): + """Show settings dialog.""" + logger.info("Opening settings dialog") + dialog = ConfigDialog(self.config_manager, self) + if dialog.exec(): + self._apply_settings() + self._update_status("Settings saved") + + def _apply_settings(self): + """Apply changed settings.""" + logger.info("Applying settings changes") + # Reload configuration in all tabs if needed + try: + if hasattr(self, "detection_tab"): + self.detection_tab.refresh() + if hasattr(self, "training_tab"): + self.training_tab.refresh() + if hasattr(self, "results_tab"): + self.results_tab.refresh() + except Exception as e: + logger.error(f"Error applying settings: {e}") + + def _refresh_current_tab(self): + """Refresh the current tab.""" + current_widget = self.tab_widget.currentWidget() + if hasattr(current_widget, "refresh"): + current_widget.refresh() + self._update_status("Tab refreshed") + + def _on_tab_changed(self, index: int): + """Handle tab change event.""" + tab_name = self.tab_widget.tabText(index) + logger.debug(f"Switched to tab: {tab_name}") + self._update_status(f"Viewing: {tab_name}") + + def _show_database_stats(self): + """Show database statistics dialog.""" + try: + stats = self.db_manager.get_detection_statistics() + + message = f""" +

Database Statistics

+

Total Detections: {stats.get('total_detections', 0)}

+

Average Confidence: {stats.get('average_confidence', 0):.2%}

+

Classes:

+ " + + QMessageBox.information(self, "Database Statistics", message) + + except Exception as e: + logger.error(f"Error getting database stats: {e}") + QMessageBox.warning( + self, "Error", f"Failed to get database statistics:\n{str(e)}" + ) + + def _show_about(self): + """Show about dialog.""" + about_text = """ +

Microscopy Object Detection Application

+

Version: 1.0.0

+

A desktop application for detecting organelles and membrane branching + structures in microscopy images using YOLOv8.

+ +

Features:

+ + +

Technologies:

+ + """ + + QMessageBox.about(self, "About", about_text) + + def _show_documentation(self): + """Show documentation.""" + QMessageBox.information( + self, + "Documentation", + "Please refer to README.md and ARCHITECTURE.md files in the project directory.", + ) + + def _update_status(self, message: str, timeout: int = 5000): + """ + Update status bar message. + + Args: + message: Status message to display + timeout: Time in milliseconds to show message (0 for permanent) + """ + self.status_label.setText(message) + if timeout > 0: + QTimer.singleShot(timeout, lambda: self.status_label.setText("Ready")) + + def closeEvent(self, event): + """Handle window close event.""" + reply = QMessageBox.question( + self, + "Confirm Exit", + "Are you sure you want to exit?", + QMessageBox.Yes | QMessageBox.No, + QMessageBox.No, + ) + + if reply == QMessageBox.Yes: + logger.info("Application closing") + event.accept() + else: + event.ignore() diff --git a/src/gui/tabs/__init__.py b/src/gui/tabs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/gui/tabs/annotation_tab.py b/src/gui/tabs/annotation_tab.py new file mode 100644 index 0000000..7b2f5fc --- /dev/null +++ b/src/gui/tabs/annotation_tab.py @@ -0,0 +1,48 @@ +""" +Annotation tab for the microscopy object detection application. +Future feature for manual annotation. +""" + +from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox + +from src.database.db_manager import DatabaseManager +from src.utils.config_manager import ConfigManager + + +class AnnotationTab(QWidget): + """Annotation tab placeholder (future feature).""" + + def __init__( + self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None + ): + super().__init__(parent) + self.db_manager = db_manager + self.config_manager = config_manager + + self._setup_ui() + + def _setup_ui(self): + """Setup user interface.""" + layout = QVBoxLayout() + + group = QGroupBox("Annotation Tool (Future Feature)") + group_layout = QVBoxLayout() + label = QLabel( + "Annotation functionality will be implemented in future version.\n\n" + "Planned Features:\n" + "- Image browser\n" + "- Drawing tools for bounding boxes\n" + "- Class label assignment\n" + "- Export annotations to YOLO format\n" + "- Annotation verification" + ) + group_layout.addWidget(label) + group.setLayout(group_layout) + + layout.addWidget(group) + layout.addStretch() + self.setLayout(layout) + + def refresh(self): + """Refresh the tab.""" + pass diff --git a/src/gui/tabs/detection_tab.py b/src/gui/tabs/detection_tab.py new file mode 100644 index 0000000..4fe71ce --- /dev/null +++ b/src/gui/tabs/detection_tab.py @@ -0,0 +1,344 @@ +""" +Detection tab for the microscopy object detection application. +Handles single image and batch detection. +""" + +from PySide6.QtWidgets import ( + QWidget, + QVBoxLayout, + QHBoxLayout, + QPushButton, + QLabel, + QComboBox, + QSlider, + QFileDialog, + QMessageBox, + QProgressBar, + QTextEdit, + QGroupBox, + QFormLayout, +) +from PySide6.QtCore import Qt, QThread, Signal +from pathlib import Path + +from src.database.db_manager import DatabaseManager +from src.utils.config_manager import ConfigManager +from src.utils.logger import get_logger +from src.utils.file_utils import get_image_files +from src.model.inference import InferenceEngine + + +logger = get_logger(__name__) + + +class DetectionWorker(QThread): + """Worker thread for running detection.""" + + progress = Signal(int, int, str) # current, total, message + finished = Signal(list) # results + error = Signal(str) # error message + + def __init__(self, engine, image_paths, repo_root, conf): + super().__init__() + self.engine = engine + self.image_paths = image_paths + self.repo_root = repo_root + self.conf = conf + + def run(self): + """Run detection in background thread.""" + try: + results = self.engine.detect_batch( + self.image_paths, self.repo_root, self.conf, self.progress.emit + ) + self.finished.emit(results) + except Exception as e: + logger.error(f"Detection error: {e}") + self.error.emit(str(e)) + + +class DetectionTab(QWidget): + """Detection tab for single image and batch detection.""" + + def __init__( + self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None + ): + super().__init__(parent) + self.db_manager = db_manager + self.config_manager = config_manager + self.inference_engine = None + self.current_model_id = None + + self._setup_ui() + self._connect_signals() + self._load_models() + + def _setup_ui(self): + """Setup user interface.""" + layout = QVBoxLayout() + + # Model selection group + model_group = QGroupBox("Model Selection") + model_layout = QFormLayout() + + self.model_combo = QComboBox() + self.model_combo.addItem("No models available", None) + model_layout.addRow("Model:", self.model_combo) + + model_group.setLayout(model_layout) + layout.addWidget(model_group) + + # Detection settings group + settings_group = QGroupBox("Detection Settings") + settings_layout = QFormLayout() + + # Confidence threshold + conf_layout = QHBoxLayout() + self.conf_slider = QSlider(Qt.Horizontal) + self.conf_slider.setRange(0, 100) + self.conf_slider.setValue(25) + self.conf_slider.setTickPosition(QSlider.TicksBelow) + self.conf_slider.setTickInterval(10) + conf_layout.addWidget(self.conf_slider) + + self.conf_label = QLabel("0.25") + conf_layout.addWidget(self.conf_label) + + settings_layout.addRow("Confidence:", conf_layout) + settings_group.setLayout(settings_layout) + layout.addWidget(settings_group) + + # Action buttons + button_layout = QHBoxLayout() + + self.single_image_btn = QPushButton("Detect Single Image") + self.single_image_btn.clicked.connect(self._detect_single_image) + button_layout.addWidget(self.single_image_btn) + + self.batch_btn = QPushButton("Detect Batch (Folder)") + self.batch_btn.clicked.connect(self._detect_batch) + button_layout.addWidget(self.batch_btn) + + layout.addLayout(button_layout) + + # Progress bar + self.progress_bar = QProgressBar() + self.progress_bar.setVisible(False) + layout.addWidget(self.progress_bar) + + # Results display + results_group = QGroupBox("Detection Results") + results_layout = QVBoxLayout() + + self.results_text = QTextEdit() + self.results_text.setReadOnly(True) + self.results_text.setMaximumHeight(200) + results_layout.addWidget(self.results_text) + + results_group.setLayout(results_layout) + layout.addWidget(results_group) + + layout.addStretch() + self.setLayout(layout) + + def _connect_signals(self): + """Connect signals and slots.""" + self.conf_slider.valueChanged.connect(self._update_confidence_label) + self.model_combo.currentIndexChanged.connect(self._on_model_changed) + + def _load_models(self): + """Load available models from database.""" + try: + models = self.db_manager.get_models() + self.model_combo.clear() + + if not models: + self.model_combo.addItem("No models available", None) + self._set_buttons_enabled(False) + return + + # Add base model option + base_model = self.config_manager.get( + "models.default_base_model", "yolov8s.pt" + ) + self.model_combo.addItem( + f"Base Model ({base_model})", {"id": 0, "path": base_model} + ) + + # Add trained models + for model in models: + display_name = f"{model['model_name']} v{model['model_version']}" + self.model_combo.addItem(display_name, model) + + self._set_buttons_enabled(True) + + except Exception as e: + logger.error(f"Error loading models: {e}") + QMessageBox.warning(self, "Error", f"Failed to load models:\n{str(e)}") + + def _on_model_changed(self, index: int): + """Handle model selection change.""" + model_data = self.model_combo.itemData(index) + if model_data and model_data["id"] != 0: + self.current_model_id = model_data["id"] + else: + self.current_model_id = None + + def _update_confidence_label(self, value: int): + """Update confidence label.""" + conf = value / 100.0 + self.conf_label.setText(f"{conf:.2f}") + + def _detect_single_image(self): + """Detect objects in a single image.""" + # Get image file + repo_path = self.config_manager.get_image_repository_path() + start_dir = repo_path if repo_path else "" + + file_path, _ = QFileDialog.getOpenFileName( + self, + "Select Image", + start_dir, + "Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)", + ) + + if not file_path: + return + + # Run detection + self._run_detection([file_path]) + + def _detect_batch(self): + """Detect objects in batch (folder).""" + # Get folder + repo_path = self.config_manager.get_image_repository_path() + start_dir = repo_path if repo_path else "" + + folder_path = QFileDialog.getExistingDirectory(self, "Select Folder", start_dir) + + if not folder_path: + return + + # Get all image files + allowed_ext = self.config_manager.get_allowed_extensions() + image_files = get_image_files(folder_path, allowed_ext, recursive=False) + + if not image_files: + QMessageBox.information( + self, "No Images", "No image files found in selected folder." + ) + return + + # Confirm batch processing + reply = QMessageBox.question( + self, + "Confirm Batch Detection", + f"Process {len(image_files)} images?", + QMessageBox.Yes | QMessageBox.No, + ) + + if reply == QMessageBox.Yes: + self._run_detection(image_files) + + def _run_detection(self, image_paths: list): + """Run detection on image list.""" + try: + # Get selected model + model_data = self.model_combo.currentData() + if not model_data: + QMessageBox.warning(self, "No Model", "Please select a model first.") + return + + model_path = model_data["path"] + model_id = model_data["id"] + + # Ensure we have a valid model ID (create entry for base model if needed) + if model_id == 0: + # Create database entry for base model + base_model = self.config_manager.get( + "models.default_base_model", "yolov8s.pt" + ) + model_id = self.db_manager.add_model( + model_name="Base Model", + model_version="pretrained", + model_path=base_model, + base_model=base_model, + ) + + # Create inference engine + self.inference_engine = InferenceEngine( + model_path, self.db_manager, model_id + ) + + # Get confidence threshold + conf = self.conf_slider.value() / 100.0 + + # Get repository root + repo_root = self.config_manager.get_image_repository_path() + if not repo_root: + repo_root = str(Path(image_paths[0]).parent) + + # Show progress bar + self.progress_bar.setVisible(True) + self.progress_bar.setMaximum(len(image_paths)) + self._set_buttons_enabled(False) + + # Create and start worker thread + self.worker = DetectionWorker( + self.inference_engine, image_paths, repo_root, conf + ) + self.worker.progress.connect(self._on_progress) + self.worker.finished.connect(self._on_detection_finished) + self.worker.error.connect(self._on_detection_error) + self.worker.start() + + except Exception as e: + logger.error(f"Error starting detection: {e}") + QMessageBox.critical(self, "Error", f"Failed to start detection:\n{str(e)}") + self._set_buttons_enabled(True) + + def _on_progress(self, current: int, total: int, message: str): + """Handle progress update.""" + self.progress_bar.setValue(current) + self.results_text.append(f"[{current}/{total}] {message}") + + def _on_detection_finished(self, results: list): + """Handle detection completion.""" + self.progress_bar.setVisible(False) + self._set_buttons_enabled(True) + + # Calculate statistics + total_detections = sum(r["count"] for r in results) + successful = sum(1 for r in results if r.get("success", False)) + + summary = f"\n=== Detection Complete ===\n" + summary += f"Processed: {len(results)} images\n" + summary += f"Successful: {successful}\n" + summary += f"Total detections: {total_detections}\n" + + self.results_text.append(summary) + + QMessageBox.information( + self, + "Detection Complete", + f"Processed {len(results)} images\n{total_detections} objects detected", + ) + + def _on_detection_error(self, error_msg: str): + """Handle detection error.""" + self.progress_bar.setVisible(False) + self._set_buttons_enabled(True) + + self.results_text.append(f"\nERROR: {error_msg}") + QMessageBox.critical(self, "Detection Error", error_msg) + + def _set_buttons_enabled(self, enabled: bool): + """Enable/disable action buttons.""" + self.single_image_btn.setEnabled(enabled) + self.batch_btn.setEnabled(enabled) + self.model_combo.setEnabled(enabled) + + def refresh(self): + """Refresh the tab.""" + self._load_models() + self.results_text.clear() diff --git a/src/gui/tabs/results_tab.py b/src/gui/tabs/results_tab.py new file mode 100644 index 0000000..71e523c --- /dev/null +++ b/src/gui/tabs/results_tab.py @@ -0,0 +1,46 @@ +""" +Results tab for the microscopy object detection application. +""" + +from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox + +from src.database.db_manager import DatabaseManager +from src.utils.config_manager import ConfigManager + + +class ResultsTab(QWidget): + """Results tab placeholder.""" + + def __init__( + self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None + ): + super().__init__(parent) + self.db_manager = db_manager + self.config_manager = config_manager + + self._setup_ui() + + def _setup_ui(self): + """Setup user interface.""" + layout = QVBoxLayout() + + group = QGroupBox("Results") + group_layout = QVBoxLayout() + label = QLabel( + "Results viewer will be implemented here.\n\n" + "Features:\n" + "- Detection history browser\n" + "- Advanced filtering\n" + "- Statistics dashboard\n" + "- Export functionality" + ) + group_layout.addWidget(label) + group.setLayout(group_layout) + + layout.addWidget(group) + layout.addStretch() + self.setLayout(layout) + + def refresh(self): + """Refresh the tab.""" + pass diff --git a/src/gui/tabs/training_tab.py b/src/gui/tabs/training_tab.py new file mode 100644 index 0000000..8c3e778 --- /dev/null +++ b/src/gui/tabs/training_tab.py @@ -0,0 +1,52 @@ +""" +Training tab for the microscopy object detection application. +Handles model training with YOLO. +""" + +from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox + +from src.database.db_manager import DatabaseManager +from src.utils.config_manager import ConfigManager +from src.utils.logger import get_logger + + +logger = get_logger(__name__) + + +class TrainingTab(QWidget): + """Training tab for model training.""" + + def __init__( + self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None + ): + super().__init__(parent) + self.db_manager = db_manager + self.config_manager = config_manager + + self._setup_ui() + + def _setup_ui(self): + """Setup user interface.""" + layout = QVBoxLayout() + + # Placeholder + group = QGroupBox("Training") + group_layout = QVBoxLayout() + label = QLabel( + "Training functionality will be implemented here.\n\n" + "Features:\n" + "- Dataset selection\n" + "- Training parameter configuration\n" + "- Real-time training progress\n" + "- Loss and metric visualization" + ) + group_layout.addWidget(label) + group.setLayout(group_layout) + + layout.addWidget(group) + layout.addStretch() + self.setLayout(layout) + + def refresh(self): + """Refresh the tab.""" + pass diff --git a/src/gui/tabs/validation_tab.py b/src/gui/tabs/validation_tab.py new file mode 100644 index 0000000..2e7749a --- /dev/null +++ b/src/gui/tabs/validation_tab.py @@ -0,0 +1,46 @@ +""" +Validation tab for the microscopy object detection application. +""" + +from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox + +from src.database.db_manager import DatabaseManager +from src.utils.config_manager import ConfigManager + + +class ValidationTab(QWidget): + """Validation tab placeholder.""" + + def __init__( + self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None + ): + super().__init__(parent) + self.db_manager = db_manager + self.config_manager = config_manager + + self._setup_ui() + + def _setup_ui(self): + """Setup user interface.""" + layout = QVBoxLayout() + + group = QGroupBox("Validation") + group_layout = QVBoxLayout() + label = QLabel( + "Validation functionality will be implemented here.\n\n" + "Features:\n" + "- Model validation\n" + "- Metrics visualization\n" + "- Confusion matrix\n" + "- Precision-Recall curves" + ) + group_layout.addWidget(label) + group.setLayout(group_layout) + + layout.addWidget(group) + layout.addStretch() + self.setLayout(layout) + + def refresh(self): + """Refresh the tab.""" + pass diff --git a/src/gui/widgets/__init__.py b/src/gui/widgets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/model/__init__.py b/src/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/model/inference.py b/src/model/inference.py new file mode 100644 index 0000000..1fc5ab8 --- /dev/null +++ b/src/model/inference.py @@ -0,0 +1,323 @@ +""" +Inference engine for the microscopy object detection application. +Handles detection inference and result storage. +""" + +from typing import List, Dict, Optional, Callable +from pathlib import Path +from PIL import Image +import cv2 +import numpy as np + +from src.model.yolo_wrapper import YOLOWrapper +from src.database.db_manager import DatabaseManager +from src.utils.logger import get_logger +from src.utils.file_utils import get_relative_path + + +logger = get_logger(__name__) + + +class InferenceEngine: + """Handles detection inference and result storage.""" + + def __init__(self, model_path: str, db_manager: DatabaseManager, model_id: int): + """ + Initialize inference engine. + + Args: + model_path: Path to YOLO model weights + db_manager: Database manager instance + model_id: ID of the model in database + """ + self.yolo = YOLOWrapper(model_path) + self.yolo.load_model() + self.db_manager = db_manager + self.model_id = model_id + logger.info(f"InferenceEngine initialized with model_id {model_id}") + + def detect_single( + self, + image_path: str, + relative_path: str, + conf: float = 0.25, + save_to_db: bool = True, + ) -> Dict: + """ + Detect objects in a single image. + + Args: + image_path: Absolute path to image file + relative_path: Relative path from repository root + conf: Confidence threshold + save_to_db: Whether to save results to database + + Returns: + Dictionary with detection results + """ + try: + # Get image dimensions + img = Image.open(image_path) + width, height = img.size + img.close() + + # Perform detection + detections = self.yolo.predict(image_path, conf=conf) + + # Add/get image in database + image_id = self.db_manager.get_or_create_image( + relative_path=relative_path, + filename=Path(image_path).name, + width=width, + height=height, + ) + + # Save detections to database + if save_to_db and detections: + detection_records = [] + for det in detections: + # Use normalized bbox from detection + bbox_normalized = det[ + "bbox_normalized" + ] # [x_min, y_min, x_max, y_max] + + record = { + "image_id": image_id, + "model_id": self.model_id, + "class_name": det["class_name"], + "bbox": tuple(bbox_normalized), + "confidence": det["confidence"], + "metadata": {"class_id": det["class_id"]}, + } + detection_records.append(record) + + self.db_manager.add_detections_batch(detection_records) + logger.info(f"Saved {len(detection_records)} detections to database") + + return { + "success": True, + "image_path": image_path, + "image_id": image_id, + "detections": detections, + "count": len(detections), + } + + except Exception as e: + logger.error(f"Error detecting objects in {image_path}: {e}") + return { + "success": False, + "image_path": image_path, + "error": str(e), + "detections": [], + "count": 0, + } + + def detect_batch( + self, + image_paths: List[str], + repository_root: str, + conf: float = 0.25, + progress_callback: Optional[Callable[[int, int, str], None]] = None, + ) -> List[Dict]: + """ + Detect objects in multiple images. + + Args: + image_paths: List of absolute image paths + repository_root: Root directory for relative paths + conf: Confidence threshold + progress_callback: Optional callback(current, total, message) + + Returns: + List of detection result dictionaries + """ + results = [] + total = len(image_paths) + + logger.info(f"Starting batch detection on {total} images") + + for i, image_path in enumerate(image_paths, 1): + # Calculate relative path + rel_path = get_relative_path(image_path, repository_root) + + # Perform detection + result = self.detect_single(image_path, rel_path, conf) + results.append(result) + + # Update progress + if progress_callback: + progress_callback(i, total, f"Processed {rel_path}") + + if i % 10 == 0: + logger.info(f"Processed {i}/{total} images") + + logger.info(f"Batch detection complete: {total} images processed") + return results + + def detect_with_visualization( + self, + image_path: str, + conf: float = 0.25, + bbox_thickness: int = 2, + bbox_colors: Optional[Dict[str, str]] = None, + ) -> tuple: + """ + Detect objects and return annotated image. + + Args: + image_path: Path to image + conf: Confidence threshold + bbox_thickness: Thickness of bounding boxes + bbox_colors: Dictionary mapping class names to hex colors + + Returns: + Tuple of (detections, annotated_image_array) + """ + try: + detections = self.yolo.predict(image_path, conf=conf) + + # Load image + img = cv2.imread(image_path) + if img is None: + raise ValueError(f"Failed to load image: {image_path}") + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + height, width = img.shape[:2] + + # Default colors if not provided + if bbox_colors is None: + bbox_colors = {} + default_color = self._hex_to_bgr(bbox_colors.get("default", "#00FF00")) + + # Draw bounding boxes + for det in detections: + # Get absolute coordinates + bbox_abs = det["bbox_absolute"] + x1, y1, x2, y2 = [int(v) for v in bbox_abs] + + # Get color for this class + class_name = det["class_name"] + color_hex = bbox_colors.get( + class_name, bbox_colors.get("default", "#00FF00") + ) + color = self._hex_to_bgr(color_hex) + + # Draw box + cv2.rectangle(img, (x1, y1), (x2, y2), color, bbox_thickness) + + # Prepare label + label = f"{class_name} {det['confidence']:.2f}" + + # Draw label background + (label_w, label_h), baseline = cv2.getTextSize( + label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1 + ) + cv2.rectangle( + img, + (x1, y1 - label_h - baseline - 5), + (x1 + label_w, y1), + color, + -1, + ) + + # Draw label text + cv2.putText( + img, + label, + (x1, y1 - baseline - 5), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (255, 255, 255), + 1, + ) + + return detections, img + + except Exception as e: + logger.error(f"Error creating visualization: {e}") + # Return empty detections and original image if possible + try: + img = cv2.imread(image_path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return [], img + except: + return [], np.zeros((480, 640, 3), dtype=np.uint8) + + def get_detection_summary(self, detections: List[Dict]) -> Dict[str, any]: + """ + Generate summary statistics for detections. + + Args: + detections: List of detection dictionaries + + Returns: + Dictionary with summary statistics + """ + if not detections: + return { + "total_count": 0, + "class_counts": {}, + "avg_confidence": 0.0, + "confidence_range": (0.0, 0.0), + } + + # Count by class + class_counts = {} + confidences = [] + + for det in detections: + class_name = det["class_name"] + class_counts[class_name] = class_counts.get(class_name, 0) + 1 + confidences.append(det["confidence"]) + + return { + "total_count": len(detections), + "class_counts": class_counts, + "avg_confidence": sum(confidences) / len(confidences), + "confidence_range": (min(confidences), max(confidences)), + } + + @staticmethod + def _hex_to_bgr(hex_color: str) -> tuple: + """ + Convert hex color to BGR tuple. + + Args: + hex_color: Hex color string (e.g., '#FF0000') + + Returns: + BGR tuple (B, G, R) + """ + hex_color = hex_color.lstrip("#") + if len(hex_color) != 6: + return (0, 255, 0) # Default green + + try: + r = int(hex_color[0:2], 16) + g = int(hex_color[2:4], 16) + b = int(hex_color[4:6], 16) + return (b, g, r) # OpenCV uses BGR + except ValueError: + return (0, 255, 0) # Default green + + def change_model(self, model_path: str, model_id: int) -> bool: + """ + Change the current model. + + Args: + model_path: Path to new model weights + model_id: ID of new model in database + + Returns: + True if successful, False otherwise + """ + try: + self.yolo = YOLOWrapper(model_path) + if self.yolo.load_model(): + self.model_id = model_id + logger.info(f"Model changed to {model_path}") + return True + return False + except Exception as e: + logger.error(f"Error changing model: {e}") + return False diff --git a/src/model/yolo_wrapper.py b/src/model/yolo_wrapper.py new file mode 100644 index 0000000..d4a8050 --- /dev/null +++ b/src/model/yolo_wrapper.py @@ -0,0 +1,364 @@ +""" +YOLO model wrapper for the microscopy object detection application. +Provides a clean interface to YOLOv8 for training, validation, and inference. +""" + +from ultralytics import YOLO +from pathlib import Path +from typing import Optional, List, Dict, Callable, Any +import torch +from src.utils.logger import get_logger + + +logger = get_logger(__name__) + + +class YOLOWrapper: + """Wrapper for YOLOv8 model operations.""" + + def __init__(self, model_path: str = "yolov8s.pt"): + """ + Initialize YOLO model. + + Args: + model_path: Path to model weights (.pt file) + """ + self.model_path = model_path + self.model = None + self.device = "cuda" if torch.cuda.is_available() else "cpu" + logger.info(f"YOLOWrapper initialized with device: {self.device}") + + def load_model(self) -> bool: + """ + Load YOLO model from path. + + Returns: + True if loaded successfully, False otherwise + """ + try: + logger.info(f"Loading YOLO model from {self.model_path}") + self.model = YOLO(self.model_path) + self.model.to(self.device) + logger.info("Model loaded successfully") + return True + except Exception as e: + logger.error(f"Error loading model: {e}") + return False + + def train( + self, + data_yaml: str, + epochs: int = 100, + imgsz: int = 640, + batch: int = 16, + patience: int = 50, + save_dir: str = "data/models", + name: str = "custom_model", + resume: bool = False, + **kwargs, + ) -> Dict[str, Any]: + """ + Train the YOLO model. + + Args: + data_yaml: Path to data.yaml configuration file + epochs: Number of training epochs + imgsz: Input image size + batch: Batch size + patience: Early stopping patience + save_dir: Directory to save trained model + name: Name for the training run + resume: Resume training from last checkpoint + **kwargs: Additional training arguments + + Returns: + Dictionary with training results + """ + if self.model is None: + self.load_model() + + try: + logger.info(f"Starting training: {name}") + logger.info( + f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}" + ) + + # Train the model + results = self.model.train( + data=data_yaml, + epochs=epochs, + imgsz=imgsz, + batch=batch, + patience=patience, + project=save_dir, + name=name, + device=self.device, + resume=resume, + **kwargs, + ) + + logger.info("Training completed successfully") + return self._format_training_results(results) + + except Exception as e: + logger.error(f"Error during training: {e}") + raise + + def validate(self, data_yaml: str, split: str = "val", **kwargs) -> Dict[str, Any]: + """ + Validate the model. + + Args: + data_yaml: Path to data.yaml configuration file + split: Dataset split to validate on ('val' or 'test') + **kwargs: Additional validation arguments + + Returns: + Dictionary with validation metrics + """ + if self.model is None: + self.load_model() + + try: + logger.info(f"Starting validation on {split} split") + results = self.model.val( + data=data_yaml, split=split, device=self.device, **kwargs + ) + + logger.info("Validation completed successfully") + return self._format_validation_results(results) + + except Exception as e: + logger.error(f"Error during validation: {e}") + raise + + def predict( + self, + source: str, + conf: float = 0.25, + iou: float = 0.45, + save: bool = False, + save_txt: bool = False, + save_conf: bool = False, + **kwargs, + ) -> List[Dict]: + """ + Perform inference on image(s). + + Args: + source: Path to image or directory + conf: Confidence threshold + iou: IoU threshold for NMS + save: Whether to save annotated images + save_txt: Whether to save labels to .txt files + save_conf: Whether to save confidence in labels + **kwargs: Additional prediction arguments + + Returns: + List of detection dictionaries + """ + if self.model is None: + self.load_model() + + try: + logger.info(f"Running inference on {source}") + results = self.model.predict( + source=source, + conf=conf, + iou=iou, + save=save, + save_txt=save_txt, + save_conf=save_conf, + device=self.device, + **kwargs, + ) + + detections = self._format_prediction_results(results) + logger.info(f"Inference complete: {len(detections)} detections") + return detections + + except Exception as e: + logger.error(f"Error during inference: {e}") + raise + + def export( + self, format: str = "onnx", output_path: Optional[str] = None, **kwargs + ) -> str: + """ + Export model to different format. + + Args: + format: Export format (onnx, torchscript, tflite, etc.) + output_path: Path for exported model + **kwargs: Additional export arguments + + Returns: + Path to exported model + """ + if self.model is None: + self.load_model() + + try: + logger.info(f"Exporting model to {format} format") + export_path = self.model.export(format=format, **kwargs) + logger.info(f"Model exported to {export_path}") + return str(export_path) + + except Exception as e: + logger.error(f"Error exporting model: {e}") + raise + + def _format_training_results(self, results) -> Dict[str, Any]: + """Format training results into dictionary.""" + try: + # Get the results dict + results_dict = ( + results.results_dict if hasattr(results, "results_dict") else {} + ) + + formatted = { + "success": True, + "final_epoch": getattr(results, "epoch", 0), + "metrics": { + "mAP50": float(results_dict.get("metrics/mAP50(B)", 0)), + "mAP50-95": float(results_dict.get("metrics/mAP50-95(B)", 0)), + "precision": float(results_dict.get("metrics/precision(B)", 0)), + "recall": float(results_dict.get("metrics/recall(B)", 0)), + }, + "best_model_path": str(Path(results.save_dir) / "weights" / "best.pt"), + "last_model_path": str(Path(results.save_dir) / "weights" / "last.pt"), + "save_dir": str(results.save_dir), + } + + return formatted + + except Exception as e: + logger.error(f"Error formatting training results: {e}") + return {"success": False, "error": str(e)} + + def _format_validation_results(self, results) -> Dict[str, Any]: + """Format validation results into dictionary.""" + try: + box_metrics = results.box + + formatted = { + "success": True, + "mAP50": float(box_metrics.map50), + "mAP50-95": float(box_metrics.map), + "precision": float(box_metrics.mp), + "recall": float(box_metrics.mr), + "fitness": ( + float(results.fitness) if hasattr(results, "fitness") else 0.0 + ), + } + + # Add per-class metrics if available + if hasattr(box_metrics, "ap") and hasattr(results, "names"): + class_metrics = {} + for idx, name in results.names.items(): + if idx < len(box_metrics.ap): + class_metrics[name] = { + "ap": float(box_metrics.ap[idx]), + "ap50": ( + float(box_metrics.ap50[idx]) + if hasattr(box_metrics, "ap50") + else 0.0 + ), + } + formatted["class_metrics"] = class_metrics + + return formatted + + except Exception as e: + logger.error(f"Error formatting validation results: {e}") + return {"success": False, "error": str(e)} + + def _format_prediction_results(self, results) -> List[Dict]: + """Format prediction results into list of dictionaries.""" + detections = [] + + try: + for result in results: + boxes = result.boxes + image_path = str(result.path) + orig_shape = result.orig_shape # (height, width) + + for i in range(len(boxes)): + # Get normalized coordinates + xyxyn = boxes.xyxyn[i].cpu().numpy() # Normalized [x1, y1, x2, y2] + + detection = { + "image_path": image_path, + "class_id": int(boxes.cls[i]), + "class_name": result.names[int(boxes.cls[i])], + "confidence": float(boxes.conf[i]), + "bbox_normalized": [ + float(v) for v in xyxyn + ], # [x_min, y_min, x_max, y_max] + "bbox_absolute": [ + float(v) for v in boxes.xyxy[i].cpu().numpy() + ], # Absolute pixels + } + detections.append(detection) + + return detections + + except Exception as e: + logger.error(f"Error formatting prediction results: {e}") + return [] + + @staticmethod + def convert_bbox_format( + bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy" + ) -> List[float]: + """ + Convert bounding box between formats. + + Formats: + - xywh: [x_center, y_center, width, height] + - xyxy: [x_min, y_min, x_max, y_max] + + Args: + bbox: Bounding box coordinates + format_from: Source format + format_to: Target format + + Returns: + Converted bounding box + """ + if format_from == "xywh" and format_to == "xyxy": + x, y, w, h = bbox + return [x - w / 2, y - h / 2, x + w / 2, y + h / 2] + elif format_from == "xyxy" and format_to == "xywh": + x1, y1, x2, y2 = bbox + return [(x1 + x2) / 2, (y1 + y2) / 2, x2 - x1, y2 - y1] + else: + return bbox + + def get_model_info(self) -> Dict[str, Any]: + """ + Get information about the loaded model. + + Returns: + Dictionary with model information + """ + if self.model is None: + return {"error": "Model not loaded"} + + try: + info = { + "model_path": self.model_path, + "device": self.device, + "task": getattr(self.model, "task", "unknown"), + } + + # Try to get class names + if hasattr(self.model, "names"): + info["classes"] = self.model.names + info["num_classes"] = len(self.model.names) + + return info + + except Exception as e: + logger.error(f"Error getting model info: {e}") + return {"error": str(e)} diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/config_manager.py b/src/utils/config_manager.py new file mode 100644 index 0000000..385d2a9 --- /dev/null +++ b/src/utils/config_manager.py @@ -0,0 +1,218 @@ +""" +Configuration manager for the microscopy object detection application. +Handles loading, saving, and accessing application configuration. +""" + +import yaml +from pathlib import Path +from typing import Any, Dict, Optional +from src.utils.logger import get_logger + + +logger = get_logger(__name__) + + +class ConfigManager: + """Manages application configuration.""" + + def __init__(self, config_path: str = "config/app_config.yaml"): + """ + Initialize configuration manager. + + Args: + config_path: Path to configuration file + """ + self.config_path = Path(config_path) + self.config: Dict[str, Any] = {} + self._load_config() + + def _load_config(self) -> None: + """Load configuration from YAML file.""" + try: + if self.config_path.exists(): + with open(self.config_path, "r") as f: + self.config = yaml.safe_load(f) or {} + logger.info(f"Configuration loaded from {self.config_path}") + else: + logger.warning(f"Configuration file not found: {self.config_path}") + self._create_default_config() + except Exception as e: + logger.error(f"Error loading configuration: {e}") + self._create_default_config() + + def _create_default_config(self) -> None: + """Create default configuration.""" + self.config = { + "database": {"path": "data/detections.db"}, + "image_repository": { + "base_path": "", + "allowed_extensions": [ + ".jpg", + ".jpeg", + ".png", + ".tif", + ".tiff", + ".bmp", + ], + }, + "models": { + "default_base_model": "yolov8s.pt", + "models_directory": "data/models", + }, + "training": { + "default_epochs": 100, + "default_batch_size": 16, + "default_imgsz": 640, + "default_patience": 50, + "default_lr0": 0.01, + }, + "detection": { + "default_confidence": 0.25, + "default_iou": 0.45, + "max_batch_size": 100, + }, + "visualization": { + "bbox_colors": { + "organelle": "#FF6B6B", + "membrane_branch": "#4ECDC4", + "default": "#00FF00", + }, + "bbox_thickness": 2, + "font_size": 12, + }, + "export": {"formats": ["csv", "json", "excel"], "default_format": "csv"}, + "logging": { + "level": "INFO", + "file": "logs/app.log", + "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", + }, + } + self.save_config() + + def save_config(self) -> bool: + """ + Save current configuration to file. + + Returns: + True if successful, False otherwise + """ + try: + # Create directory if it doesn't exist + self.config_path.parent.mkdir(parents=True, exist_ok=True) + + with open(self.config_path, "w") as f: + yaml.dump(self.config, f, default_flow_style=False, sort_keys=False) + + logger.info(f"Configuration saved to {self.config_path}") + return True + except Exception as e: + logger.error(f"Error saving configuration: {e}") + return False + + def get(self, key: str, default: Any = None) -> Any: + """ + Get configuration value by key. + + Args: + key: Configuration key (can use dot notation, e.g., 'database.path') + default: Default value if key not found + + Returns: + Configuration value or default + """ + keys = key.split(".") + value = self.config + + for k in keys: + if isinstance(value, dict) and k in value: + value = value[k] + else: + return default + + return value + + def set(self, key: str, value: Any) -> None: + """ + Set configuration value by key. + + Args: + key: Configuration key (can use dot notation) + value: Value to set + """ + keys = key.split(".") + config = self.config + + # Navigate to the nested dictionary + for k in keys[:-1]: + if k not in config: + config[k] = {} + config = config[k] + + # Set the value + config[keys[-1]] = value + logger.debug(f"Configuration updated: {key} = {value}") + + def get_section(self, section: str) -> Dict[str, Any]: + """ + Get entire configuration section. + + Args: + section: Section name (e.g., 'database', 'training') + + Returns: + Dictionary with section configuration + """ + return self.config.get(section, {}) + + def update_section(self, section: str, values: Dict[str, Any]) -> None: + """ + Update entire configuration section. + + Args: + section: Section name + values: Dictionary with new values + """ + if section not in self.config: + self.config[section] = {} + + self.config[section].update(values) + logger.debug(f"Configuration section updated: {section}") + + def reload(self) -> None: + """Reload configuration from file.""" + self._load_config() + + def get_database_path(self) -> str: + """Get database path.""" + return self.get("database.path", "data/detections.db") + + def get_image_repository_path(self) -> str: + """Get image repository base path.""" + return self.get("image_repository.base_path", "") + + def set_image_repository_path(self, path: str) -> None: + """Set image repository base path.""" + self.set("image_repository.base_path", path) + self.save_config() + + def get_models_directory(self) -> str: + """Get models directory path.""" + return self.get("models.models_directory", "data/models") + + def get_default_training_params(self) -> Dict[str, Any]: + """Get default training parameters.""" + return self.get_section("training") + + def get_default_detection_params(self) -> Dict[str, Any]: + """Get default detection parameters.""" + return self.get_section("detection") + + def get_bbox_colors(self) -> Dict[str, str]: + """Get bounding box colors for different classes.""" + return self.get("visualization.bbox_colors", {}) + + def get_allowed_extensions(self) -> list: + """Get list of allowed image file extensions.""" + return self.get( + "image_repository.allowed_extensions", [".jpg", ".jpeg", ".png"] + ) diff --git a/src/utils/file_utils.py b/src/utils/file_utils.py new file mode 100644 index 0000000..019852e --- /dev/null +++ b/src/utils/file_utils.py @@ -0,0 +1,235 @@ +""" +File utility functions for the microscopy object detection application. +""" + +import os +from pathlib import Path +from typing import List, Optional +from src.utils.logger import get_logger + + +logger = get_logger(__name__) + + +def get_image_files( + directory: str, + allowed_extensions: Optional[List[str]] = None, + recursive: bool = False, +) -> List[str]: + """ + Get all image files in a directory. + + Args: + directory: Directory path to search + allowed_extensions: List of allowed file extensions (e.g., ['.jpg', '.png']) + recursive: Whether to search recursively + + Returns: + List of absolute paths to image files + """ + if allowed_extensions is None: + allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"] + + # Normalize extensions to lowercase + allowed_extensions = [ext.lower() for ext in allowed_extensions] + + image_files = [] + directory_path = Path(directory) + + if not directory_path.exists(): + logger.error(f"Directory does not exist: {directory}") + return image_files + + try: + if recursive: + # Recursive search + for ext in allowed_extensions: + image_files.extend(directory_path.rglob(f"*{ext}")) + # Also search uppercase extensions + image_files.extend(directory_path.rglob(f"*{ext.upper()}")) + else: + # Top-level search only + for ext in allowed_extensions: + image_files.extend(directory_path.glob(f"*{ext}")) + # Also search uppercase extensions + image_files.extend(directory_path.glob(f"*{ext.upper()}")) + + # Convert to absolute paths and sort + image_files = sorted([str(f.absolute()) for f in image_files]) + logger.info(f"Found {len(image_files)} image files in {directory}") + + except Exception as e: + logger.error(f"Error searching for images: {e}") + + return image_files + + +def ensure_directory(directory: str) -> bool: + """ + Ensure a directory exists, create if it doesn't. + + Args: + directory: Directory path + + Returns: + True if directory exists or was created successfully + """ + try: + Path(directory).mkdir(parents=True, exist_ok=True) + return True + except Exception as e: + logger.error(f"Error creating directory {directory}: {e}") + return False + + +def get_relative_path(file_path: str, base_path: str) -> str: + """ + Get relative path from base path. + + Args: + file_path: Absolute file path + base_path: Base directory path + + Returns: + Relative path string + """ + try: + return str(Path(file_path).relative_to(base_path)) + except ValueError: + # If file_path is not relative to base_path, return the filename + return Path(file_path).name + + +def validate_file_path(file_path: str, must_exist: bool = True) -> bool: + """ + Validate a file path. + + Args: + file_path: Path to validate + must_exist: Whether the file must exist + + Returns: + True if valid, False otherwise + """ + path = Path(file_path) + + if must_exist and not path.exists(): + logger.error(f"File does not exist: {file_path}") + return False + + if must_exist and not path.is_file(): + logger.error(f"Path is not a file: {file_path}") + return False + + return True + + +def get_file_size(file_path: str) -> int: + """ + Get file size in bytes. + + Args: + file_path: Path to file + + Returns: + File size in bytes, or 0 if error + """ + try: + return Path(file_path).stat().st_size + except Exception as e: + logger.error(f"Error getting file size for {file_path}: {e}") + return 0 + + +def format_file_size(size_bytes: int) -> str: + """ + Format file size in human-readable format. + + Args: + size_bytes: Size in bytes + + Returns: + Formatted string (e.g., "1.5 MB") + """ + for unit in ["B", "KB", "MB", "GB"]: + if size_bytes < 1024.0: + return f"{size_bytes:.1f} {unit}" + size_bytes /= 1024.0 + return f"{size_bytes:.1f} TB" + + +def create_unique_filename(directory: str, base_name: str, extension: str) -> str: + """ + Create a unique filename by adding a number suffix if file exists. + + Args: + directory: Directory path + base_name: Base filename without extension + extension: File extension (with or without dot) + + Returns: + Unique filename + """ + if not extension.startswith("."): + extension = "." + extension + + directory_path = Path(directory) + filename = f"{base_name}{extension}" + file_path = directory_path / filename + + if not file_path.exists(): + return filename + + # Add number suffix + counter = 1 + while True: + filename = f"{base_name}_{counter}{extension}" + file_path = directory_path / filename + if not file_path.exists(): + return filename + counter += 1 + + +def is_image_file( + file_path: str, allowed_extensions: Optional[List[str]] = None +) -> bool: + """ + Check if a file is an image based on extension. + + Args: + file_path: Path to file + allowed_extensions: List of allowed extensions + + Returns: + True if file is an image + """ + if allowed_extensions is None: + allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"] + + extension = Path(file_path).suffix.lower() + return extension in [ext.lower() for ext in allowed_extensions] + + +def safe_filename(filename: str) -> str: + """ + Convert a string to a safe filename by removing/replacing invalid characters. + + Args: + filename: Original filename + + Returns: + Safe filename + """ + # Replace invalid characters + invalid_chars = '<>:"/\\|?*' + for char in invalid_chars: + filename = filename.replace(char, "_") + + # Remove leading/trailing spaces and dots + filename = filename.strip(". ") + + # Ensure filename is not empty + if not filename: + filename = "unnamed" + + return filename diff --git a/src/utils/logger.py b/src/utils/logger.py new file mode 100644 index 0000000..5faf6df --- /dev/null +++ b/src/utils/logger.py @@ -0,0 +1,75 @@ +""" +Logging configuration for the microscopy object detection application. +""" + +import logging +import sys +from pathlib import Path +from typing import Optional + + +def setup_logging( + log_file: str = "logs/app.log", + level: str = "INFO", + log_format: Optional[str] = None, +) -> logging.Logger: + """ + Setup application logging. + + Args: + log_file: Path to log file + level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) + log_format: Custom log format string + + Returns: + Configured logger instance + """ + # Create logs directory if it doesn't exist + log_path = Path(log_file) + log_path.parent.mkdir(parents=True, exist_ok=True) + + # Default format if none provided + if log_format is None: + log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + + # Convert level string to logging constant + numeric_level = getattr(logging, level.upper(), logging.INFO) + + # Configure root logger + root_logger = logging.getLogger() + root_logger.setLevel(numeric_level) + + # Remove existing handlers + root_logger.handlers.clear() + + # Console handler + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(numeric_level) + console_formatter = logging.Formatter(log_format) + console_handler.setFormatter(console_formatter) + root_logger.addHandler(console_handler) + + # File handler + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(numeric_level) + file_formatter = logging.Formatter(log_format) + file_handler.setFormatter(file_formatter) + root_logger.addHandler(file_handler) + + # Log initial message + root_logger.info("Logging initialized") + + return root_logger + + +def get_logger(name: str) -> logging.Logger: + """ + Get a logger instance for a specific module. + + Args: + name: Logger name (typically __name__) + + Returns: + Logger instance + """ + return logging.getLogger(name) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29