Adding python files

This commit is contained in:
2025-12-05 09:50:50 +02:00
parent c6143cd11a
commit 6bd2b100ca
24 changed files with 3076 additions and 0 deletions

0
src/__init__.py Normal file
View File

0
src/database/__init__.py Normal file
View File

619
src/database/db_manager.py Normal file
View File

@@ -0,0 +1,619 @@
"""
Database manager for the microscopy object detection application.
Handles all database operations including CRUD operations, queries, and exports.
"""
import sqlite3
import json
from datetime import datetime
from typing import List, Dict, Optional, Tuple, Any
from pathlib import Path
import csv
import hashlib
class DatabaseManager:
"""Manages all database operations for the application."""
def __init__(self, db_path: str = "data/detections.db"):
"""
Initialize database manager.
Args:
db_path: Path to SQLite database file
"""
self.db_path = db_path
self._ensure_database_exists()
def _ensure_database_exists(self) -> None:
"""Create database and tables if they don't exist."""
# Create directory if it doesn't exist
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
# Read schema file and execute
schema_path = Path(__file__).parent / "schema.sql"
with open(schema_path, "r") as f:
schema_sql = f.read()
conn = self.get_connection()
try:
conn.executescript(schema_sql)
conn.commit()
finally:
conn.close()
def get_connection(self) -> sqlite3.Connection:
"""Get database connection with proper settings."""
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row # Enable column access by name
conn.execute("PRAGMA foreign_keys = ON") # Enable foreign keys
return conn
# ==================== Model Operations ====================
def add_model(
self,
model_name: str,
model_version: str,
model_path: str,
base_model: str = "yolov8s.pt",
training_params: Optional[Dict] = None,
metrics: Optional[Dict] = None,
) -> int:
"""
Add a new model to the database.
Args:
model_name: Name of the model
model_version: Version string
model_path: Path to model weights file
base_model: Base model used for training
training_params: Dictionary of training parameters
metrics: Dictionary of validation metrics
Returns:
ID of the inserted model
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO models (model_name, model_version, model_path, base_model, training_params, metrics)
VALUES (?, ?, ?, ?, ?, ?)
""",
(
model_name,
model_version,
model_path,
base_model,
json.dumps(training_params) if training_params else None,
json.dumps(metrics) if metrics else None,
),
)
conn.commit()
return cursor.lastrowid
finally:
conn.close()
def get_models(self, filters: Optional[Dict] = None) -> List[Dict]:
"""
Retrieve models from database.
Args:
filters: Optional filters (e.g., {'model_name': 'my_model'})
Returns:
List of model dictionaries
"""
conn = self.get_connection()
try:
query = "SELECT * FROM models"
params = []
if filters:
conditions = []
for key, value in filters.items():
conditions.append(f"{key} = ?")
params.append(value)
query += " WHERE " + " AND ".join(conditions)
query += " ORDER BY created_at DESC"
cursor = conn.cursor()
cursor.execute(query, params)
models = []
for row in cursor.fetchall():
model = dict(row)
# Parse JSON fields
if model["training_params"]:
model["training_params"] = json.loads(model["training_params"])
if model["metrics"]:
model["metrics"] = json.loads(model["metrics"])
models.append(model)
return models
finally:
conn.close()
def get_model_by_id(self, model_id: int) -> Optional[Dict]:
"""Get model by ID."""
models = self.get_models({"id": model_id})
return models[0] if models else None
def update_model(self, model_id: int, updates: Dict) -> bool:
"""Update model fields."""
conn = self.get_connection()
try:
# Build update query
set_clauses = []
params = []
for key, value in updates.items():
if key in ["training_params", "metrics"] and isinstance(value, dict):
value = json.dumps(value)
set_clauses.append(f"{key} = ?")
params.append(value)
params.append(model_id)
query = f"UPDATE models SET {', '.join(set_clauses)} WHERE id = ?"
cursor = conn.cursor()
cursor.execute(query, params)
conn.commit()
return cursor.rowcount > 0
finally:
conn.close()
# ==================== Image Operations ====================
def add_image(
self,
relative_path: str,
filename: str,
width: int,
height: int,
captured_at: Optional[datetime] = None,
checksum: Optional[str] = None,
) -> int:
"""
Add a new image to the database.
Args:
relative_path: Path relative to image repository
filename: Image filename
width: Image width in pixels
height: Image height in pixels
captured_at: When image was captured (if known)
checksum: MD5 checksum of image file
Returns:
ID of the inserted image
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO images (relative_path, filename, width, height, captured_at, checksum)
VALUES (?, ?, ?, ?, ?, ?)
""",
(relative_path, filename, width, height, captured_at, checksum),
)
conn.commit()
return cursor.lastrowid
except sqlite3.IntegrityError:
# Image already exists, return its ID
cursor.execute(
"SELECT id FROM images WHERE relative_path = ?", (relative_path,)
)
row = cursor.fetchone()
return row["id"] if row else None
finally:
conn.close()
def get_image_by_path(self, relative_path: str) -> Optional[Dict]:
"""Get image by relative path."""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"SELECT * FROM images WHERE relative_path = ?", (relative_path,)
)
row = cursor.fetchone()
return dict(row) if row else None
finally:
conn.close()
def get_or_create_image(
self, relative_path: str, filename: str, width: int, height: int
) -> int:
"""Get existing image or create new one."""
existing = self.get_image_by_path(relative_path)
if existing:
return existing["id"]
return self.add_image(relative_path, filename, width, height)
# ==================== Detection Operations ====================
def add_detection(
self,
image_id: int,
model_id: int,
class_name: str,
bbox: Tuple[float, float, float, float], # (x_min, y_min, x_max, y_max)
confidence: float,
metadata: Optional[Dict] = None,
) -> int:
"""
Add a new detection to the database.
Args:
image_id: ID of the image
model_id: ID of the model used
class_name: Detected object class
bbox: Bounding box coordinates (normalized 0-1)
confidence: Detection confidence score
metadata: Additional metadata
Returns:
ID of the inserted detection
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
x_min, y_min, x_max, y_max = bbox
cursor.execute(
"""
INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
image_id,
model_id,
class_name,
x_min,
y_min,
x_max,
y_max,
confidence,
json.dumps(metadata) if metadata else None,
),
)
conn.commit()
return cursor.lastrowid
finally:
conn.close()
def add_detections_batch(self, detections: List[Dict]) -> int:
"""
Add multiple detections in a single transaction.
Args:
detections: List of detection dictionaries
Returns:
Number of detections inserted
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
for det in detections:
bbox = det["bbox"]
cursor.execute(
"""
INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
det["image_id"],
det["model_id"],
det["class_name"],
bbox[0],
bbox[1],
bbox[2],
bbox[3],
det["confidence"],
(
json.dumps(det.get("metadata"))
if det.get("metadata")
else None
),
),
)
conn.commit()
return len(detections)
finally:
conn.close()
def get_detections(
self,
filters: Optional[Dict] = None,
limit: Optional[int] = None,
offset: int = 0,
) -> List[Dict]:
"""
Retrieve detections from database.
Args:
filters: Optional filters for querying
limit: Maximum number of results
offset: Number of results to skip
Returns:
List of detection dictionaries with joined data
"""
conn = self.get_connection()
try:
query = """
SELECT
d.*,
i.relative_path as image_path,
i.filename as image_filename,
i.width as image_width,
i.height as image_height,
m.model_name,
m.model_version
FROM detections d
JOIN images i ON d.image_id = i.id
JOIN models m ON d.model_id = m.id
"""
params = []
if filters:
conditions = []
for key, value in filters.items():
if (
key.startswith("d.")
or key.startswith("i.")
or key.startswith("m.")
):
conditions.append(f"{key} = ?")
else:
conditions.append(f"d.{key} = ?")
params.append(value)
query += " WHERE " + " AND ".join(conditions)
query += " ORDER BY d.detected_at DESC"
if limit:
query += f" LIMIT {limit} OFFSET {offset}"
cursor = conn.cursor()
cursor.execute(query, params)
detections = []
for row in cursor.fetchall():
det = dict(row)
# Parse JSON metadata
if det.get("metadata"):
det["metadata"] = json.loads(det["metadata"])
detections.append(det)
return detections
finally:
conn.close()
def get_detections_for_image(
self, image_id: int, model_id: Optional[int] = None
) -> List[Dict]:
"""Get all detections for a specific image."""
filters = {"image_id": image_id}
if model_id:
filters["model_id"] = model_id
return self.get_detections(filters)
def delete_detections_for_model(self, model_id: int) -> int:
"""Delete all detections for a specific model."""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("DELETE FROM detections WHERE model_id = ?", (model_id,))
conn.commit()
return cursor.rowcount
finally:
conn.close()
# ==================== Statistics Operations ====================
def get_detection_statistics(
self, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None
) -> Dict:
"""
Get detection statistics for a date range.
Returns:
Dictionary with statistics (count by class, confidence distribution, etc.)
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
# Build date filter
date_filter = ""
params = []
if start_date:
date_filter += " AND detected_at >= ?"
params.append(start_date)
if end_date:
date_filter += " AND detected_at <= ?"
params.append(end_date)
# Total detections
cursor.execute(
f"SELECT COUNT(*) as count FROM detections WHERE 1=1{date_filter}",
params,
)
total_count = cursor.fetchone()["count"]
# Count by class
cursor.execute(
f"""
SELECT class_name, COUNT(*) as count
FROM detections
WHERE 1=1{date_filter}
GROUP BY class_name
ORDER BY count DESC
""",
params,
)
class_counts = {
row["class_name"]: row["count"] for row in cursor.fetchall()
}
# Average confidence
cursor.execute(
f"SELECT AVG(confidence) as avg_conf FROM detections WHERE 1=1{date_filter}",
params,
)
avg_confidence = cursor.fetchone()["avg_conf"] or 0
# Confidence distribution
cursor.execute(
f"""
SELECT
CASE
WHEN confidence < 0.3 THEN 'low'
WHEN confidence < 0.7 THEN 'medium'
ELSE 'high'
END as conf_level,
COUNT(*) as count
FROM detections
WHERE 1=1{date_filter}
GROUP BY conf_level
""",
params,
)
conf_dist = {row["conf_level"]: row["count"] for row in cursor.fetchall()}
return {
"total_detections": total_count,
"class_counts": class_counts,
"average_confidence": avg_confidence,
"confidence_distribution": conf_dist,
}
finally:
conn.close()
def get_class_distribution(self, model_id: Optional[int] = None) -> Dict[str, int]:
"""Get count of detections per class."""
conn = self.get_connection()
try:
cursor = conn.cursor()
query = "SELECT class_name, COUNT(*) as count FROM detections"
params = []
if model_id:
query += " WHERE model_id = ?"
params.append(model_id)
query += " GROUP BY class_name ORDER BY count DESC"
cursor.execute(query, params)
return {row["class_name"]: row["count"] for row in cursor.fetchall()}
finally:
conn.close()
# ==================== Export Operations ====================
def export_detections_to_csv(
self, output_path: str, filters: Optional[Dict] = None
) -> bool:
"""Export detections to CSV file."""
try:
detections = self.get_detections(filters)
with open(output_path, "w", newline="") as csvfile:
if not detections:
return True
fieldnames = [
"id",
"image_path",
"model_name",
"model_version",
"class_name",
"x_min",
"y_min",
"x_max",
"y_max",
"confidence",
"detected_at",
]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for det in detections:
row = {k: det[k] for k in fieldnames if k in det}
writer.writerow(row)
return True
except Exception as e:
print(f"Error exporting to CSV: {e}")
return False
def export_detections_to_json(
self, output_path: str, filters: Optional[Dict] = None
) -> bool:
"""Export detections to JSON file."""
try:
detections = self.get_detections(filters)
# Convert datetime objects to strings
for det in detections:
if isinstance(det.get("detected_at"), datetime):
det["detected_at"] = det["detected_at"].isoformat()
with open(output_path, "w") as jsonfile:
json.dump(detections, jsonfile, indent=2)
return True
except Exception as e:
print(f"Error exporting to JSON: {e}")
return False
# ==================== Annotation Operations ====================
def add_annotation(
self,
image_id: int,
class_name: str,
bbox: Tuple[float, float, float, float],
annotator: str,
verified: bool = False,
) -> int:
"""Add manual annotation."""
conn = self.get_connection()
try:
cursor = conn.cursor()
x_min, y_min, x_max, y_max = bbox
cursor.execute(
"""
INSERT INTO annotations (image_id, class_name, x_min, y_min, x_max, y_max, annotator, verified)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(image_id, class_name, x_min, y_min, x_max, y_max, annotator, verified),
)
conn.commit()
return cursor.lastrowid
finally:
conn.close()
def get_annotations_for_image(self, image_id: int) -> List[Dict]:
"""Get all annotations for an image."""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("SELECT * FROM annotations WHERE image_id = ?", (image_id,))
return [dict(row) for row in cursor.fetchall()]
finally:
conn.close()
@staticmethod
def calculate_checksum(file_path: str) -> str:
"""Calculate MD5 checksum of a file."""
md5_hash = hashlib.md5()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
md5_hash.update(chunk)
return md5_hash.hexdigest()

63
src/database/models.py Normal file
View File

@@ -0,0 +1,63 @@
"""
Data models for the microscopy object detection application.
These dataclasses represent the database entities.
"""
from dataclasses import dataclass
from datetime import datetime
from typing import Optional, Dict, Tuple
@dataclass
class Model:
"""Represents a trained model."""
id: Optional[int]
model_name: str
model_version: str
model_path: str
base_model: str
created_at: datetime
training_params: Optional[Dict]
metrics: Optional[Dict]
@dataclass
class Image:
"""Represents an image in the database."""
id: Optional[int]
relative_path: str
filename: str
width: int
height: int
captured_at: Optional[datetime]
added_at: datetime
checksum: Optional[str]
@dataclass
class Detection:
"""Represents a detection result."""
id: Optional[int]
image_id: int
model_id: int
class_name: str
bbox: Tuple[float, float, float, float] # (x_min, y_min, x_max, y_max)
confidence: float
detected_at: datetime
metadata: Optional[Dict]
@dataclass
class Annotation:
"""Represents a manual annotation."""
id: Optional[int]
image_id: int
class_name: str
bbox: Tuple[float, float, float, float] # (x_min, y_min, x_max, y_max)
annotator: str
created_at: datetime
verified: bool

70
src/database/schema.sql Normal file
View File

@@ -0,0 +1,70 @@
-- Microscopy Object Detection Application - Database Schema
-- SQLite Database Schema for storing models, images, detections, and annotations
-- Models table: stores trained model information
CREATE TABLE IF NOT EXISTS models (
id INTEGER PRIMARY KEY AUTOINCREMENT,
model_name TEXT NOT NULL,
model_version TEXT NOT NULL,
model_path TEXT NOT NULL,
base_model TEXT NOT NULL DEFAULT 'yolov8s.pt',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
training_params TEXT, -- JSON string of training parameters
metrics TEXT, -- JSON string of validation metrics
UNIQUE(model_name, model_version)
);
-- Images table: stores image metadata
CREATE TABLE IF NOT EXISTS images (
id INTEGER PRIMARY KEY AUTOINCREMENT,
relative_path TEXT NOT NULL UNIQUE,
filename TEXT NOT NULL,
width INTEGER,
height INTEGER,
captured_at TIMESTAMP,
added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
checksum TEXT
);
-- Detections table: stores detection results
CREATE TABLE IF NOT EXISTS detections (
id INTEGER PRIMARY KEY AUTOINCREMENT,
image_id INTEGER NOT NULL,
model_id INTEGER NOT NULL,
class_name TEXT NOT NULL,
x_min REAL NOT NULL CHECK(x_min >= 0 AND x_min <= 1),
y_min REAL NOT NULL CHECK(y_min >= 0 AND y_min <= 1),
x_max REAL NOT NULL CHECK(x_max >= 0 AND x_max <= 1),
y_max REAL NOT NULL CHECK(y_max >= 0 AND y_max <= 1),
confidence REAL NOT NULL CHECK(confidence >= 0 AND confidence <= 1),
detected_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
metadata TEXT, -- JSON string for additional metadata
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE,
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE
);
-- Annotations table: stores manual annotations (future feature)
CREATE TABLE IF NOT EXISTS annotations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
image_id INTEGER NOT NULL,
class_name TEXT NOT NULL,
x_min REAL NOT NULL CHECK(x_min >= 0 AND x_min <= 1),
y_min REAL NOT NULL CHECK(y_min >= 0 AND y_min <= 1),
x_max REAL NOT NULL CHECK(x_max >= 0 AND x_max <= 1),
y_max REAL NOT NULL CHECK(y_max >= 0 AND y_max <= 1),
annotator TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
verified BOOLEAN DEFAULT 0,
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE
);
-- Create indexes for performance optimization
CREATE INDEX IF NOT EXISTS idx_detections_image_id ON detections(image_id);
CREATE INDEX IF NOT EXISTS idx_detections_model_id ON detections(model_id);
CREATE INDEX IF NOT EXISTS idx_detections_class_name ON detections(class_name);
CREATE INDEX IF NOT EXISTS idx_detections_detected_at ON detections(detected_at);
CREATE INDEX IF NOT EXISTS idx_detections_confidence ON detections(confidence);
CREATE INDEX IF NOT EXISTS idx_images_relative_path ON images(relative_path);
CREATE INDEX IF NOT EXISTS idx_images_added_at ON images(added_at);
CREATE INDEX IF NOT EXISTS idx_annotations_image_id ON annotations(image_id);
CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at);

0
src/gui/__init__.py Normal file
View File

View File

View File

@@ -0,0 +1,291 @@
"""
Configuration dialog for the microscopy object detection application.
"""
from PySide6.QtWidgets import (
QDialog,
QVBoxLayout,
QHBoxLayout,
QFormLayout,
QPushButton,
QLineEdit,
QSpinBox,
QDoubleSpinBox,
QFileDialog,
QTabWidget,
QWidget,
QLabel,
QGroupBox,
)
from PySide6.QtCore import Qt
from src.utils.config_manager import ConfigManager
from src.utils.logger import get_logger
logger = get_logger(__name__)
class ConfigDialog(QDialog):
"""Configuration dialog window."""
def __init__(self, config_manager: ConfigManager, parent=None):
super().__init__(parent)
self.config_manager = config_manager
self.setWindowTitle("Settings")
self.setMinimumWidth(500)
self.setMinimumHeight(400)
self._setup_ui()
self._load_settings()
def _setup_ui(self):
"""Setup user interface."""
layout = QVBoxLayout()
# Create tab widget for different setting categories
self.tab_widget = QTabWidget()
# General settings tab
general_tab = self._create_general_tab()
self.tab_widget.addTab(general_tab, "General")
# Training settings tab
training_tab = self._create_training_tab()
self.tab_widget.addTab(training_tab, "Training")
# Detection settings tab
detection_tab = self._create_detection_tab()
self.tab_widget.addTab(detection_tab, "Detection")
layout.addWidget(self.tab_widget)
# Buttons
button_layout = QHBoxLayout()
button_layout.addStretch()
self.save_button = QPushButton("Save")
self.save_button.clicked.connect(self.accept)
button_layout.addWidget(self.save_button)
self.cancel_button = QPushButton("Cancel")
self.cancel_button.clicked.connect(self.reject)
button_layout.addWidget(self.cancel_button)
layout.addLayout(button_layout)
self.setLayout(layout)
def _create_general_tab(self) -> QWidget:
"""Create general settings tab."""
widget = QWidget()
layout = QVBoxLayout()
# Image repository group
repo_group = QGroupBox("Image Repository")
repo_layout = QFormLayout()
# Repository path
path_layout = QHBoxLayout()
self.repo_path_edit = QLineEdit()
self.repo_path_edit.setPlaceholderText("Path to image repository")
path_layout.addWidget(self.repo_path_edit)
browse_button = QPushButton("Browse...")
browse_button.clicked.connect(self._browse_repository)
path_layout.addWidget(browse_button)
repo_layout.addRow("Base Path:", path_layout)
repo_group.setLayout(repo_layout)
layout.addWidget(repo_group)
# Database group
db_group = QGroupBox("Database")
db_layout = QFormLayout()
self.db_path_edit = QLineEdit()
self.db_path_edit.setPlaceholderText("Path to database file")
db_layout.addRow("Database Path:", self.db_path_edit)
db_group.setLayout(db_layout)
layout.addWidget(db_group)
# Models group
models_group = QGroupBox("Models")
models_layout = QFormLayout()
self.models_dir_edit = QLineEdit()
self.models_dir_edit.setPlaceholderText("Directory for saved models")
models_layout.addRow("Models Directory:", self.models_dir_edit)
self.base_model_edit = QLineEdit()
self.base_model_edit.setPlaceholderText("yolov8s.pt")
models_layout.addRow("Default Base Model:", self.base_model_edit)
models_group.setLayout(models_layout)
layout.addWidget(models_group)
layout.addStretch()
widget.setLayout(layout)
return widget
def _create_training_tab(self) -> QWidget:
"""Create training settings tab."""
widget = QWidget()
layout = QVBoxLayout()
form_layout = QFormLayout()
# Epochs
self.epochs_spin = QSpinBox()
self.epochs_spin.setRange(1, 1000)
self.epochs_spin.setValue(100)
form_layout.addRow("Default Epochs:", self.epochs_spin)
# Batch size
self.batch_size_spin = QSpinBox()
self.batch_size_spin.setRange(1, 128)
self.batch_size_spin.setValue(16)
form_layout.addRow("Default Batch Size:", self.batch_size_spin)
# Image size
self.imgsz_spin = QSpinBox()
self.imgsz_spin.setRange(320, 1280)
self.imgsz_spin.setSingleStep(32)
self.imgsz_spin.setValue(640)
form_layout.addRow("Default Image Size:", self.imgsz_spin)
# Patience
self.patience_spin = QSpinBox()
self.patience_spin.setRange(1, 200)
self.patience_spin.setValue(50)
form_layout.addRow("Default Patience:", self.patience_spin)
# Learning rate
self.lr_spin = QDoubleSpinBox()
self.lr_spin.setRange(0.0001, 0.1)
self.lr_spin.setSingleStep(0.001)
self.lr_spin.setDecimals(4)
self.lr_spin.setValue(0.01)
form_layout.addRow("Default Learning Rate:", self.lr_spin)
layout.addLayout(form_layout)
layout.addStretch()
widget.setLayout(layout)
return widget
def _create_detection_tab(self) -> QWidget:
"""Create detection settings tab."""
widget = QWidget()
layout = QVBoxLayout()
form_layout = QFormLayout()
# Confidence threshold
self.conf_spin = QDoubleSpinBox()
self.conf_spin.setRange(0.0, 1.0)
self.conf_spin.setSingleStep(0.05)
self.conf_spin.setDecimals(2)
self.conf_spin.setValue(0.25)
form_layout.addRow("Default Confidence:", self.conf_spin)
# IoU threshold
self.iou_spin = QDoubleSpinBox()
self.iou_spin.setRange(0.0, 1.0)
self.iou_spin.setSingleStep(0.05)
self.iou_spin.setDecimals(2)
self.iou_spin.setValue(0.45)
form_layout.addRow("Default IoU:", self.iou_spin)
# Max batch size
self.max_batch_spin = QSpinBox()
self.max_batch_spin.setRange(1, 1000)
self.max_batch_spin.setValue(100)
form_layout.addRow("Max Batch Size:", self.max_batch_spin)
layout.addLayout(form_layout)
layout.addStretch()
widget.setLayout(layout)
return widget
def _browse_repository(self):
"""Browse for image repository directory."""
directory = QFileDialog.getExistingDirectory(
self, "Select Image Repository", self.repo_path_edit.text()
)
if directory:
self.repo_path_edit.setText(directory)
def _load_settings(self):
"""Load current settings into dialog."""
# General settings
self.repo_path_edit.setText(
self.config_manager.get("image_repository.base_path", "")
)
self.db_path_edit.setText(
self.config_manager.get("database.path", "data/detections.db")
)
self.models_dir_edit.setText(
self.config_manager.get("models.models_directory", "data/models")
)
self.base_model_edit.setText(
self.config_manager.get("models.default_base_model", "yolov8s.pt")
)
# Training settings
self.epochs_spin.setValue(
self.config_manager.get("training.default_epochs", 100)
)
self.batch_size_spin.setValue(
self.config_manager.get("training.default_batch_size", 16)
)
self.imgsz_spin.setValue(self.config_manager.get("training.default_imgsz", 640))
self.patience_spin.setValue(
self.config_manager.get("training.default_patience", 50)
)
self.lr_spin.setValue(self.config_manager.get("training.default_lr0", 0.01))
# Detection settings
self.conf_spin.setValue(
self.config_manager.get("detection.default_confidence", 0.25)
)
self.iou_spin.setValue(self.config_manager.get("detection.default_iou", 0.45))
self.max_batch_spin.setValue(
self.config_manager.get("detection.max_batch_size", 100)
)
def accept(self):
"""Save settings and close dialog."""
logger.info("Saving configuration")
# Save general settings
self.config_manager.set(
"image_repository.base_path", self.repo_path_edit.text()
)
self.config_manager.set("database.path", self.db_path_edit.text())
self.config_manager.set("models.models_directory", self.models_dir_edit.text())
self.config_manager.set(
"models.default_base_model", self.base_model_edit.text()
)
# Save training settings
self.config_manager.set("training.default_epochs", self.epochs_spin.value())
self.config_manager.set(
"training.default_batch_size", self.batch_size_spin.value()
)
self.config_manager.set("training.default_imgsz", self.imgsz_spin.value())
self.config_manager.set("training.default_patience", self.patience_spin.value())
self.config_manager.set("training.default_lr0", self.lr_spin.value())
# Save detection settings
self.config_manager.set("detection.default_confidence", self.conf_spin.value())
self.config_manager.set("detection.default_iou", self.iou_spin.value())
self.config_manager.set("detection.max_batch_size", self.max_batch_spin.value())
# Save to file
self.config_manager.save_config()
super().accept()

282
src/gui/main_window.py Normal file
View File

@@ -0,0 +1,282 @@
"""
Main window for the microscopy object detection application.
"""
from PySide6.QtWidgets import (
QMainWindow,
QTabWidget,
QMenuBar,
QMenu,
QStatusBar,
QMessageBox,
QWidget,
QVBoxLayout,
QLabel,
)
from PySide6.QtCore import Qt, QTimer
from PySide6.QtGui import QAction, QKeySequence
from src.database.db_manager import DatabaseManager
from src.utils.config_manager import ConfigManager
from src.utils.logger import get_logger
from src.gui.dialogs.config_dialog import ConfigDialog
from src.gui.tabs.detection_tab import DetectionTab
from src.gui.tabs.training_tab import TrainingTab
from src.gui.tabs.validation_tab import ValidationTab
from src.gui.tabs.results_tab import ResultsTab
from src.gui.tabs.annotation_tab import AnnotationTab
logger = get_logger(__name__)
class MainWindow(QMainWindow):
"""Main application window."""
def __init__(self):
super().__init__()
# Initialize managers
self.config_manager = ConfigManager()
db_path = self.config_manager.get_database_path()
self.db_manager = DatabaseManager(db_path)
logger.info("Main window initializing")
# Setup UI
self.setWindowTitle("Microscopy Object Detection")
self.setMinimumSize(1200, 800)
self._create_menu_bar()
self._create_tab_widget()
self._create_status_bar()
# Center window on screen
self._center_window()
logger.info("Main window initialized")
def _create_menu_bar(self):
"""Create application menu bar."""
menubar = self.menuBar()
# File menu
file_menu = menubar.addMenu("&File")
settings_action = QAction("&Settings", self)
settings_action.setShortcut(QKeySequence("Ctrl+,"))
settings_action.triggered.connect(self._show_settings)
file_menu.addAction(settings_action)
file_menu.addSeparator()
exit_action = QAction("E&xit", self)
exit_action.setShortcut(QKeySequence("Ctrl+Q"))
exit_action.triggered.connect(self.close)
file_menu.addAction(exit_action)
# View menu
view_menu = menubar.addMenu("&View")
refresh_action = QAction("&Refresh", self)
refresh_action.setShortcut(QKeySequence("F5"))
refresh_action.triggered.connect(self._refresh_current_tab)
view_menu.addAction(refresh_action)
# Tools menu
tools_menu = menubar.addMenu("&Tools")
db_stats_action = QAction("Database &Statistics", self)
db_stats_action.triggered.connect(self._show_database_stats)
tools_menu.addAction(db_stats_action)
# Help menu
help_menu = menubar.addMenu("&Help")
about_action = QAction("&About", self)
about_action.triggered.connect(self._show_about)
help_menu.addAction(about_action)
docs_action = QAction("&Documentation", self)
docs_action.triggered.connect(self._show_documentation)
help_menu.addAction(docs_action)
def _create_tab_widget(self):
"""Create main tab widget with all tabs."""
self.tab_widget = QTabWidget()
self.tab_widget.setTabPosition(QTabWidget.North)
# Create tabs
try:
self.detection_tab = DetectionTab(self.db_manager, self.config_manager)
self.training_tab = TrainingTab(self.db_manager, self.config_manager)
self.validation_tab = ValidationTab(self.db_manager, self.config_manager)
self.results_tab = ResultsTab(self.db_manager, self.config_manager)
self.annotation_tab = AnnotationTab(self.db_manager, self.config_manager)
# Add tabs to widget
self.tab_widget.addTab(self.detection_tab, "Detection")
self.tab_widget.addTab(self.training_tab, "Training")
self.tab_widget.addTab(self.validation_tab, "Validation")
self.tab_widget.addTab(self.results_tab, "Results")
self.tab_widget.addTab(self.annotation_tab, "Annotation (Future)")
# Connect tab change signal
self.tab_widget.currentChanged.connect(self._on_tab_changed)
except Exception as e:
logger.error(f"Error creating tabs: {e}")
# Create placeholder
placeholder = QWidget()
layout = QVBoxLayout()
layout.addWidget(QLabel(f"Error creating tabs: {e}"))
placeholder.setLayout(layout)
self.tab_widget.addTab(placeholder, "Error")
self.setCentralWidget(self.tab_widget)
def _create_status_bar(self):
"""Create status bar."""
self.status_bar = QStatusBar()
self.setStatusBar(self.status_bar)
# Add permanent widgets to status bar
self.status_label = QLabel("Ready")
self.status_bar.addWidget(self.status_label)
# Initial status message
self._update_status("Ready")
def _center_window(self):
"""Center window on screen."""
screen = self.screen().geometry()
size = self.geometry()
self.move(
(screen.width() - size.width()) // 2, (screen.height() - size.height()) // 2
)
def _show_settings(self):
"""Show settings dialog."""
logger.info("Opening settings dialog")
dialog = ConfigDialog(self.config_manager, self)
if dialog.exec():
self._apply_settings()
self._update_status("Settings saved")
def _apply_settings(self):
"""Apply changed settings."""
logger.info("Applying settings changes")
# Reload configuration in all tabs if needed
try:
if hasattr(self, "detection_tab"):
self.detection_tab.refresh()
if hasattr(self, "training_tab"):
self.training_tab.refresh()
if hasattr(self, "results_tab"):
self.results_tab.refresh()
except Exception as e:
logger.error(f"Error applying settings: {e}")
def _refresh_current_tab(self):
"""Refresh the current tab."""
current_widget = self.tab_widget.currentWidget()
if hasattr(current_widget, "refresh"):
current_widget.refresh()
self._update_status("Tab refreshed")
def _on_tab_changed(self, index: int):
"""Handle tab change event."""
tab_name = self.tab_widget.tabText(index)
logger.debug(f"Switched to tab: {tab_name}")
self._update_status(f"Viewing: {tab_name}")
def _show_database_stats(self):
"""Show database statistics dialog."""
try:
stats = self.db_manager.get_detection_statistics()
message = f"""
<h3>Database Statistics</h3>
<p><b>Total Detections:</b> {stats.get('total_detections', 0)}</p>
<p><b>Average Confidence:</b> {stats.get('average_confidence', 0):.2%}</p>
<p><b>Classes:</b></p>
<ul>
"""
for class_name, count in stats.get("class_counts", {}).items():
message += f"<li>{class_name}: {count}</li>"
message += "</ul>"
QMessageBox.information(self, "Database Statistics", message)
except Exception as e:
logger.error(f"Error getting database stats: {e}")
QMessageBox.warning(
self, "Error", f"Failed to get database statistics:\n{str(e)}"
)
def _show_about(self):
"""Show about dialog."""
about_text = """
<h2>Microscopy Object Detection Application</h2>
<p><b>Version:</b> 1.0.0</p>
<p>A desktop application for detecting organelles and membrane branching
structures in microscopy images using YOLOv8.</p>
<p><b>Features:</b></p>
<ul>
<li>Object detection with YOLOv8</li>
<li>Model training and validation</li>
<li>Detection results storage</li>
<li>Interactive visualization</li>
<li>Export capabilities</li>
</ul>
<p><b>Technologies:</b></p>
<ul>
<li>Ultralytics YOLOv8</li>
<li>PySide6</li>
<li>pyqtgraph</li>
<li>SQLite</li>
</ul>
"""
QMessageBox.about(self, "About", about_text)
def _show_documentation(self):
"""Show documentation."""
QMessageBox.information(
self,
"Documentation",
"Please refer to README.md and ARCHITECTURE.md files in the project directory.",
)
def _update_status(self, message: str, timeout: int = 5000):
"""
Update status bar message.
Args:
message: Status message to display
timeout: Time in milliseconds to show message (0 for permanent)
"""
self.status_label.setText(message)
if timeout > 0:
QTimer.singleShot(timeout, lambda: self.status_label.setText("Ready"))
def closeEvent(self, event):
"""Handle window close event."""
reply = QMessageBox.question(
self,
"Confirm Exit",
"Are you sure you want to exit?",
QMessageBox.Yes | QMessageBox.No,
QMessageBox.No,
)
if reply == QMessageBox.Yes:
logger.info("Application closing")
event.accept()
else:
event.ignore()

0
src/gui/tabs/__init__.py Normal file
View File

View File

@@ -0,0 +1,48 @@
"""
Annotation tab for the microscopy object detection application.
Future feature for manual annotation.
"""
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
from src.database.db_manager import DatabaseManager
from src.utils.config_manager import ConfigManager
class AnnotationTab(QWidget):
"""Annotation tab placeholder (future feature)."""
def __init__(
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
):
super().__init__(parent)
self.db_manager = db_manager
self.config_manager = config_manager
self._setup_ui()
def _setup_ui(self):
"""Setup user interface."""
layout = QVBoxLayout()
group = QGroupBox("Annotation Tool (Future Feature)")
group_layout = QVBoxLayout()
label = QLabel(
"Annotation functionality will be implemented in future version.\n\n"
"Planned Features:\n"
"- Image browser\n"
"- Drawing tools for bounding boxes\n"
"- Class label assignment\n"
"- Export annotations to YOLO format\n"
"- Annotation verification"
)
group_layout.addWidget(label)
group.setLayout(group_layout)
layout.addWidget(group)
layout.addStretch()
self.setLayout(layout)
def refresh(self):
"""Refresh the tab."""
pass

View File

@@ -0,0 +1,344 @@
"""
Detection tab for the microscopy object detection application.
Handles single image and batch detection.
"""
from PySide6.QtWidgets import (
QWidget,
QVBoxLayout,
QHBoxLayout,
QPushButton,
QLabel,
QComboBox,
QSlider,
QFileDialog,
QMessageBox,
QProgressBar,
QTextEdit,
QGroupBox,
QFormLayout,
)
from PySide6.QtCore import Qt, QThread, Signal
from pathlib import Path
from src.database.db_manager import DatabaseManager
from src.utils.config_manager import ConfigManager
from src.utils.logger import get_logger
from src.utils.file_utils import get_image_files
from src.model.inference import InferenceEngine
logger = get_logger(__name__)
class DetectionWorker(QThread):
"""Worker thread for running detection."""
progress = Signal(int, int, str) # current, total, message
finished = Signal(list) # results
error = Signal(str) # error message
def __init__(self, engine, image_paths, repo_root, conf):
super().__init__()
self.engine = engine
self.image_paths = image_paths
self.repo_root = repo_root
self.conf = conf
def run(self):
"""Run detection in background thread."""
try:
results = self.engine.detect_batch(
self.image_paths, self.repo_root, self.conf, self.progress.emit
)
self.finished.emit(results)
except Exception as e:
logger.error(f"Detection error: {e}")
self.error.emit(str(e))
class DetectionTab(QWidget):
"""Detection tab for single image and batch detection."""
def __init__(
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
):
super().__init__(parent)
self.db_manager = db_manager
self.config_manager = config_manager
self.inference_engine = None
self.current_model_id = None
self._setup_ui()
self._connect_signals()
self._load_models()
def _setup_ui(self):
"""Setup user interface."""
layout = QVBoxLayout()
# Model selection group
model_group = QGroupBox("Model Selection")
model_layout = QFormLayout()
self.model_combo = QComboBox()
self.model_combo.addItem("No models available", None)
model_layout.addRow("Model:", self.model_combo)
model_group.setLayout(model_layout)
layout.addWidget(model_group)
# Detection settings group
settings_group = QGroupBox("Detection Settings")
settings_layout = QFormLayout()
# Confidence threshold
conf_layout = QHBoxLayout()
self.conf_slider = QSlider(Qt.Horizontal)
self.conf_slider.setRange(0, 100)
self.conf_slider.setValue(25)
self.conf_slider.setTickPosition(QSlider.TicksBelow)
self.conf_slider.setTickInterval(10)
conf_layout.addWidget(self.conf_slider)
self.conf_label = QLabel("0.25")
conf_layout.addWidget(self.conf_label)
settings_layout.addRow("Confidence:", conf_layout)
settings_group.setLayout(settings_layout)
layout.addWidget(settings_group)
# Action buttons
button_layout = QHBoxLayout()
self.single_image_btn = QPushButton("Detect Single Image")
self.single_image_btn.clicked.connect(self._detect_single_image)
button_layout.addWidget(self.single_image_btn)
self.batch_btn = QPushButton("Detect Batch (Folder)")
self.batch_btn.clicked.connect(self._detect_batch)
button_layout.addWidget(self.batch_btn)
layout.addLayout(button_layout)
# Progress bar
self.progress_bar = QProgressBar()
self.progress_bar.setVisible(False)
layout.addWidget(self.progress_bar)
# Results display
results_group = QGroupBox("Detection Results")
results_layout = QVBoxLayout()
self.results_text = QTextEdit()
self.results_text.setReadOnly(True)
self.results_text.setMaximumHeight(200)
results_layout.addWidget(self.results_text)
results_group.setLayout(results_layout)
layout.addWidget(results_group)
layout.addStretch()
self.setLayout(layout)
def _connect_signals(self):
"""Connect signals and slots."""
self.conf_slider.valueChanged.connect(self._update_confidence_label)
self.model_combo.currentIndexChanged.connect(self._on_model_changed)
def _load_models(self):
"""Load available models from database."""
try:
models = self.db_manager.get_models()
self.model_combo.clear()
if not models:
self.model_combo.addItem("No models available", None)
self._set_buttons_enabled(False)
return
# Add base model option
base_model = self.config_manager.get(
"models.default_base_model", "yolov8s.pt"
)
self.model_combo.addItem(
f"Base Model ({base_model})", {"id": 0, "path": base_model}
)
# Add trained models
for model in models:
display_name = f"{model['model_name']} v{model['model_version']}"
self.model_combo.addItem(display_name, model)
self._set_buttons_enabled(True)
except Exception as e:
logger.error(f"Error loading models: {e}")
QMessageBox.warning(self, "Error", f"Failed to load models:\n{str(e)}")
def _on_model_changed(self, index: int):
"""Handle model selection change."""
model_data = self.model_combo.itemData(index)
if model_data and model_data["id"] != 0:
self.current_model_id = model_data["id"]
else:
self.current_model_id = None
def _update_confidence_label(self, value: int):
"""Update confidence label."""
conf = value / 100.0
self.conf_label.setText(f"{conf:.2f}")
def _detect_single_image(self):
"""Detect objects in a single image."""
# Get image file
repo_path = self.config_manager.get_image_repository_path()
start_dir = repo_path if repo_path else ""
file_path, _ = QFileDialog.getOpenFileName(
self,
"Select Image",
start_dir,
"Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)",
)
if not file_path:
return
# Run detection
self._run_detection([file_path])
def _detect_batch(self):
"""Detect objects in batch (folder)."""
# Get folder
repo_path = self.config_manager.get_image_repository_path()
start_dir = repo_path if repo_path else ""
folder_path = QFileDialog.getExistingDirectory(self, "Select Folder", start_dir)
if not folder_path:
return
# Get all image files
allowed_ext = self.config_manager.get_allowed_extensions()
image_files = get_image_files(folder_path, allowed_ext, recursive=False)
if not image_files:
QMessageBox.information(
self, "No Images", "No image files found in selected folder."
)
return
# Confirm batch processing
reply = QMessageBox.question(
self,
"Confirm Batch Detection",
f"Process {len(image_files)} images?",
QMessageBox.Yes | QMessageBox.No,
)
if reply == QMessageBox.Yes:
self._run_detection(image_files)
def _run_detection(self, image_paths: list):
"""Run detection on image list."""
try:
# Get selected model
model_data = self.model_combo.currentData()
if not model_data:
QMessageBox.warning(self, "No Model", "Please select a model first.")
return
model_path = model_data["path"]
model_id = model_data["id"]
# Ensure we have a valid model ID (create entry for base model if needed)
if model_id == 0:
# Create database entry for base model
base_model = self.config_manager.get(
"models.default_base_model", "yolov8s.pt"
)
model_id = self.db_manager.add_model(
model_name="Base Model",
model_version="pretrained",
model_path=base_model,
base_model=base_model,
)
# Create inference engine
self.inference_engine = InferenceEngine(
model_path, self.db_manager, model_id
)
# Get confidence threshold
conf = self.conf_slider.value() / 100.0
# Get repository root
repo_root = self.config_manager.get_image_repository_path()
if not repo_root:
repo_root = str(Path(image_paths[0]).parent)
# Show progress bar
self.progress_bar.setVisible(True)
self.progress_bar.setMaximum(len(image_paths))
self._set_buttons_enabled(False)
# Create and start worker thread
self.worker = DetectionWorker(
self.inference_engine, image_paths, repo_root, conf
)
self.worker.progress.connect(self._on_progress)
self.worker.finished.connect(self._on_detection_finished)
self.worker.error.connect(self._on_detection_error)
self.worker.start()
except Exception as e:
logger.error(f"Error starting detection: {e}")
QMessageBox.critical(self, "Error", f"Failed to start detection:\n{str(e)}")
self._set_buttons_enabled(True)
def _on_progress(self, current: int, total: int, message: str):
"""Handle progress update."""
self.progress_bar.setValue(current)
self.results_text.append(f"[{current}/{total}] {message}")
def _on_detection_finished(self, results: list):
"""Handle detection completion."""
self.progress_bar.setVisible(False)
self._set_buttons_enabled(True)
# Calculate statistics
total_detections = sum(r["count"] for r in results)
successful = sum(1 for r in results if r.get("success", False))
summary = f"\n=== Detection Complete ===\n"
summary += f"Processed: {len(results)} images\n"
summary += f"Successful: {successful}\n"
summary += f"Total detections: {total_detections}\n"
self.results_text.append(summary)
QMessageBox.information(
self,
"Detection Complete",
f"Processed {len(results)} images\n{total_detections} objects detected",
)
def _on_detection_error(self, error_msg: str):
"""Handle detection error."""
self.progress_bar.setVisible(False)
self._set_buttons_enabled(True)
self.results_text.append(f"\nERROR: {error_msg}")
QMessageBox.critical(self, "Detection Error", error_msg)
def _set_buttons_enabled(self, enabled: bool):
"""Enable/disable action buttons."""
self.single_image_btn.setEnabled(enabled)
self.batch_btn.setEnabled(enabled)
self.model_combo.setEnabled(enabled)
def refresh(self):
"""Refresh the tab."""
self._load_models()
self.results_text.clear()

View File

@@ -0,0 +1,46 @@
"""
Results tab for the microscopy object detection application.
"""
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
from src.database.db_manager import DatabaseManager
from src.utils.config_manager import ConfigManager
class ResultsTab(QWidget):
"""Results tab placeholder."""
def __init__(
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
):
super().__init__(parent)
self.db_manager = db_manager
self.config_manager = config_manager
self._setup_ui()
def _setup_ui(self):
"""Setup user interface."""
layout = QVBoxLayout()
group = QGroupBox("Results")
group_layout = QVBoxLayout()
label = QLabel(
"Results viewer will be implemented here.\n\n"
"Features:\n"
"- Detection history browser\n"
"- Advanced filtering\n"
"- Statistics dashboard\n"
"- Export functionality"
)
group_layout.addWidget(label)
group.setLayout(group_layout)
layout.addWidget(group)
layout.addStretch()
self.setLayout(layout)
def refresh(self):
"""Refresh the tab."""
pass

View File

@@ -0,0 +1,52 @@
"""
Training tab for the microscopy object detection application.
Handles model training with YOLO.
"""
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
from src.database.db_manager import DatabaseManager
from src.utils.config_manager import ConfigManager
from src.utils.logger import get_logger
logger = get_logger(__name__)
class TrainingTab(QWidget):
"""Training tab for model training."""
def __init__(
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
):
super().__init__(parent)
self.db_manager = db_manager
self.config_manager = config_manager
self._setup_ui()
def _setup_ui(self):
"""Setup user interface."""
layout = QVBoxLayout()
# Placeholder
group = QGroupBox("Training")
group_layout = QVBoxLayout()
label = QLabel(
"Training functionality will be implemented here.\n\n"
"Features:\n"
"- Dataset selection\n"
"- Training parameter configuration\n"
"- Real-time training progress\n"
"- Loss and metric visualization"
)
group_layout.addWidget(label)
group.setLayout(group_layout)
layout.addWidget(group)
layout.addStretch()
self.setLayout(layout)
def refresh(self):
"""Refresh the tab."""
pass

View File

@@ -0,0 +1,46 @@
"""
Validation tab for the microscopy object detection application.
"""
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
from src.database.db_manager import DatabaseManager
from src.utils.config_manager import ConfigManager
class ValidationTab(QWidget):
"""Validation tab placeholder."""
def __init__(
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
):
super().__init__(parent)
self.db_manager = db_manager
self.config_manager = config_manager
self._setup_ui()
def _setup_ui(self):
"""Setup user interface."""
layout = QVBoxLayout()
group = QGroupBox("Validation")
group_layout = QVBoxLayout()
label = QLabel(
"Validation functionality will be implemented here.\n\n"
"Features:\n"
"- Model validation\n"
"- Metrics visualization\n"
"- Confusion matrix\n"
"- Precision-Recall curves"
)
group_layout.addWidget(label)
group.setLayout(group_layout)
layout.addWidget(group)
layout.addStretch()
self.setLayout(layout)
def refresh(self):
"""Refresh the tab."""
pass

View File

0
src/model/__init__.py Normal file
View File

323
src/model/inference.py Normal file
View File

@@ -0,0 +1,323 @@
"""
Inference engine for the microscopy object detection application.
Handles detection inference and result storage.
"""
from typing import List, Dict, Optional, Callable
from pathlib import Path
from PIL import Image
import cv2
import numpy as np
from src.model.yolo_wrapper import YOLOWrapper
from src.database.db_manager import DatabaseManager
from src.utils.logger import get_logger
from src.utils.file_utils import get_relative_path
logger = get_logger(__name__)
class InferenceEngine:
"""Handles detection inference and result storage."""
def __init__(self, model_path: str, db_manager: DatabaseManager, model_id: int):
"""
Initialize inference engine.
Args:
model_path: Path to YOLO model weights
db_manager: Database manager instance
model_id: ID of the model in database
"""
self.yolo = YOLOWrapper(model_path)
self.yolo.load_model()
self.db_manager = db_manager
self.model_id = model_id
logger.info(f"InferenceEngine initialized with model_id {model_id}")
def detect_single(
self,
image_path: str,
relative_path: str,
conf: float = 0.25,
save_to_db: bool = True,
) -> Dict:
"""
Detect objects in a single image.
Args:
image_path: Absolute path to image file
relative_path: Relative path from repository root
conf: Confidence threshold
save_to_db: Whether to save results to database
Returns:
Dictionary with detection results
"""
try:
# Get image dimensions
img = Image.open(image_path)
width, height = img.size
img.close()
# Perform detection
detections = self.yolo.predict(image_path, conf=conf)
# Add/get image in database
image_id = self.db_manager.get_or_create_image(
relative_path=relative_path,
filename=Path(image_path).name,
width=width,
height=height,
)
# Save detections to database
if save_to_db and detections:
detection_records = []
for det in detections:
# Use normalized bbox from detection
bbox_normalized = det[
"bbox_normalized"
] # [x_min, y_min, x_max, y_max]
record = {
"image_id": image_id,
"model_id": self.model_id,
"class_name": det["class_name"],
"bbox": tuple(bbox_normalized),
"confidence": det["confidence"],
"metadata": {"class_id": det["class_id"]},
}
detection_records.append(record)
self.db_manager.add_detections_batch(detection_records)
logger.info(f"Saved {len(detection_records)} detections to database")
return {
"success": True,
"image_path": image_path,
"image_id": image_id,
"detections": detections,
"count": len(detections),
}
except Exception as e:
logger.error(f"Error detecting objects in {image_path}: {e}")
return {
"success": False,
"image_path": image_path,
"error": str(e),
"detections": [],
"count": 0,
}
def detect_batch(
self,
image_paths: List[str],
repository_root: str,
conf: float = 0.25,
progress_callback: Optional[Callable[[int, int, str], None]] = None,
) -> List[Dict]:
"""
Detect objects in multiple images.
Args:
image_paths: List of absolute image paths
repository_root: Root directory for relative paths
conf: Confidence threshold
progress_callback: Optional callback(current, total, message)
Returns:
List of detection result dictionaries
"""
results = []
total = len(image_paths)
logger.info(f"Starting batch detection on {total} images")
for i, image_path in enumerate(image_paths, 1):
# Calculate relative path
rel_path = get_relative_path(image_path, repository_root)
# Perform detection
result = self.detect_single(image_path, rel_path, conf)
results.append(result)
# Update progress
if progress_callback:
progress_callback(i, total, f"Processed {rel_path}")
if i % 10 == 0:
logger.info(f"Processed {i}/{total} images")
logger.info(f"Batch detection complete: {total} images processed")
return results
def detect_with_visualization(
self,
image_path: str,
conf: float = 0.25,
bbox_thickness: int = 2,
bbox_colors: Optional[Dict[str, str]] = None,
) -> tuple:
"""
Detect objects and return annotated image.
Args:
image_path: Path to image
conf: Confidence threshold
bbox_thickness: Thickness of bounding boxes
bbox_colors: Dictionary mapping class names to hex colors
Returns:
Tuple of (detections, annotated_image_array)
"""
try:
detections = self.yolo.predict(image_path, conf=conf)
# Load image
img = cv2.imread(image_path)
if img is None:
raise ValueError(f"Failed to load image: {image_path}")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
height, width = img.shape[:2]
# Default colors if not provided
if bbox_colors is None:
bbox_colors = {}
default_color = self._hex_to_bgr(bbox_colors.get("default", "#00FF00"))
# Draw bounding boxes
for det in detections:
# Get absolute coordinates
bbox_abs = det["bbox_absolute"]
x1, y1, x2, y2 = [int(v) for v in bbox_abs]
# Get color for this class
class_name = det["class_name"]
color_hex = bbox_colors.get(
class_name, bbox_colors.get("default", "#00FF00")
)
color = self._hex_to_bgr(color_hex)
# Draw box
cv2.rectangle(img, (x1, y1), (x2, y2), color, bbox_thickness)
# Prepare label
label = f"{class_name} {det['confidence']:.2f}"
# Draw label background
(label_w, label_h), baseline = cv2.getTextSize(
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
)
cv2.rectangle(
img,
(x1, y1 - label_h - baseline - 5),
(x1 + label_w, y1),
color,
-1,
)
# Draw label text
cv2.putText(
img,
label,
(x1, y1 - baseline - 5),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(255, 255, 255),
1,
)
return detections, img
except Exception as e:
logger.error(f"Error creating visualization: {e}")
# Return empty detections and original image if possible
try:
img = cv2.imread(image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return [], img
except:
return [], np.zeros((480, 640, 3), dtype=np.uint8)
def get_detection_summary(self, detections: List[Dict]) -> Dict[str, any]:
"""
Generate summary statistics for detections.
Args:
detections: List of detection dictionaries
Returns:
Dictionary with summary statistics
"""
if not detections:
return {
"total_count": 0,
"class_counts": {},
"avg_confidence": 0.0,
"confidence_range": (0.0, 0.0),
}
# Count by class
class_counts = {}
confidences = []
for det in detections:
class_name = det["class_name"]
class_counts[class_name] = class_counts.get(class_name, 0) + 1
confidences.append(det["confidence"])
return {
"total_count": len(detections),
"class_counts": class_counts,
"avg_confidence": sum(confidences) / len(confidences),
"confidence_range": (min(confidences), max(confidences)),
}
@staticmethod
def _hex_to_bgr(hex_color: str) -> tuple:
"""
Convert hex color to BGR tuple.
Args:
hex_color: Hex color string (e.g., '#FF0000')
Returns:
BGR tuple (B, G, R)
"""
hex_color = hex_color.lstrip("#")
if len(hex_color) != 6:
return (0, 255, 0) # Default green
try:
r = int(hex_color[0:2], 16)
g = int(hex_color[2:4], 16)
b = int(hex_color[4:6], 16)
return (b, g, r) # OpenCV uses BGR
except ValueError:
return (0, 255, 0) # Default green
def change_model(self, model_path: str, model_id: int) -> bool:
"""
Change the current model.
Args:
model_path: Path to new model weights
model_id: ID of new model in database
Returns:
True if successful, False otherwise
"""
try:
self.yolo = YOLOWrapper(model_path)
if self.yolo.load_model():
self.model_id = model_id
logger.info(f"Model changed to {model_path}")
return True
return False
except Exception as e:
logger.error(f"Error changing model: {e}")
return False

364
src/model/yolo_wrapper.py Normal file
View File

@@ -0,0 +1,364 @@
"""
YOLO model wrapper for the microscopy object detection application.
Provides a clean interface to YOLOv8 for training, validation, and inference.
"""
from ultralytics import YOLO
from pathlib import Path
from typing import Optional, List, Dict, Callable, Any
import torch
from src.utils.logger import get_logger
logger = get_logger(__name__)
class YOLOWrapper:
"""Wrapper for YOLOv8 model operations."""
def __init__(self, model_path: str = "yolov8s.pt"):
"""
Initialize YOLO model.
Args:
model_path: Path to model weights (.pt file)
"""
self.model_path = model_path
self.model = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"YOLOWrapper initialized with device: {self.device}")
def load_model(self) -> bool:
"""
Load YOLO model from path.
Returns:
True if loaded successfully, False otherwise
"""
try:
logger.info(f"Loading YOLO model from {self.model_path}")
self.model = YOLO(self.model_path)
self.model.to(self.device)
logger.info("Model loaded successfully")
return True
except Exception as e:
logger.error(f"Error loading model: {e}")
return False
def train(
self,
data_yaml: str,
epochs: int = 100,
imgsz: int = 640,
batch: int = 16,
patience: int = 50,
save_dir: str = "data/models",
name: str = "custom_model",
resume: bool = False,
**kwargs,
) -> Dict[str, Any]:
"""
Train the YOLO model.
Args:
data_yaml: Path to data.yaml configuration file
epochs: Number of training epochs
imgsz: Input image size
batch: Batch size
patience: Early stopping patience
save_dir: Directory to save trained model
name: Name for the training run
resume: Resume training from last checkpoint
**kwargs: Additional training arguments
Returns:
Dictionary with training results
"""
if self.model is None:
self.load_model()
try:
logger.info(f"Starting training: {name}")
logger.info(
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
)
# Train the model
results = self.model.train(
data=data_yaml,
epochs=epochs,
imgsz=imgsz,
batch=batch,
patience=patience,
project=save_dir,
name=name,
device=self.device,
resume=resume,
**kwargs,
)
logger.info("Training completed successfully")
return self._format_training_results(results)
except Exception as e:
logger.error(f"Error during training: {e}")
raise
def validate(self, data_yaml: str, split: str = "val", **kwargs) -> Dict[str, Any]:
"""
Validate the model.
Args:
data_yaml: Path to data.yaml configuration file
split: Dataset split to validate on ('val' or 'test')
**kwargs: Additional validation arguments
Returns:
Dictionary with validation metrics
"""
if self.model is None:
self.load_model()
try:
logger.info(f"Starting validation on {split} split")
results = self.model.val(
data=data_yaml, split=split, device=self.device, **kwargs
)
logger.info("Validation completed successfully")
return self._format_validation_results(results)
except Exception as e:
logger.error(f"Error during validation: {e}")
raise
def predict(
self,
source: str,
conf: float = 0.25,
iou: float = 0.45,
save: bool = False,
save_txt: bool = False,
save_conf: bool = False,
**kwargs,
) -> List[Dict]:
"""
Perform inference on image(s).
Args:
source: Path to image or directory
conf: Confidence threshold
iou: IoU threshold for NMS
save: Whether to save annotated images
save_txt: Whether to save labels to .txt files
save_conf: Whether to save confidence in labels
**kwargs: Additional prediction arguments
Returns:
List of detection dictionaries
"""
if self.model is None:
self.load_model()
try:
logger.info(f"Running inference on {source}")
results = self.model.predict(
source=source,
conf=conf,
iou=iou,
save=save,
save_txt=save_txt,
save_conf=save_conf,
device=self.device,
**kwargs,
)
detections = self._format_prediction_results(results)
logger.info(f"Inference complete: {len(detections)} detections")
return detections
except Exception as e:
logger.error(f"Error during inference: {e}")
raise
def export(
self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
) -> str:
"""
Export model to different format.
Args:
format: Export format (onnx, torchscript, tflite, etc.)
output_path: Path for exported model
**kwargs: Additional export arguments
Returns:
Path to exported model
"""
if self.model is None:
self.load_model()
try:
logger.info(f"Exporting model to {format} format")
export_path = self.model.export(format=format, **kwargs)
logger.info(f"Model exported to {export_path}")
return str(export_path)
except Exception as e:
logger.error(f"Error exporting model: {e}")
raise
def _format_training_results(self, results) -> Dict[str, Any]:
"""Format training results into dictionary."""
try:
# Get the results dict
results_dict = (
results.results_dict if hasattr(results, "results_dict") else {}
)
formatted = {
"success": True,
"final_epoch": getattr(results, "epoch", 0),
"metrics": {
"mAP50": float(results_dict.get("metrics/mAP50(B)", 0)),
"mAP50-95": float(results_dict.get("metrics/mAP50-95(B)", 0)),
"precision": float(results_dict.get("metrics/precision(B)", 0)),
"recall": float(results_dict.get("metrics/recall(B)", 0)),
},
"best_model_path": str(Path(results.save_dir) / "weights" / "best.pt"),
"last_model_path": str(Path(results.save_dir) / "weights" / "last.pt"),
"save_dir": str(results.save_dir),
}
return formatted
except Exception as e:
logger.error(f"Error formatting training results: {e}")
return {"success": False, "error": str(e)}
def _format_validation_results(self, results) -> Dict[str, Any]:
"""Format validation results into dictionary."""
try:
box_metrics = results.box
formatted = {
"success": True,
"mAP50": float(box_metrics.map50),
"mAP50-95": float(box_metrics.map),
"precision": float(box_metrics.mp),
"recall": float(box_metrics.mr),
"fitness": (
float(results.fitness) if hasattr(results, "fitness") else 0.0
),
}
# Add per-class metrics if available
if hasattr(box_metrics, "ap") and hasattr(results, "names"):
class_metrics = {}
for idx, name in results.names.items():
if idx < len(box_metrics.ap):
class_metrics[name] = {
"ap": float(box_metrics.ap[idx]),
"ap50": (
float(box_metrics.ap50[idx])
if hasattr(box_metrics, "ap50")
else 0.0
),
}
formatted["class_metrics"] = class_metrics
return formatted
except Exception as e:
logger.error(f"Error formatting validation results: {e}")
return {"success": False, "error": str(e)}
def _format_prediction_results(self, results) -> List[Dict]:
"""Format prediction results into list of dictionaries."""
detections = []
try:
for result in results:
boxes = result.boxes
image_path = str(result.path)
orig_shape = result.orig_shape # (height, width)
for i in range(len(boxes)):
# Get normalized coordinates
xyxyn = boxes.xyxyn[i].cpu().numpy() # Normalized [x1, y1, x2, y2]
detection = {
"image_path": image_path,
"class_id": int(boxes.cls[i]),
"class_name": result.names[int(boxes.cls[i])],
"confidence": float(boxes.conf[i]),
"bbox_normalized": [
float(v) for v in xyxyn
], # [x_min, y_min, x_max, y_max]
"bbox_absolute": [
float(v) for v in boxes.xyxy[i].cpu().numpy()
], # Absolute pixels
}
detections.append(detection)
return detections
except Exception as e:
logger.error(f"Error formatting prediction results: {e}")
return []
@staticmethod
def convert_bbox_format(
bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy"
) -> List[float]:
"""
Convert bounding box between formats.
Formats:
- xywh: [x_center, y_center, width, height]
- xyxy: [x_min, y_min, x_max, y_max]
Args:
bbox: Bounding box coordinates
format_from: Source format
format_to: Target format
Returns:
Converted bounding box
"""
if format_from == "xywh" and format_to == "xyxy":
x, y, w, h = bbox
return [x - w / 2, y - h / 2, x + w / 2, y + h / 2]
elif format_from == "xyxy" and format_to == "xywh":
x1, y1, x2, y2 = bbox
return [(x1 + x2) / 2, (y1 + y2) / 2, x2 - x1, y2 - y1]
else:
return bbox
def get_model_info(self) -> Dict[str, Any]:
"""
Get information about the loaded model.
Returns:
Dictionary with model information
"""
if self.model is None:
return {"error": "Model not loaded"}
try:
info = {
"model_path": self.model_path,
"device": self.device,
"task": getattr(self.model, "task", "unknown"),
}
# Try to get class names
if hasattr(self.model, "names"):
info["classes"] = self.model.names
info["num_classes"] = len(self.model.names)
return info
except Exception as e:
logger.error(f"Error getting model info: {e}")
return {"error": str(e)}

0
src/utils/__init__.py Normal file
View File

218
src/utils/config_manager.py Normal file
View File

@@ -0,0 +1,218 @@
"""
Configuration manager for the microscopy object detection application.
Handles loading, saving, and accessing application configuration.
"""
import yaml
from pathlib import Path
from typing import Any, Dict, Optional
from src.utils.logger import get_logger
logger = get_logger(__name__)
class ConfigManager:
"""Manages application configuration."""
def __init__(self, config_path: str = "config/app_config.yaml"):
"""
Initialize configuration manager.
Args:
config_path: Path to configuration file
"""
self.config_path = Path(config_path)
self.config: Dict[str, Any] = {}
self._load_config()
def _load_config(self) -> None:
"""Load configuration from YAML file."""
try:
if self.config_path.exists():
with open(self.config_path, "r") as f:
self.config = yaml.safe_load(f) or {}
logger.info(f"Configuration loaded from {self.config_path}")
else:
logger.warning(f"Configuration file not found: {self.config_path}")
self._create_default_config()
except Exception as e:
logger.error(f"Error loading configuration: {e}")
self._create_default_config()
def _create_default_config(self) -> None:
"""Create default configuration."""
self.config = {
"database": {"path": "data/detections.db"},
"image_repository": {
"base_path": "",
"allowed_extensions": [
".jpg",
".jpeg",
".png",
".tif",
".tiff",
".bmp",
],
},
"models": {
"default_base_model": "yolov8s.pt",
"models_directory": "data/models",
},
"training": {
"default_epochs": 100,
"default_batch_size": 16,
"default_imgsz": 640,
"default_patience": 50,
"default_lr0": 0.01,
},
"detection": {
"default_confidence": 0.25,
"default_iou": 0.45,
"max_batch_size": 100,
},
"visualization": {
"bbox_colors": {
"organelle": "#FF6B6B",
"membrane_branch": "#4ECDC4",
"default": "#00FF00",
},
"bbox_thickness": 2,
"font_size": 12,
},
"export": {"formats": ["csv", "json", "excel"], "default_format": "csv"},
"logging": {
"level": "INFO",
"file": "logs/app.log",
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
},
}
self.save_config()
def save_config(self) -> bool:
"""
Save current configuration to file.
Returns:
True if successful, False otherwise
"""
try:
# Create directory if it doesn't exist
self.config_path.parent.mkdir(parents=True, exist_ok=True)
with open(self.config_path, "w") as f:
yaml.dump(self.config, f, default_flow_style=False, sort_keys=False)
logger.info(f"Configuration saved to {self.config_path}")
return True
except Exception as e:
logger.error(f"Error saving configuration: {e}")
return False
def get(self, key: str, default: Any = None) -> Any:
"""
Get configuration value by key.
Args:
key: Configuration key (can use dot notation, e.g., 'database.path')
default: Default value if key not found
Returns:
Configuration value or default
"""
keys = key.split(".")
value = self.config
for k in keys:
if isinstance(value, dict) and k in value:
value = value[k]
else:
return default
return value
def set(self, key: str, value: Any) -> None:
"""
Set configuration value by key.
Args:
key: Configuration key (can use dot notation)
value: Value to set
"""
keys = key.split(".")
config = self.config
# Navigate to the nested dictionary
for k in keys[:-1]:
if k not in config:
config[k] = {}
config = config[k]
# Set the value
config[keys[-1]] = value
logger.debug(f"Configuration updated: {key} = {value}")
def get_section(self, section: str) -> Dict[str, Any]:
"""
Get entire configuration section.
Args:
section: Section name (e.g., 'database', 'training')
Returns:
Dictionary with section configuration
"""
return self.config.get(section, {})
def update_section(self, section: str, values: Dict[str, Any]) -> None:
"""
Update entire configuration section.
Args:
section: Section name
values: Dictionary with new values
"""
if section not in self.config:
self.config[section] = {}
self.config[section].update(values)
logger.debug(f"Configuration section updated: {section}")
def reload(self) -> None:
"""Reload configuration from file."""
self._load_config()
def get_database_path(self) -> str:
"""Get database path."""
return self.get("database.path", "data/detections.db")
def get_image_repository_path(self) -> str:
"""Get image repository base path."""
return self.get("image_repository.base_path", "")
def set_image_repository_path(self, path: str) -> None:
"""Set image repository base path."""
self.set("image_repository.base_path", path)
self.save_config()
def get_models_directory(self) -> str:
"""Get models directory path."""
return self.get("models.models_directory", "data/models")
def get_default_training_params(self) -> Dict[str, Any]:
"""Get default training parameters."""
return self.get_section("training")
def get_default_detection_params(self) -> Dict[str, Any]:
"""Get default detection parameters."""
return self.get_section("detection")
def get_bbox_colors(self) -> Dict[str, str]:
"""Get bounding box colors for different classes."""
return self.get("visualization.bbox_colors", {})
def get_allowed_extensions(self) -> list:
"""Get list of allowed image file extensions."""
return self.get(
"image_repository.allowed_extensions", [".jpg", ".jpeg", ".png"]
)

235
src/utils/file_utils.py Normal file
View File

@@ -0,0 +1,235 @@
"""
File utility functions for the microscopy object detection application.
"""
import os
from pathlib import Path
from typing import List, Optional
from src.utils.logger import get_logger
logger = get_logger(__name__)
def get_image_files(
directory: str,
allowed_extensions: Optional[List[str]] = None,
recursive: bool = False,
) -> List[str]:
"""
Get all image files in a directory.
Args:
directory: Directory path to search
allowed_extensions: List of allowed file extensions (e.g., ['.jpg', '.png'])
recursive: Whether to search recursively
Returns:
List of absolute paths to image files
"""
if allowed_extensions is None:
allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
# Normalize extensions to lowercase
allowed_extensions = [ext.lower() for ext in allowed_extensions]
image_files = []
directory_path = Path(directory)
if not directory_path.exists():
logger.error(f"Directory does not exist: {directory}")
return image_files
try:
if recursive:
# Recursive search
for ext in allowed_extensions:
image_files.extend(directory_path.rglob(f"*{ext}"))
# Also search uppercase extensions
image_files.extend(directory_path.rglob(f"*{ext.upper()}"))
else:
# Top-level search only
for ext in allowed_extensions:
image_files.extend(directory_path.glob(f"*{ext}"))
# Also search uppercase extensions
image_files.extend(directory_path.glob(f"*{ext.upper()}"))
# Convert to absolute paths and sort
image_files = sorted([str(f.absolute()) for f in image_files])
logger.info(f"Found {len(image_files)} image files in {directory}")
except Exception as e:
logger.error(f"Error searching for images: {e}")
return image_files
def ensure_directory(directory: str) -> bool:
"""
Ensure a directory exists, create if it doesn't.
Args:
directory: Directory path
Returns:
True if directory exists or was created successfully
"""
try:
Path(directory).mkdir(parents=True, exist_ok=True)
return True
except Exception as e:
logger.error(f"Error creating directory {directory}: {e}")
return False
def get_relative_path(file_path: str, base_path: str) -> str:
"""
Get relative path from base path.
Args:
file_path: Absolute file path
base_path: Base directory path
Returns:
Relative path string
"""
try:
return str(Path(file_path).relative_to(base_path))
except ValueError:
# If file_path is not relative to base_path, return the filename
return Path(file_path).name
def validate_file_path(file_path: str, must_exist: bool = True) -> bool:
"""
Validate a file path.
Args:
file_path: Path to validate
must_exist: Whether the file must exist
Returns:
True if valid, False otherwise
"""
path = Path(file_path)
if must_exist and not path.exists():
logger.error(f"File does not exist: {file_path}")
return False
if must_exist and not path.is_file():
logger.error(f"Path is not a file: {file_path}")
return False
return True
def get_file_size(file_path: str) -> int:
"""
Get file size in bytes.
Args:
file_path: Path to file
Returns:
File size in bytes, or 0 if error
"""
try:
return Path(file_path).stat().st_size
except Exception as e:
logger.error(f"Error getting file size for {file_path}: {e}")
return 0
def format_file_size(size_bytes: int) -> str:
"""
Format file size in human-readable format.
Args:
size_bytes: Size in bytes
Returns:
Formatted string (e.g., "1.5 MB")
"""
for unit in ["B", "KB", "MB", "GB"]:
if size_bytes < 1024.0:
return f"{size_bytes:.1f} {unit}"
size_bytes /= 1024.0
return f"{size_bytes:.1f} TB"
def create_unique_filename(directory: str, base_name: str, extension: str) -> str:
"""
Create a unique filename by adding a number suffix if file exists.
Args:
directory: Directory path
base_name: Base filename without extension
extension: File extension (with or without dot)
Returns:
Unique filename
"""
if not extension.startswith("."):
extension = "." + extension
directory_path = Path(directory)
filename = f"{base_name}{extension}"
file_path = directory_path / filename
if not file_path.exists():
return filename
# Add number suffix
counter = 1
while True:
filename = f"{base_name}_{counter}{extension}"
file_path = directory_path / filename
if not file_path.exists():
return filename
counter += 1
def is_image_file(
file_path: str, allowed_extensions: Optional[List[str]] = None
) -> bool:
"""
Check if a file is an image based on extension.
Args:
file_path: Path to file
allowed_extensions: List of allowed extensions
Returns:
True if file is an image
"""
if allowed_extensions is None:
allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
extension = Path(file_path).suffix.lower()
return extension in [ext.lower() for ext in allowed_extensions]
def safe_filename(filename: str) -> str:
"""
Convert a string to a safe filename by removing/replacing invalid characters.
Args:
filename: Original filename
Returns:
Safe filename
"""
# Replace invalid characters
invalid_chars = '<>:"/\\|?*'
for char in invalid_chars:
filename = filename.replace(char, "_")
# Remove leading/trailing spaces and dots
filename = filename.strip(". ")
# Ensure filename is not empty
if not filename:
filename = "unnamed"
return filename

75
src/utils/logger.py Normal file
View File

@@ -0,0 +1,75 @@
"""
Logging configuration for the microscopy object detection application.
"""
import logging
import sys
from pathlib import Path
from typing import Optional
def setup_logging(
log_file: str = "logs/app.log",
level: str = "INFO",
log_format: Optional[str] = None,
) -> logging.Logger:
"""
Setup application logging.
Args:
log_file: Path to log file
level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
log_format: Custom log format string
Returns:
Configured logger instance
"""
# Create logs directory if it doesn't exist
log_path = Path(log_file)
log_path.parent.mkdir(parents=True, exist_ok=True)
# Default format if none provided
if log_format is None:
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
# Convert level string to logging constant
numeric_level = getattr(logging, level.upper(), logging.INFO)
# Configure root logger
root_logger = logging.getLogger()
root_logger.setLevel(numeric_level)
# Remove existing handlers
root_logger.handlers.clear()
# Console handler
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(numeric_level)
console_formatter = logging.Formatter(log_format)
console_handler.setFormatter(console_formatter)
root_logger.addHandler(console_handler)
# File handler
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(numeric_level)
file_formatter = logging.Formatter(log_format)
file_handler.setFormatter(file_formatter)
root_logger.addHandler(file_handler)
# Log initial message
root_logger.info("Logging initialized")
return root_logger
def get_logger(name: str) -> logging.Logger:
"""
Get a logger instance for a specific module.
Args:
name: Logger name (typically __name__)
Returns:
Logger instance
"""
return logging.getLogger(name)

0
tests/__init__.py Normal file
View File