Adding python files
This commit is contained in:
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
0
src/database/__init__.py
Normal file
0
src/database/__init__.py
Normal file
619
src/database/db_manager.py
Normal file
619
src/database/db_manager.py
Normal file
@@ -0,0 +1,619 @@
|
||||
"""
|
||||
Database manager for the microscopy object detection application.
|
||||
Handles all database operations including CRUD operations, queries, and exports.
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Optional, Tuple, Any
|
||||
from pathlib import Path
|
||||
import csv
|
||||
import hashlib
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""Manages all database operations for the application."""
|
||||
|
||||
def __init__(self, db_path: str = "data/detections.db"):
|
||||
"""
|
||||
Initialize database manager.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database file
|
||||
"""
|
||||
self.db_path = db_path
|
||||
self._ensure_database_exists()
|
||||
|
||||
def _ensure_database_exists(self) -> None:
|
||||
"""Create database and tables if they don't exist."""
|
||||
# Create directory if it doesn't exist
|
||||
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Read schema file and execute
|
||||
schema_path = Path(__file__).parent / "schema.sql"
|
||||
with open(schema_path, "r") as f:
|
||||
schema_sql = f.read()
|
||||
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
conn.executescript(schema_sql)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_connection(self) -> sqlite3.Connection:
|
||||
"""Get database connection with proper settings."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row # Enable column access by name
|
||||
conn.execute("PRAGMA foreign_keys = ON") # Enable foreign keys
|
||||
return conn
|
||||
|
||||
# ==================== Model Operations ====================
|
||||
|
||||
def add_model(
|
||||
self,
|
||||
model_name: str,
|
||||
model_version: str,
|
||||
model_path: str,
|
||||
base_model: str = "yolov8s.pt",
|
||||
training_params: Optional[Dict] = None,
|
||||
metrics: Optional[Dict] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Add a new model to the database.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
model_version: Version string
|
||||
model_path: Path to model weights file
|
||||
base_model: Base model used for training
|
||||
training_params: Dictionary of training parameters
|
||||
metrics: Dictionary of validation metrics
|
||||
|
||||
Returns:
|
||||
ID of the inserted model
|
||||
"""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO models (model_name, model_version, model_path, base_model, training_params, metrics)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
model_name,
|
||||
model_version,
|
||||
model_path,
|
||||
base_model,
|
||||
json.dumps(training_params) if training_params else None,
|
||||
json.dumps(metrics) if metrics else None,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.lastrowid
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_models(self, filters: Optional[Dict] = None) -> List[Dict]:
|
||||
"""
|
||||
Retrieve models from database.
|
||||
|
||||
Args:
|
||||
filters: Optional filters (e.g., {'model_name': 'my_model'})
|
||||
|
||||
Returns:
|
||||
List of model dictionaries
|
||||
"""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
query = "SELECT * FROM models"
|
||||
params = []
|
||||
|
||||
if filters:
|
||||
conditions = []
|
||||
for key, value in filters.items():
|
||||
conditions.append(f"{key} = ?")
|
||||
params.append(value)
|
||||
query += " WHERE " + " AND ".join(conditions)
|
||||
|
||||
query += " ORDER BY created_at DESC"
|
||||
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(query, params)
|
||||
|
||||
models = []
|
||||
for row in cursor.fetchall():
|
||||
model = dict(row)
|
||||
# Parse JSON fields
|
||||
if model["training_params"]:
|
||||
model["training_params"] = json.loads(model["training_params"])
|
||||
if model["metrics"]:
|
||||
model["metrics"] = json.loads(model["metrics"])
|
||||
models.append(model)
|
||||
|
||||
return models
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_model_by_id(self, model_id: int) -> Optional[Dict]:
|
||||
"""Get model by ID."""
|
||||
models = self.get_models({"id": model_id})
|
||||
return models[0] if models else None
|
||||
|
||||
def update_model(self, model_id: int, updates: Dict) -> bool:
|
||||
"""Update model fields."""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
# Build update query
|
||||
set_clauses = []
|
||||
params = []
|
||||
for key, value in updates.items():
|
||||
if key in ["training_params", "metrics"] and isinstance(value, dict):
|
||||
value = json.dumps(value)
|
||||
set_clauses.append(f"{key} = ?")
|
||||
params.append(value)
|
||||
|
||||
params.append(model_id)
|
||||
|
||||
query = f"UPDATE models SET {', '.join(set_clauses)} WHERE id = ?"
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(query, params)
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# ==================== Image Operations ====================
|
||||
|
||||
def add_image(
|
||||
self,
|
||||
relative_path: str,
|
||||
filename: str,
|
||||
width: int,
|
||||
height: int,
|
||||
captured_at: Optional[datetime] = None,
|
||||
checksum: Optional[str] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Add a new image to the database.
|
||||
|
||||
Args:
|
||||
relative_path: Path relative to image repository
|
||||
filename: Image filename
|
||||
width: Image width in pixels
|
||||
height: Image height in pixels
|
||||
captured_at: When image was captured (if known)
|
||||
checksum: MD5 checksum of image file
|
||||
|
||||
Returns:
|
||||
ID of the inserted image
|
||||
"""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO images (relative_path, filename, width, height, captured_at, checksum)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(relative_path, filename, width, height, captured_at, checksum),
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.lastrowid
|
||||
except sqlite3.IntegrityError:
|
||||
# Image already exists, return its ID
|
||||
cursor.execute(
|
||||
"SELECT id FROM images WHERE relative_path = ?", (relative_path,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return row["id"] if row else None
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_image_by_path(self, relative_path: str) -> Optional[Dict]:
|
||||
"""Get image by relative path."""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT * FROM images WHERE relative_path = ?", (relative_path,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return dict(row) if row else None
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_or_create_image(
|
||||
self, relative_path: str, filename: str, width: int, height: int
|
||||
) -> int:
|
||||
"""Get existing image or create new one."""
|
||||
existing = self.get_image_by_path(relative_path)
|
||||
if existing:
|
||||
return existing["id"]
|
||||
return self.add_image(relative_path, filename, width, height)
|
||||
|
||||
# ==================== Detection Operations ====================
|
||||
|
||||
def add_detection(
|
||||
self,
|
||||
image_id: int,
|
||||
model_id: int,
|
||||
class_name: str,
|
||||
bbox: Tuple[float, float, float, float], # (x_min, y_min, x_max, y_max)
|
||||
confidence: float,
|
||||
metadata: Optional[Dict] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Add a new detection to the database.
|
||||
|
||||
Args:
|
||||
image_id: ID of the image
|
||||
model_id: ID of the model used
|
||||
class_name: Detected object class
|
||||
bbox: Bounding box coordinates (normalized 0-1)
|
||||
confidence: Detection confidence score
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
ID of the inserted detection
|
||||
"""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
x_min, y_min, x_max, y_max = bbox
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
image_id,
|
||||
model_id,
|
||||
class_name,
|
||||
x_min,
|
||||
y_min,
|
||||
x_max,
|
||||
y_max,
|
||||
confidence,
|
||||
json.dumps(metadata) if metadata else None,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.lastrowid
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def add_detections_batch(self, detections: List[Dict]) -> int:
|
||||
"""
|
||||
Add multiple detections in a single transaction.
|
||||
|
||||
Args:
|
||||
detections: List of detection dictionaries
|
||||
|
||||
Returns:
|
||||
Number of detections inserted
|
||||
"""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
for det in detections:
|
||||
bbox = det["bbox"]
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO detections (image_id, model_id, class_name, x_min, y_min, x_max, y_max, confidence, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
det["image_id"],
|
||||
det["model_id"],
|
||||
det["class_name"],
|
||||
bbox[0],
|
||||
bbox[1],
|
||||
bbox[2],
|
||||
bbox[3],
|
||||
det["confidence"],
|
||||
(
|
||||
json.dumps(det.get("metadata"))
|
||||
if det.get("metadata")
|
||||
else None
|
||||
),
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
return len(detections)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_detections(
|
||||
self,
|
||||
filters: Optional[Dict] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: int = 0,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Retrieve detections from database.
|
||||
|
||||
Args:
|
||||
filters: Optional filters for querying
|
||||
limit: Maximum number of results
|
||||
offset: Number of results to skip
|
||||
|
||||
Returns:
|
||||
List of detection dictionaries with joined data
|
||||
"""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
query = """
|
||||
SELECT
|
||||
d.*,
|
||||
i.relative_path as image_path,
|
||||
i.filename as image_filename,
|
||||
i.width as image_width,
|
||||
i.height as image_height,
|
||||
m.model_name,
|
||||
m.model_version
|
||||
FROM detections d
|
||||
JOIN images i ON d.image_id = i.id
|
||||
JOIN models m ON d.model_id = m.id
|
||||
"""
|
||||
params = []
|
||||
|
||||
if filters:
|
||||
conditions = []
|
||||
for key, value in filters.items():
|
||||
if (
|
||||
key.startswith("d.")
|
||||
or key.startswith("i.")
|
||||
or key.startswith("m.")
|
||||
):
|
||||
conditions.append(f"{key} = ?")
|
||||
else:
|
||||
conditions.append(f"d.{key} = ?")
|
||||
params.append(value)
|
||||
query += " WHERE " + " AND ".join(conditions)
|
||||
|
||||
query += " ORDER BY d.detected_at DESC"
|
||||
|
||||
if limit:
|
||||
query += f" LIMIT {limit} OFFSET {offset}"
|
||||
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(query, params)
|
||||
|
||||
detections = []
|
||||
for row in cursor.fetchall():
|
||||
det = dict(row)
|
||||
# Parse JSON metadata
|
||||
if det.get("metadata"):
|
||||
det["metadata"] = json.loads(det["metadata"])
|
||||
detections.append(det)
|
||||
|
||||
return detections
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_detections_for_image(
|
||||
self, image_id: int, model_id: Optional[int] = None
|
||||
) -> List[Dict]:
|
||||
"""Get all detections for a specific image."""
|
||||
filters = {"image_id": image_id}
|
||||
if model_id:
|
||||
filters["model_id"] = model_id
|
||||
return self.get_detections(filters)
|
||||
|
||||
def delete_detections_for_model(self, model_id: int) -> int:
|
||||
"""Delete all detections for a specific model."""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM detections WHERE model_id = ?", (model_id,))
|
||||
conn.commit()
|
||||
return cursor.rowcount
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# ==================== Statistics Operations ====================
|
||||
|
||||
def get_detection_statistics(
|
||||
self, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None
|
||||
) -> Dict:
|
||||
"""
|
||||
Get detection statistics for a date range.
|
||||
|
||||
Returns:
|
||||
Dictionary with statistics (count by class, confidence distribution, etc.)
|
||||
"""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Build date filter
|
||||
date_filter = ""
|
||||
params = []
|
||||
if start_date:
|
||||
date_filter += " AND detected_at >= ?"
|
||||
params.append(start_date)
|
||||
if end_date:
|
||||
date_filter += " AND detected_at <= ?"
|
||||
params.append(end_date)
|
||||
|
||||
# Total detections
|
||||
cursor.execute(
|
||||
f"SELECT COUNT(*) as count FROM detections WHERE 1=1{date_filter}",
|
||||
params,
|
||||
)
|
||||
total_count = cursor.fetchone()["count"]
|
||||
|
||||
# Count by class
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT class_name, COUNT(*) as count
|
||||
FROM detections
|
||||
WHERE 1=1{date_filter}
|
||||
GROUP BY class_name
|
||||
ORDER BY count DESC
|
||||
""",
|
||||
params,
|
||||
)
|
||||
class_counts = {
|
||||
row["class_name"]: row["count"] for row in cursor.fetchall()
|
||||
}
|
||||
|
||||
# Average confidence
|
||||
cursor.execute(
|
||||
f"SELECT AVG(confidence) as avg_conf FROM detections WHERE 1=1{date_filter}",
|
||||
params,
|
||||
)
|
||||
avg_confidence = cursor.fetchone()["avg_conf"] or 0
|
||||
|
||||
# Confidence distribution
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT
|
||||
CASE
|
||||
WHEN confidence < 0.3 THEN 'low'
|
||||
WHEN confidence < 0.7 THEN 'medium'
|
||||
ELSE 'high'
|
||||
END as conf_level,
|
||||
COUNT(*) as count
|
||||
FROM detections
|
||||
WHERE 1=1{date_filter}
|
||||
GROUP BY conf_level
|
||||
""",
|
||||
params,
|
||||
)
|
||||
conf_dist = {row["conf_level"]: row["count"] for row in cursor.fetchall()}
|
||||
|
||||
return {
|
||||
"total_detections": total_count,
|
||||
"class_counts": class_counts,
|
||||
"average_confidence": avg_confidence,
|
||||
"confidence_distribution": conf_dist,
|
||||
}
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_class_distribution(self, model_id: Optional[int] = None) -> Dict[str, int]:
|
||||
"""Get count of detections per class."""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
query = "SELECT class_name, COUNT(*) as count FROM detections"
|
||||
params = []
|
||||
|
||||
if model_id:
|
||||
query += " WHERE model_id = ?"
|
||||
params.append(model_id)
|
||||
|
||||
query += " GROUP BY class_name ORDER BY count DESC"
|
||||
|
||||
cursor.execute(query, params)
|
||||
return {row["class_name"]: row["count"] for row in cursor.fetchall()}
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# ==================== Export Operations ====================
|
||||
|
||||
def export_detections_to_csv(
|
||||
self, output_path: str, filters: Optional[Dict] = None
|
||||
) -> bool:
|
||||
"""Export detections to CSV file."""
|
||||
try:
|
||||
detections = self.get_detections(filters)
|
||||
|
||||
with open(output_path, "w", newline="") as csvfile:
|
||||
if not detections:
|
||||
return True
|
||||
|
||||
fieldnames = [
|
||||
"id",
|
||||
"image_path",
|
||||
"model_name",
|
||||
"model_version",
|
||||
"class_name",
|
||||
"x_min",
|
||||
"y_min",
|
||||
"x_max",
|
||||
"y_max",
|
||||
"confidence",
|
||||
"detected_at",
|
||||
]
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
|
||||
for det in detections:
|
||||
row = {k: det[k] for k in fieldnames if k in det}
|
||||
writer.writerow(row)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error exporting to CSV: {e}")
|
||||
return False
|
||||
|
||||
def export_detections_to_json(
|
||||
self, output_path: str, filters: Optional[Dict] = None
|
||||
) -> bool:
|
||||
"""Export detections to JSON file."""
|
||||
try:
|
||||
detections = self.get_detections(filters)
|
||||
|
||||
# Convert datetime objects to strings
|
||||
for det in detections:
|
||||
if isinstance(det.get("detected_at"), datetime):
|
||||
det["detected_at"] = det["detected_at"].isoformat()
|
||||
|
||||
with open(output_path, "w") as jsonfile:
|
||||
json.dump(detections, jsonfile, indent=2)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error exporting to JSON: {e}")
|
||||
return False
|
||||
|
||||
# ==================== Annotation Operations ====================
|
||||
|
||||
def add_annotation(
|
||||
self,
|
||||
image_id: int,
|
||||
class_name: str,
|
||||
bbox: Tuple[float, float, float, float],
|
||||
annotator: str,
|
||||
verified: bool = False,
|
||||
) -> int:
|
||||
"""Add manual annotation."""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
x_min, y_min, x_max, y_max = bbox
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO annotations (image_id, class_name, x_min, y_min, x_max, y_max, annotator, verified)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(image_id, class_name, x_min, y_min, x_max, y_max, annotator, verified),
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.lastrowid
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_annotations_for_image(self, image_id: int) -> List[Dict]:
|
||||
"""Get all annotations for an image."""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM annotations WHERE image_id = ?", (image_id,))
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
@staticmethod
|
||||
def calculate_checksum(file_path: str) -> str:
|
||||
"""Calculate MD5 checksum of a file."""
|
||||
md5_hash = hashlib.md5()
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(4096), b""):
|
||||
md5_hash.update(chunk)
|
||||
return md5_hash.hexdigest()
|
||||
63
src/database/models.py
Normal file
63
src/database/models.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""
|
||||
Data models for the microscopy object detection application.
|
||||
These dataclasses represent the database entities.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class Model:
|
||||
"""Represents a trained model."""
|
||||
|
||||
id: Optional[int]
|
||||
model_name: str
|
||||
model_version: str
|
||||
model_path: str
|
||||
base_model: str
|
||||
created_at: datetime
|
||||
training_params: Optional[Dict]
|
||||
metrics: Optional[Dict]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Image:
|
||||
"""Represents an image in the database."""
|
||||
|
||||
id: Optional[int]
|
||||
relative_path: str
|
||||
filename: str
|
||||
width: int
|
||||
height: int
|
||||
captured_at: Optional[datetime]
|
||||
added_at: datetime
|
||||
checksum: Optional[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Detection:
|
||||
"""Represents a detection result."""
|
||||
|
||||
id: Optional[int]
|
||||
image_id: int
|
||||
model_id: int
|
||||
class_name: str
|
||||
bbox: Tuple[float, float, float, float] # (x_min, y_min, x_max, y_max)
|
||||
confidence: float
|
||||
detected_at: datetime
|
||||
metadata: Optional[Dict]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Annotation:
|
||||
"""Represents a manual annotation."""
|
||||
|
||||
id: Optional[int]
|
||||
image_id: int
|
||||
class_name: str
|
||||
bbox: Tuple[float, float, float, float] # (x_min, y_min, x_max, y_max)
|
||||
annotator: str
|
||||
created_at: datetime
|
||||
verified: bool
|
||||
70
src/database/schema.sql
Normal file
70
src/database/schema.sql
Normal file
@@ -0,0 +1,70 @@
|
||||
-- Microscopy Object Detection Application - Database Schema
|
||||
-- SQLite Database Schema for storing models, images, detections, and annotations
|
||||
|
||||
-- Models table: stores trained model information
|
||||
CREATE TABLE IF NOT EXISTS models (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
model_name TEXT NOT NULL,
|
||||
model_version TEXT NOT NULL,
|
||||
model_path TEXT NOT NULL,
|
||||
base_model TEXT NOT NULL DEFAULT 'yolov8s.pt',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
training_params TEXT, -- JSON string of training parameters
|
||||
metrics TEXT, -- JSON string of validation metrics
|
||||
UNIQUE(model_name, model_version)
|
||||
);
|
||||
|
||||
-- Images table: stores image metadata
|
||||
CREATE TABLE IF NOT EXISTS images (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
relative_path TEXT NOT NULL UNIQUE,
|
||||
filename TEXT NOT NULL,
|
||||
width INTEGER,
|
||||
height INTEGER,
|
||||
captured_at TIMESTAMP,
|
||||
added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
checksum TEXT
|
||||
);
|
||||
|
||||
-- Detections table: stores detection results
|
||||
CREATE TABLE IF NOT EXISTS detections (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
image_id INTEGER NOT NULL,
|
||||
model_id INTEGER NOT NULL,
|
||||
class_name TEXT NOT NULL,
|
||||
x_min REAL NOT NULL CHECK(x_min >= 0 AND x_min <= 1),
|
||||
y_min REAL NOT NULL CHECK(y_min >= 0 AND y_min <= 1),
|
||||
x_max REAL NOT NULL CHECK(x_max >= 0 AND x_max <= 1),
|
||||
y_max REAL NOT NULL CHECK(y_max >= 0 AND y_max <= 1),
|
||||
confidence REAL NOT NULL CHECK(confidence >= 0 AND confidence <= 1),
|
||||
detected_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
metadata TEXT, -- JSON string for additional metadata
|
||||
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- Annotations table: stores manual annotations (future feature)
|
||||
CREATE TABLE IF NOT EXISTS annotations (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
image_id INTEGER NOT NULL,
|
||||
class_name TEXT NOT NULL,
|
||||
x_min REAL NOT NULL CHECK(x_min >= 0 AND x_min <= 1),
|
||||
y_min REAL NOT NULL CHECK(y_min >= 0 AND y_min <= 1),
|
||||
x_max REAL NOT NULL CHECK(x_max >= 0 AND x_max <= 1),
|
||||
y_max REAL NOT NULL CHECK(y_max >= 0 AND y_max <= 1),
|
||||
annotator TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
verified BOOLEAN DEFAULT 0,
|
||||
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- Create indexes for performance optimization
|
||||
CREATE INDEX IF NOT EXISTS idx_detections_image_id ON detections(image_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_detections_model_id ON detections(model_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_detections_class_name ON detections(class_name);
|
||||
CREATE INDEX IF NOT EXISTS idx_detections_detected_at ON detections(detected_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_detections_confidence ON detections(confidence);
|
||||
CREATE INDEX IF NOT EXISTS idx_images_relative_path ON images(relative_path);
|
||||
CREATE INDEX IF NOT EXISTS idx_images_added_at ON images(added_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_annotations_image_id ON annotations(image_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at);
|
||||
0
src/gui/__init__.py
Normal file
0
src/gui/__init__.py
Normal file
0
src/gui/dialogs/__init__.py
Normal file
0
src/gui/dialogs/__init__.py
Normal file
291
src/gui/dialogs/config_dialog.py
Normal file
291
src/gui/dialogs/config_dialog.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""
|
||||
Configuration dialog for the microscopy object detection application.
|
||||
"""
|
||||
|
||||
from PySide6.QtWidgets import (
|
||||
QDialog,
|
||||
QVBoxLayout,
|
||||
QHBoxLayout,
|
||||
QFormLayout,
|
||||
QPushButton,
|
||||
QLineEdit,
|
||||
QSpinBox,
|
||||
QDoubleSpinBox,
|
||||
QFileDialog,
|
||||
QTabWidget,
|
||||
QWidget,
|
||||
QLabel,
|
||||
QGroupBox,
|
||||
)
|
||||
from PySide6.QtCore import Qt
|
||||
|
||||
from src.utils.config_manager import ConfigManager
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ConfigDialog(QDialog):
|
||||
"""Configuration dialog window."""
|
||||
|
||||
def __init__(self, config_manager: ConfigManager, parent=None):
|
||||
super().__init__(parent)
|
||||
|
||||
self.config_manager = config_manager
|
||||
|
||||
self.setWindowTitle("Settings")
|
||||
self.setMinimumWidth(500)
|
||||
self.setMinimumHeight(400)
|
||||
|
||||
self._setup_ui()
|
||||
self._load_settings()
|
||||
|
||||
def _setup_ui(self):
|
||||
"""Setup user interface."""
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# Create tab widget for different setting categories
|
||||
self.tab_widget = QTabWidget()
|
||||
|
||||
# General settings tab
|
||||
general_tab = self._create_general_tab()
|
||||
self.tab_widget.addTab(general_tab, "General")
|
||||
|
||||
# Training settings tab
|
||||
training_tab = self._create_training_tab()
|
||||
self.tab_widget.addTab(training_tab, "Training")
|
||||
|
||||
# Detection settings tab
|
||||
detection_tab = self._create_detection_tab()
|
||||
self.tab_widget.addTab(detection_tab, "Detection")
|
||||
|
||||
layout.addWidget(self.tab_widget)
|
||||
|
||||
# Buttons
|
||||
button_layout = QHBoxLayout()
|
||||
button_layout.addStretch()
|
||||
|
||||
self.save_button = QPushButton("Save")
|
||||
self.save_button.clicked.connect(self.accept)
|
||||
button_layout.addWidget(self.save_button)
|
||||
|
||||
self.cancel_button = QPushButton("Cancel")
|
||||
self.cancel_button.clicked.connect(self.reject)
|
||||
button_layout.addWidget(self.cancel_button)
|
||||
|
||||
layout.addLayout(button_layout)
|
||||
|
||||
self.setLayout(layout)
|
||||
|
||||
def _create_general_tab(self) -> QWidget:
|
||||
"""Create general settings tab."""
|
||||
widget = QWidget()
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# Image repository group
|
||||
repo_group = QGroupBox("Image Repository")
|
||||
repo_layout = QFormLayout()
|
||||
|
||||
# Repository path
|
||||
path_layout = QHBoxLayout()
|
||||
self.repo_path_edit = QLineEdit()
|
||||
self.repo_path_edit.setPlaceholderText("Path to image repository")
|
||||
path_layout.addWidget(self.repo_path_edit)
|
||||
|
||||
browse_button = QPushButton("Browse...")
|
||||
browse_button.clicked.connect(self._browse_repository)
|
||||
path_layout.addWidget(browse_button)
|
||||
|
||||
repo_layout.addRow("Base Path:", path_layout)
|
||||
repo_group.setLayout(repo_layout)
|
||||
layout.addWidget(repo_group)
|
||||
|
||||
# Database group
|
||||
db_group = QGroupBox("Database")
|
||||
db_layout = QFormLayout()
|
||||
|
||||
self.db_path_edit = QLineEdit()
|
||||
self.db_path_edit.setPlaceholderText("Path to database file")
|
||||
db_layout.addRow("Database Path:", self.db_path_edit)
|
||||
|
||||
db_group.setLayout(db_layout)
|
||||
layout.addWidget(db_group)
|
||||
|
||||
# Models group
|
||||
models_group = QGroupBox("Models")
|
||||
models_layout = QFormLayout()
|
||||
|
||||
self.models_dir_edit = QLineEdit()
|
||||
self.models_dir_edit.setPlaceholderText("Directory for saved models")
|
||||
models_layout.addRow("Models Directory:", self.models_dir_edit)
|
||||
|
||||
self.base_model_edit = QLineEdit()
|
||||
self.base_model_edit.setPlaceholderText("yolov8s.pt")
|
||||
models_layout.addRow("Default Base Model:", self.base_model_edit)
|
||||
|
||||
models_group.setLayout(models_layout)
|
||||
layout.addWidget(models_group)
|
||||
|
||||
layout.addStretch()
|
||||
widget.setLayout(layout)
|
||||
return widget
|
||||
|
||||
def _create_training_tab(self) -> QWidget:
|
||||
"""Create training settings tab."""
|
||||
widget = QWidget()
|
||||
layout = QVBoxLayout()
|
||||
|
||||
form_layout = QFormLayout()
|
||||
|
||||
# Epochs
|
||||
self.epochs_spin = QSpinBox()
|
||||
self.epochs_spin.setRange(1, 1000)
|
||||
self.epochs_spin.setValue(100)
|
||||
form_layout.addRow("Default Epochs:", self.epochs_spin)
|
||||
|
||||
# Batch size
|
||||
self.batch_size_spin = QSpinBox()
|
||||
self.batch_size_spin.setRange(1, 128)
|
||||
self.batch_size_spin.setValue(16)
|
||||
form_layout.addRow("Default Batch Size:", self.batch_size_spin)
|
||||
|
||||
# Image size
|
||||
self.imgsz_spin = QSpinBox()
|
||||
self.imgsz_spin.setRange(320, 1280)
|
||||
self.imgsz_spin.setSingleStep(32)
|
||||
self.imgsz_spin.setValue(640)
|
||||
form_layout.addRow("Default Image Size:", self.imgsz_spin)
|
||||
|
||||
# Patience
|
||||
self.patience_spin = QSpinBox()
|
||||
self.patience_spin.setRange(1, 200)
|
||||
self.patience_spin.setValue(50)
|
||||
form_layout.addRow("Default Patience:", self.patience_spin)
|
||||
|
||||
# Learning rate
|
||||
self.lr_spin = QDoubleSpinBox()
|
||||
self.lr_spin.setRange(0.0001, 0.1)
|
||||
self.lr_spin.setSingleStep(0.001)
|
||||
self.lr_spin.setDecimals(4)
|
||||
self.lr_spin.setValue(0.01)
|
||||
form_layout.addRow("Default Learning Rate:", self.lr_spin)
|
||||
|
||||
layout.addLayout(form_layout)
|
||||
layout.addStretch()
|
||||
widget.setLayout(layout)
|
||||
return widget
|
||||
|
||||
def _create_detection_tab(self) -> QWidget:
|
||||
"""Create detection settings tab."""
|
||||
widget = QWidget()
|
||||
layout = QVBoxLayout()
|
||||
|
||||
form_layout = QFormLayout()
|
||||
|
||||
# Confidence threshold
|
||||
self.conf_spin = QDoubleSpinBox()
|
||||
self.conf_spin.setRange(0.0, 1.0)
|
||||
self.conf_spin.setSingleStep(0.05)
|
||||
self.conf_spin.setDecimals(2)
|
||||
self.conf_spin.setValue(0.25)
|
||||
form_layout.addRow("Default Confidence:", self.conf_spin)
|
||||
|
||||
# IoU threshold
|
||||
self.iou_spin = QDoubleSpinBox()
|
||||
self.iou_spin.setRange(0.0, 1.0)
|
||||
self.iou_spin.setSingleStep(0.05)
|
||||
self.iou_spin.setDecimals(2)
|
||||
self.iou_spin.setValue(0.45)
|
||||
form_layout.addRow("Default IoU:", self.iou_spin)
|
||||
|
||||
# Max batch size
|
||||
self.max_batch_spin = QSpinBox()
|
||||
self.max_batch_spin.setRange(1, 1000)
|
||||
self.max_batch_spin.setValue(100)
|
||||
form_layout.addRow("Max Batch Size:", self.max_batch_spin)
|
||||
|
||||
layout.addLayout(form_layout)
|
||||
layout.addStretch()
|
||||
widget.setLayout(layout)
|
||||
return widget
|
||||
|
||||
def _browse_repository(self):
|
||||
"""Browse for image repository directory."""
|
||||
directory = QFileDialog.getExistingDirectory(
|
||||
self, "Select Image Repository", self.repo_path_edit.text()
|
||||
)
|
||||
|
||||
if directory:
|
||||
self.repo_path_edit.setText(directory)
|
||||
|
||||
def _load_settings(self):
|
||||
"""Load current settings into dialog."""
|
||||
# General settings
|
||||
self.repo_path_edit.setText(
|
||||
self.config_manager.get("image_repository.base_path", "")
|
||||
)
|
||||
self.db_path_edit.setText(
|
||||
self.config_manager.get("database.path", "data/detections.db")
|
||||
)
|
||||
self.models_dir_edit.setText(
|
||||
self.config_manager.get("models.models_directory", "data/models")
|
||||
)
|
||||
self.base_model_edit.setText(
|
||||
self.config_manager.get("models.default_base_model", "yolov8s.pt")
|
||||
)
|
||||
|
||||
# Training settings
|
||||
self.epochs_spin.setValue(
|
||||
self.config_manager.get("training.default_epochs", 100)
|
||||
)
|
||||
self.batch_size_spin.setValue(
|
||||
self.config_manager.get("training.default_batch_size", 16)
|
||||
)
|
||||
self.imgsz_spin.setValue(self.config_manager.get("training.default_imgsz", 640))
|
||||
self.patience_spin.setValue(
|
||||
self.config_manager.get("training.default_patience", 50)
|
||||
)
|
||||
self.lr_spin.setValue(self.config_manager.get("training.default_lr0", 0.01))
|
||||
|
||||
# Detection settings
|
||||
self.conf_spin.setValue(
|
||||
self.config_manager.get("detection.default_confidence", 0.25)
|
||||
)
|
||||
self.iou_spin.setValue(self.config_manager.get("detection.default_iou", 0.45))
|
||||
self.max_batch_spin.setValue(
|
||||
self.config_manager.get("detection.max_batch_size", 100)
|
||||
)
|
||||
|
||||
def accept(self):
|
||||
"""Save settings and close dialog."""
|
||||
logger.info("Saving configuration")
|
||||
|
||||
# Save general settings
|
||||
self.config_manager.set(
|
||||
"image_repository.base_path", self.repo_path_edit.text()
|
||||
)
|
||||
self.config_manager.set("database.path", self.db_path_edit.text())
|
||||
self.config_manager.set("models.models_directory", self.models_dir_edit.text())
|
||||
self.config_manager.set(
|
||||
"models.default_base_model", self.base_model_edit.text()
|
||||
)
|
||||
|
||||
# Save training settings
|
||||
self.config_manager.set("training.default_epochs", self.epochs_spin.value())
|
||||
self.config_manager.set(
|
||||
"training.default_batch_size", self.batch_size_spin.value()
|
||||
)
|
||||
self.config_manager.set("training.default_imgsz", self.imgsz_spin.value())
|
||||
self.config_manager.set("training.default_patience", self.patience_spin.value())
|
||||
self.config_manager.set("training.default_lr0", self.lr_spin.value())
|
||||
|
||||
# Save detection settings
|
||||
self.config_manager.set("detection.default_confidence", self.conf_spin.value())
|
||||
self.config_manager.set("detection.default_iou", self.iou_spin.value())
|
||||
self.config_manager.set("detection.max_batch_size", self.max_batch_spin.value())
|
||||
|
||||
# Save to file
|
||||
self.config_manager.save_config()
|
||||
|
||||
super().accept()
|
||||
282
src/gui/main_window.py
Normal file
282
src/gui/main_window.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""
|
||||
Main window for the microscopy object detection application.
|
||||
"""
|
||||
|
||||
from PySide6.QtWidgets import (
|
||||
QMainWindow,
|
||||
QTabWidget,
|
||||
QMenuBar,
|
||||
QMenu,
|
||||
QStatusBar,
|
||||
QMessageBox,
|
||||
QWidget,
|
||||
QVBoxLayout,
|
||||
QLabel,
|
||||
)
|
||||
from PySide6.QtCore import Qt, QTimer
|
||||
from PySide6.QtGui import QAction, QKeySequence
|
||||
|
||||
from src.database.db_manager import DatabaseManager
|
||||
from src.utils.config_manager import ConfigManager
|
||||
from src.utils.logger import get_logger
|
||||
from src.gui.dialogs.config_dialog import ConfigDialog
|
||||
from src.gui.tabs.detection_tab import DetectionTab
|
||||
from src.gui.tabs.training_tab import TrainingTab
|
||||
from src.gui.tabs.validation_tab import ValidationTab
|
||||
from src.gui.tabs.results_tab import ResultsTab
|
||||
from src.gui.tabs.annotation_tab import AnnotationTab
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MainWindow(QMainWindow):
|
||||
"""Main application window."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# Initialize managers
|
||||
self.config_manager = ConfigManager()
|
||||
|
||||
db_path = self.config_manager.get_database_path()
|
||||
self.db_manager = DatabaseManager(db_path)
|
||||
|
||||
logger.info("Main window initializing")
|
||||
|
||||
# Setup UI
|
||||
self.setWindowTitle("Microscopy Object Detection")
|
||||
self.setMinimumSize(1200, 800)
|
||||
|
||||
self._create_menu_bar()
|
||||
self._create_tab_widget()
|
||||
self._create_status_bar()
|
||||
|
||||
# Center window on screen
|
||||
self._center_window()
|
||||
|
||||
logger.info("Main window initialized")
|
||||
|
||||
def _create_menu_bar(self):
|
||||
"""Create application menu bar."""
|
||||
menubar = self.menuBar()
|
||||
|
||||
# File menu
|
||||
file_menu = menubar.addMenu("&File")
|
||||
|
||||
settings_action = QAction("&Settings", self)
|
||||
settings_action.setShortcut(QKeySequence("Ctrl+,"))
|
||||
settings_action.triggered.connect(self._show_settings)
|
||||
file_menu.addAction(settings_action)
|
||||
|
||||
file_menu.addSeparator()
|
||||
|
||||
exit_action = QAction("E&xit", self)
|
||||
exit_action.setShortcut(QKeySequence("Ctrl+Q"))
|
||||
exit_action.triggered.connect(self.close)
|
||||
file_menu.addAction(exit_action)
|
||||
|
||||
# View menu
|
||||
view_menu = menubar.addMenu("&View")
|
||||
|
||||
refresh_action = QAction("&Refresh", self)
|
||||
refresh_action.setShortcut(QKeySequence("F5"))
|
||||
refresh_action.triggered.connect(self._refresh_current_tab)
|
||||
view_menu.addAction(refresh_action)
|
||||
|
||||
# Tools menu
|
||||
tools_menu = menubar.addMenu("&Tools")
|
||||
|
||||
db_stats_action = QAction("Database &Statistics", self)
|
||||
db_stats_action.triggered.connect(self._show_database_stats)
|
||||
tools_menu.addAction(db_stats_action)
|
||||
|
||||
# Help menu
|
||||
help_menu = menubar.addMenu("&Help")
|
||||
|
||||
about_action = QAction("&About", self)
|
||||
about_action.triggered.connect(self._show_about)
|
||||
help_menu.addAction(about_action)
|
||||
|
||||
docs_action = QAction("&Documentation", self)
|
||||
docs_action.triggered.connect(self._show_documentation)
|
||||
help_menu.addAction(docs_action)
|
||||
|
||||
def _create_tab_widget(self):
|
||||
"""Create main tab widget with all tabs."""
|
||||
self.tab_widget = QTabWidget()
|
||||
self.tab_widget.setTabPosition(QTabWidget.North)
|
||||
|
||||
# Create tabs
|
||||
try:
|
||||
self.detection_tab = DetectionTab(self.db_manager, self.config_manager)
|
||||
self.training_tab = TrainingTab(self.db_manager, self.config_manager)
|
||||
self.validation_tab = ValidationTab(self.db_manager, self.config_manager)
|
||||
self.results_tab = ResultsTab(self.db_manager, self.config_manager)
|
||||
self.annotation_tab = AnnotationTab(self.db_manager, self.config_manager)
|
||||
|
||||
# Add tabs to widget
|
||||
self.tab_widget.addTab(self.detection_tab, "Detection")
|
||||
self.tab_widget.addTab(self.training_tab, "Training")
|
||||
self.tab_widget.addTab(self.validation_tab, "Validation")
|
||||
self.tab_widget.addTab(self.results_tab, "Results")
|
||||
self.tab_widget.addTab(self.annotation_tab, "Annotation (Future)")
|
||||
|
||||
# Connect tab change signal
|
||||
self.tab_widget.currentChanged.connect(self._on_tab_changed)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating tabs: {e}")
|
||||
# Create placeholder
|
||||
placeholder = QWidget()
|
||||
layout = QVBoxLayout()
|
||||
layout.addWidget(QLabel(f"Error creating tabs: {e}"))
|
||||
placeholder.setLayout(layout)
|
||||
self.tab_widget.addTab(placeholder, "Error")
|
||||
|
||||
self.setCentralWidget(self.tab_widget)
|
||||
|
||||
def _create_status_bar(self):
|
||||
"""Create status bar."""
|
||||
self.status_bar = QStatusBar()
|
||||
self.setStatusBar(self.status_bar)
|
||||
|
||||
# Add permanent widgets to status bar
|
||||
self.status_label = QLabel("Ready")
|
||||
self.status_bar.addWidget(self.status_label)
|
||||
|
||||
# Initial status message
|
||||
self._update_status("Ready")
|
||||
|
||||
def _center_window(self):
|
||||
"""Center window on screen."""
|
||||
screen = self.screen().geometry()
|
||||
size = self.geometry()
|
||||
self.move(
|
||||
(screen.width() - size.width()) // 2, (screen.height() - size.height()) // 2
|
||||
)
|
||||
|
||||
def _show_settings(self):
|
||||
"""Show settings dialog."""
|
||||
logger.info("Opening settings dialog")
|
||||
dialog = ConfigDialog(self.config_manager, self)
|
||||
if dialog.exec():
|
||||
self._apply_settings()
|
||||
self._update_status("Settings saved")
|
||||
|
||||
def _apply_settings(self):
|
||||
"""Apply changed settings."""
|
||||
logger.info("Applying settings changes")
|
||||
# Reload configuration in all tabs if needed
|
||||
try:
|
||||
if hasattr(self, "detection_tab"):
|
||||
self.detection_tab.refresh()
|
||||
if hasattr(self, "training_tab"):
|
||||
self.training_tab.refresh()
|
||||
if hasattr(self, "results_tab"):
|
||||
self.results_tab.refresh()
|
||||
except Exception as e:
|
||||
logger.error(f"Error applying settings: {e}")
|
||||
|
||||
def _refresh_current_tab(self):
|
||||
"""Refresh the current tab."""
|
||||
current_widget = self.tab_widget.currentWidget()
|
||||
if hasattr(current_widget, "refresh"):
|
||||
current_widget.refresh()
|
||||
self._update_status("Tab refreshed")
|
||||
|
||||
def _on_tab_changed(self, index: int):
|
||||
"""Handle tab change event."""
|
||||
tab_name = self.tab_widget.tabText(index)
|
||||
logger.debug(f"Switched to tab: {tab_name}")
|
||||
self._update_status(f"Viewing: {tab_name}")
|
||||
|
||||
def _show_database_stats(self):
|
||||
"""Show database statistics dialog."""
|
||||
try:
|
||||
stats = self.db_manager.get_detection_statistics()
|
||||
|
||||
message = f"""
|
||||
<h3>Database Statistics</h3>
|
||||
<p><b>Total Detections:</b> {stats.get('total_detections', 0)}</p>
|
||||
<p><b>Average Confidence:</b> {stats.get('average_confidence', 0):.2%}</p>
|
||||
<p><b>Classes:</b></p>
|
||||
<ul>
|
||||
"""
|
||||
|
||||
for class_name, count in stats.get("class_counts", {}).items():
|
||||
message += f"<li>{class_name}: {count}</li>"
|
||||
|
||||
message += "</ul>"
|
||||
|
||||
QMessageBox.information(self, "Database Statistics", message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting database stats: {e}")
|
||||
QMessageBox.warning(
|
||||
self, "Error", f"Failed to get database statistics:\n{str(e)}"
|
||||
)
|
||||
|
||||
def _show_about(self):
|
||||
"""Show about dialog."""
|
||||
about_text = """
|
||||
<h2>Microscopy Object Detection Application</h2>
|
||||
<p><b>Version:</b> 1.0.0</p>
|
||||
<p>A desktop application for detecting organelles and membrane branching
|
||||
structures in microscopy images using YOLOv8.</p>
|
||||
|
||||
<p><b>Features:</b></p>
|
||||
<ul>
|
||||
<li>Object detection with YOLOv8</li>
|
||||
<li>Model training and validation</li>
|
||||
<li>Detection results storage</li>
|
||||
<li>Interactive visualization</li>
|
||||
<li>Export capabilities</li>
|
||||
</ul>
|
||||
|
||||
<p><b>Technologies:</b></p>
|
||||
<ul>
|
||||
<li>Ultralytics YOLOv8</li>
|
||||
<li>PySide6</li>
|
||||
<li>pyqtgraph</li>
|
||||
<li>SQLite</li>
|
||||
</ul>
|
||||
"""
|
||||
|
||||
QMessageBox.about(self, "About", about_text)
|
||||
|
||||
def _show_documentation(self):
|
||||
"""Show documentation."""
|
||||
QMessageBox.information(
|
||||
self,
|
||||
"Documentation",
|
||||
"Please refer to README.md and ARCHITECTURE.md files in the project directory.",
|
||||
)
|
||||
|
||||
def _update_status(self, message: str, timeout: int = 5000):
|
||||
"""
|
||||
Update status bar message.
|
||||
|
||||
Args:
|
||||
message: Status message to display
|
||||
timeout: Time in milliseconds to show message (0 for permanent)
|
||||
"""
|
||||
self.status_label.setText(message)
|
||||
if timeout > 0:
|
||||
QTimer.singleShot(timeout, lambda: self.status_label.setText("Ready"))
|
||||
|
||||
def closeEvent(self, event):
|
||||
"""Handle window close event."""
|
||||
reply = QMessageBox.question(
|
||||
self,
|
||||
"Confirm Exit",
|
||||
"Are you sure you want to exit?",
|
||||
QMessageBox.Yes | QMessageBox.No,
|
||||
QMessageBox.No,
|
||||
)
|
||||
|
||||
if reply == QMessageBox.Yes:
|
||||
logger.info("Application closing")
|
||||
event.accept()
|
||||
else:
|
||||
event.ignore()
|
||||
0
src/gui/tabs/__init__.py
Normal file
0
src/gui/tabs/__init__.py
Normal file
48
src/gui/tabs/annotation_tab.py
Normal file
48
src/gui/tabs/annotation_tab.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
Annotation tab for the microscopy object detection application.
|
||||
Future feature for manual annotation.
|
||||
"""
|
||||
|
||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
|
||||
|
||||
from src.database.db_manager import DatabaseManager
|
||||
from src.utils.config_manager import ConfigManager
|
||||
|
||||
|
||||
class AnnotationTab(QWidget):
|
||||
"""Annotation tab placeholder (future feature)."""
|
||||
|
||||
def __init__(
|
||||
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
||||
):
|
||||
super().__init__(parent)
|
||||
self.db_manager = db_manager
|
||||
self.config_manager = config_manager
|
||||
|
||||
self._setup_ui()
|
||||
|
||||
def _setup_ui(self):
|
||||
"""Setup user interface."""
|
||||
layout = QVBoxLayout()
|
||||
|
||||
group = QGroupBox("Annotation Tool (Future Feature)")
|
||||
group_layout = QVBoxLayout()
|
||||
label = QLabel(
|
||||
"Annotation functionality will be implemented in future version.\n\n"
|
||||
"Planned Features:\n"
|
||||
"- Image browser\n"
|
||||
"- Drawing tools for bounding boxes\n"
|
||||
"- Class label assignment\n"
|
||||
"- Export annotations to YOLO format\n"
|
||||
"- Annotation verification"
|
||||
)
|
||||
group_layout.addWidget(label)
|
||||
group.setLayout(group_layout)
|
||||
|
||||
layout.addWidget(group)
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def refresh(self):
|
||||
"""Refresh the tab."""
|
||||
pass
|
||||
344
src/gui/tabs/detection_tab.py
Normal file
344
src/gui/tabs/detection_tab.py
Normal file
@@ -0,0 +1,344 @@
|
||||
"""
|
||||
Detection tab for the microscopy object detection application.
|
||||
Handles single image and batch detection.
|
||||
"""
|
||||
|
||||
from PySide6.QtWidgets import (
|
||||
QWidget,
|
||||
QVBoxLayout,
|
||||
QHBoxLayout,
|
||||
QPushButton,
|
||||
QLabel,
|
||||
QComboBox,
|
||||
QSlider,
|
||||
QFileDialog,
|
||||
QMessageBox,
|
||||
QProgressBar,
|
||||
QTextEdit,
|
||||
QGroupBox,
|
||||
QFormLayout,
|
||||
)
|
||||
from PySide6.QtCore import Qt, QThread, Signal
|
||||
from pathlib import Path
|
||||
|
||||
from src.database.db_manager import DatabaseManager
|
||||
from src.utils.config_manager import ConfigManager
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.file_utils import get_image_files
|
||||
from src.model.inference import InferenceEngine
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DetectionWorker(QThread):
|
||||
"""Worker thread for running detection."""
|
||||
|
||||
progress = Signal(int, int, str) # current, total, message
|
||||
finished = Signal(list) # results
|
||||
error = Signal(str) # error message
|
||||
|
||||
def __init__(self, engine, image_paths, repo_root, conf):
|
||||
super().__init__()
|
||||
self.engine = engine
|
||||
self.image_paths = image_paths
|
||||
self.repo_root = repo_root
|
||||
self.conf = conf
|
||||
|
||||
def run(self):
|
||||
"""Run detection in background thread."""
|
||||
try:
|
||||
results = self.engine.detect_batch(
|
||||
self.image_paths, self.repo_root, self.conf, self.progress.emit
|
||||
)
|
||||
self.finished.emit(results)
|
||||
except Exception as e:
|
||||
logger.error(f"Detection error: {e}")
|
||||
self.error.emit(str(e))
|
||||
|
||||
|
||||
class DetectionTab(QWidget):
|
||||
"""Detection tab for single image and batch detection."""
|
||||
|
||||
def __init__(
|
||||
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
||||
):
|
||||
super().__init__(parent)
|
||||
self.db_manager = db_manager
|
||||
self.config_manager = config_manager
|
||||
self.inference_engine = None
|
||||
self.current_model_id = None
|
||||
|
||||
self._setup_ui()
|
||||
self._connect_signals()
|
||||
self._load_models()
|
||||
|
||||
def _setup_ui(self):
|
||||
"""Setup user interface."""
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# Model selection group
|
||||
model_group = QGroupBox("Model Selection")
|
||||
model_layout = QFormLayout()
|
||||
|
||||
self.model_combo = QComboBox()
|
||||
self.model_combo.addItem("No models available", None)
|
||||
model_layout.addRow("Model:", self.model_combo)
|
||||
|
||||
model_group.setLayout(model_layout)
|
||||
layout.addWidget(model_group)
|
||||
|
||||
# Detection settings group
|
||||
settings_group = QGroupBox("Detection Settings")
|
||||
settings_layout = QFormLayout()
|
||||
|
||||
# Confidence threshold
|
||||
conf_layout = QHBoxLayout()
|
||||
self.conf_slider = QSlider(Qt.Horizontal)
|
||||
self.conf_slider.setRange(0, 100)
|
||||
self.conf_slider.setValue(25)
|
||||
self.conf_slider.setTickPosition(QSlider.TicksBelow)
|
||||
self.conf_slider.setTickInterval(10)
|
||||
conf_layout.addWidget(self.conf_slider)
|
||||
|
||||
self.conf_label = QLabel("0.25")
|
||||
conf_layout.addWidget(self.conf_label)
|
||||
|
||||
settings_layout.addRow("Confidence:", conf_layout)
|
||||
settings_group.setLayout(settings_layout)
|
||||
layout.addWidget(settings_group)
|
||||
|
||||
# Action buttons
|
||||
button_layout = QHBoxLayout()
|
||||
|
||||
self.single_image_btn = QPushButton("Detect Single Image")
|
||||
self.single_image_btn.clicked.connect(self._detect_single_image)
|
||||
button_layout.addWidget(self.single_image_btn)
|
||||
|
||||
self.batch_btn = QPushButton("Detect Batch (Folder)")
|
||||
self.batch_btn.clicked.connect(self._detect_batch)
|
||||
button_layout.addWidget(self.batch_btn)
|
||||
|
||||
layout.addLayout(button_layout)
|
||||
|
||||
# Progress bar
|
||||
self.progress_bar = QProgressBar()
|
||||
self.progress_bar.setVisible(False)
|
||||
layout.addWidget(self.progress_bar)
|
||||
|
||||
# Results display
|
||||
results_group = QGroupBox("Detection Results")
|
||||
results_layout = QVBoxLayout()
|
||||
|
||||
self.results_text = QTextEdit()
|
||||
self.results_text.setReadOnly(True)
|
||||
self.results_text.setMaximumHeight(200)
|
||||
results_layout.addWidget(self.results_text)
|
||||
|
||||
results_group.setLayout(results_layout)
|
||||
layout.addWidget(results_group)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def _connect_signals(self):
|
||||
"""Connect signals and slots."""
|
||||
self.conf_slider.valueChanged.connect(self._update_confidence_label)
|
||||
self.model_combo.currentIndexChanged.connect(self._on_model_changed)
|
||||
|
||||
def _load_models(self):
|
||||
"""Load available models from database."""
|
||||
try:
|
||||
models = self.db_manager.get_models()
|
||||
self.model_combo.clear()
|
||||
|
||||
if not models:
|
||||
self.model_combo.addItem("No models available", None)
|
||||
self._set_buttons_enabled(False)
|
||||
return
|
||||
|
||||
# Add base model option
|
||||
base_model = self.config_manager.get(
|
||||
"models.default_base_model", "yolov8s.pt"
|
||||
)
|
||||
self.model_combo.addItem(
|
||||
f"Base Model ({base_model})", {"id": 0, "path": base_model}
|
||||
)
|
||||
|
||||
# Add trained models
|
||||
for model in models:
|
||||
display_name = f"{model['model_name']} v{model['model_version']}"
|
||||
self.model_combo.addItem(display_name, model)
|
||||
|
||||
self._set_buttons_enabled(True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading models: {e}")
|
||||
QMessageBox.warning(self, "Error", f"Failed to load models:\n{str(e)}")
|
||||
|
||||
def _on_model_changed(self, index: int):
|
||||
"""Handle model selection change."""
|
||||
model_data = self.model_combo.itemData(index)
|
||||
if model_data and model_data["id"] != 0:
|
||||
self.current_model_id = model_data["id"]
|
||||
else:
|
||||
self.current_model_id = None
|
||||
|
||||
def _update_confidence_label(self, value: int):
|
||||
"""Update confidence label."""
|
||||
conf = value / 100.0
|
||||
self.conf_label.setText(f"{conf:.2f}")
|
||||
|
||||
def _detect_single_image(self):
|
||||
"""Detect objects in a single image."""
|
||||
# Get image file
|
||||
repo_path = self.config_manager.get_image_repository_path()
|
||||
start_dir = repo_path if repo_path else ""
|
||||
|
||||
file_path, _ = QFileDialog.getOpenFileName(
|
||||
self,
|
||||
"Select Image",
|
||||
start_dir,
|
||||
"Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)",
|
||||
)
|
||||
|
||||
if not file_path:
|
||||
return
|
||||
|
||||
# Run detection
|
||||
self._run_detection([file_path])
|
||||
|
||||
def _detect_batch(self):
|
||||
"""Detect objects in batch (folder)."""
|
||||
# Get folder
|
||||
repo_path = self.config_manager.get_image_repository_path()
|
||||
start_dir = repo_path if repo_path else ""
|
||||
|
||||
folder_path = QFileDialog.getExistingDirectory(self, "Select Folder", start_dir)
|
||||
|
||||
if not folder_path:
|
||||
return
|
||||
|
||||
# Get all image files
|
||||
allowed_ext = self.config_manager.get_allowed_extensions()
|
||||
image_files = get_image_files(folder_path, allowed_ext, recursive=False)
|
||||
|
||||
if not image_files:
|
||||
QMessageBox.information(
|
||||
self, "No Images", "No image files found in selected folder."
|
||||
)
|
||||
return
|
||||
|
||||
# Confirm batch processing
|
||||
reply = QMessageBox.question(
|
||||
self,
|
||||
"Confirm Batch Detection",
|
||||
f"Process {len(image_files)} images?",
|
||||
QMessageBox.Yes | QMessageBox.No,
|
||||
)
|
||||
|
||||
if reply == QMessageBox.Yes:
|
||||
self._run_detection(image_files)
|
||||
|
||||
def _run_detection(self, image_paths: list):
|
||||
"""Run detection on image list."""
|
||||
try:
|
||||
# Get selected model
|
||||
model_data = self.model_combo.currentData()
|
||||
if not model_data:
|
||||
QMessageBox.warning(self, "No Model", "Please select a model first.")
|
||||
return
|
||||
|
||||
model_path = model_data["path"]
|
||||
model_id = model_data["id"]
|
||||
|
||||
# Ensure we have a valid model ID (create entry for base model if needed)
|
||||
if model_id == 0:
|
||||
# Create database entry for base model
|
||||
base_model = self.config_manager.get(
|
||||
"models.default_base_model", "yolov8s.pt"
|
||||
)
|
||||
model_id = self.db_manager.add_model(
|
||||
model_name="Base Model",
|
||||
model_version="pretrained",
|
||||
model_path=base_model,
|
||||
base_model=base_model,
|
||||
)
|
||||
|
||||
# Create inference engine
|
||||
self.inference_engine = InferenceEngine(
|
||||
model_path, self.db_manager, model_id
|
||||
)
|
||||
|
||||
# Get confidence threshold
|
||||
conf = self.conf_slider.value() / 100.0
|
||||
|
||||
# Get repository root
|
||||
repo_root = self.config_manager.get_image_repository_path()
|
||||
if not repo_root:
|
||||
repo_root = str(Path(image_paths[0]).parent)
|
||||
|
||||
# Show progress bar
|
||||
self.progress_bar.setVisible(True)
|
||||
self.progress_bar.setMaximum(len(image_paths))
|
||||
self._set_buttons_enabled(False)
|
||||
|
||||
# Create and start worker thread
|
||||
self.worker = DetectionWorker(
|
||||
self.inference_engine, image_paths, repo_root, conf
|
||||
)
|
||||
self.worker.progress.connect(self._on_progress)
|
||||
self.worker.finished.connect(self._on_detection_finished)
|
||||
self.worker.error.connect(self._on_detection_error)
|
||||
self.worker.start()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting detection: {e}")
|
||||
QMessageBox.critical(self, "Error", f"Failed to start detection:\n{str(e)}")
|
||||
self._set_buttons_enabled(True)
|
||||
|
||||
def _on_progress(self, current: int, total: int, message: str):
|
||||
"""Handle progress update."""
|
||||
self.progress_bar.setValue(current)
|
||||
self.results_text.append(f"[{current}/{total}] {message}")
|
||||
|
||||
def _on_detection_finished(self, results: list):
|
||||
"""Handle detection completion."""
|
||||
self.progress_bar.setVisible(False)
|
||||
self._set_buttons_enabled(True)
|
||||
|
||||
# Calculate statistics
|
||||
total_detections = sum(r["count"] for r in results)
|
||||
successful = sum(1 for r in results if r.get("success", False))
|
||||
|
||||
summary = f"\n=== Detection Complete ===\n"
|
||||
summary += f"Processed: {len(results)} images\n"
|
||||
summary += f"Successful: {successful}\n"
|
||||
summary += f"Total detections: {total_detections}\n"
|
||||
|
||||
self.results_text.append(summary)
|
||||
|
||||
QMessageBox.information(
|
||||
self,
|
||||
"Detection Complete",
|
||||
f"Processed {len(results)} images\n{total_detections} objects detected",
|
||||
)
|
||||
|
||||
def _on_detection_error(self, error_msg: str):
|
||||
"""Handle detection error."""
|
||||
self.progress_bar.setVisible(False)
|
||||
self._set_buttons_enabled(True)
|
||||
|
||||
self.results_text.append(f"\nERROR: {error_msg}")
|
||||
QMessageBox.critical(self, "Detection Error", error_msg)
|
||||
|
||||
def _set_buttons_enabled(self, enabled: bool):
|
||||
"""Enable/disable action buttons."""
|
||||
self.single_image_btn.setEnabled(enabled)
|
||||
self.batch_btn.setEnabled(enabled)
|
||||
self.model_combo.setEnabled(enabled)
|
||||
|
||||
def refresh(self):
|
||||
"""Refresh the tab."""
|
||||
self._load_models()
|
||||
self.results_text.clear()
|
||||
46
src/gui/tabs/results_tab.py
Normal file
46
src/gui/tabs/results_tab.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
Results tab for the microscopy object detection application.
|
||||
"""
|
||||
|
||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
|
||||
|
||||
from src.database.db_manager import DatabaseManager
|
||||
from src.utils.config_manager import ConfigManager
|
||||
|
||||
|
||||
class ResultsTab(QWidget):
|
||||
"""Results tab placeholder."""
|
||||
|
||||
def __init__(
|
||||
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
||||
):
|
||||
super().__init__(parent)
|
||||
self.db_manager = db_manager
|
||||
self.config_manager = config_manager
|
||||
|
||||
self._setup_ui()
|
||||
|
||||
def _setup_ui(self):
|
||||
"""Setup user interface."""
|
||||
layout = QVBoxLayout()
|
||||
|
||||
group = QGroupBox("Results")
|
||||
group_layout = QVBoxLayout()
|
||||
label = QLabel(
|
||||
"Results viewer will be implemented here.\n\n"
|
||||
"Features:\n"
|
||||
"- Detection history browser\n"
|
||||
"- Advanced filtering\n"
|
||||
"- Statistics dashboard\n"
|
||||
"- Export functionality"
|
||||
)
|
||||
group_layout.addWidget(label)
|
||||
group.setLayout(group_layout)
|
||||
|
||||
layout.addWidget(group)
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def refresh(self):
|
||||
"""Refresh the tab."""
|
||||
pass
|
||||
52
src/gui/tabs/training_tab.py
Normal file
52
src/gui/tabs/training_tab.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
Training tab for the microscopy object detection application.
|
||||
Handles model training with YOLO.
|
||||
"""
|
||||
|
||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
|
||||
|
||||
from src.database.db_manager import DatabaseManager
|
||||
from src.utils.config_manager import ConfigManager
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TrainingTab(QWidget):
|
||||
"""Training tab for model training."""
|
||||
|
||||
def __init__(
|
||||
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
||||
):
|
||||
super().__init__(parent)
|
||||
self.db_manager = db_manager
|
||||
self.config_manager = config_manager
|
||||
|
||||
self._setup_ui()
|
||||
|
||||
def _setup_ui(self):
|
||||
"""Setup user interface."""
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# Placeholder
|
||||
group = QGroupBox("Training")
|
||||
group_layout = QVBoxLayout()
|
||||
label = QLabel(
|
||||
"Training functionality will be implemented here.\n\n"
|
||||
"Features:\n"
|
||||
"- Dataset selection\n"
|
||||
"- Training parameter configuration\n"
|
||||
"- Real-time training progress\n"
|
||||
"- Loss and metric visualization"
|
||||
)
|
||||
group_layout.addWidget(label)
|
||||
group.setLayout(group_layout)
|
||||
|
||||
layout.addWidget(group)
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def refresh(self):
|
||||
"""Refresh the tab."""
|
||||
pass
|
||||
46
src/gui/tabs/validation_tab.py
Normal file
46
src/gui/tabs/validation_tab.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
Validation tab for the microscopy object detection application.
|
||||
"""
|
||||
|
||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
|
||||
|
||||
from src.database.db_manager import DatabaseManager
|
||||
from src.utils.config_manager import ConfigManager
|
||||
|
||||
|
||||
class ValidationTab(QWidget):
|
||||
"""Validation tab placeholder."""
|
||||
|
||||
def __init__(
|
||||
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
||||
):
|
||||
super().__init__(parent)
|
||||
self.db_manager = db_manager
|
||||
self.config_manager = config_manager
|
||||
|
||||
self._setup_ui()
|
||||
|
||||
def _setup_ui(self):
|
||||
"""Setup user interface."""
|
||||
layout = QVBoxLayout()
|
||||
|
||||
group = QGroupBox("Validation")
|
||||
group_layout = QVBoxLayout()
|
||||
label = QLabel(
|
||||
"Validation functionality will be implemented here.\n\n"
|
||||
"Features:\n"
|
||||
"- Model validation\n"
|
||||
"- Metrics visualization\n"
|
||||
"- Confusion matrix\n"
|
||||
"- Precision-Recall curves"
|
||||
)
|
||||
group_layout.addWidget(label)
|
||||
group.setLayout(group_layout)
|
||||
|
||||
layout.addWidget(group)
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def refresh(self):
|
||||
"""Refresh the tab."""
|
||||
pass
|
||||
0
src/gui/widgets/__init__.py
Normal file
0
src/gui/widgets/__init__.py
Normal file
0
src/model/__init__.py
Normal file
0
src/model/__init__.py
Normal file
323
src/model/inference.py
Normal file
323
src/model/inference.py
Normal file
@@ -0,0 +1,323 @@
|
||||
"""
|
||||
Inference engine for the microscopy object detection application.
|
||||
Handles detection inference and result storage.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Optional, Callable
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from src.model.yolo_wrapper import YOLOWrapper
|
||||
from src.database.db_manager import DatabaseManager
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.file_utils import get_relative_path
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class InferenceEngine:
|
||||
"""Handles detection inference and result storage."""
|
||||
|
||||
def __init__(self, model_path: str, db_manager: DatabaseManager, model_id: int):
|
||||
"""
|
||||
Initialize inference engine.
|
||||
|
||||
Args:
|
||||
model_path: Path to YOLO model weights
|
||||
db_manager: Database manager instance
|
||||
model_id: ID of the model in database
|
||||
"""
|
||||
self.yolo = YOLOWrapper(model_path)
|
||||
self.yolo.load_model()
|
||||
self.db_manager = db_manager
|
||||
self.model_id = model_id
|
||||
logger.info(f"InferenceEngine initialized with model_id {model_id}")
|
||||
|
||||
def detect_single(
|
||||
self,
|
||||
image_path: str,
|
||||
relative_path: str,
|
||||
conf: float = 0.25,
|
||||
save_to_db: bool = True,
|
||||
) -> Dict:
|
||||
"""
|
||||
Detect objects in a single image.
|
||||
|
||||
Args:
|
||||
image_path: Absolute path to image file
|
||||
relative_path: Relative path from repository root
|
||||
conf: Confidence threshold
|
||||
save_to_db: Whether to save results to database
|
||||
|
||||
Returns:
|
||||
Dictionary with detection results
|
||||
"""
|
||||
try:
|
||||
# Get image dimensions
|
||||
img = Image.open(image_path)
|
||||
width, height = img.size
|
||||
img.close()
|
||||
|
||||
# Perform detection
|
||||
detections = self.yolo.predict(image_path, conf=conf)
|
||||
|
||||
# Add/get image in database
|
||||
image_id = self.db_manager.get_or_create_image(
|
||||
relative_path=relative_path,
|
||||
filename=Path(image_path).name,
|
||||
width=width,
|
||||
height=height,
|
||||
)
|
||||
|
||||
# Save detections to database
|
||||
if save_to_db and detections:
|
||||
detection_records = []
|
||||
for det in detections:
|
||||
# Use normalized bbox from detection
|
||||
bbox_normalized = det[
|
||||
"bbox_normalized"
|
||||
] # [x_min, y_min, x_max, y_max]
|
||||
|
||||
record = {
|
||||
"image_id": image_id,
|
||||
"model_id": self.model_id,
|
||||
"class_name": det["class_name"],
|
||||
"bbox": tuple(bbox_normalized),
|
||||
"confidence": det["confidence"],
|
||||
"metadata": {"class_id": det["class_id"]},
|
||||
}
|
||||
detection_records.append(record)
|
||||
|
||||
self.db_manager.add_detections_batch(detection_records)
|
||||
logger.info(f"Saved {len(detection_records)} detections to database")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"image_path": image_path,
|
||||
"image_id": image_id,
|
||||
"detections": detections,
|
||||
"count": len(detections),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error detecting objects in {image_path}: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"image_path": image_path,
|
||||
"error": str(e),
|
||||
"detections": [],
|
||||
"count": 0,
|
||||
}
|
||||
|
||||
def detect_batch(
|
||||
self,
|
||||
image_paths: List[str],
|
||||
repository_root: str,
|
||||
conf: float = 0.25,
|
||||
progress_callback: Optional[Callable[[int, int, str], None]] = None,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Detect objects in multiple images.
|
||||
|
||||
Args:
|
||||
image_paths: List of absolute image paths
|
||||
repository_root: Root directory for relative paths
|
||||
conf: Confidence threshold
|
||||
progress_callback: Optional callback(current, total, message)
|
||||
|
||||
Returns:
|
||||
List of detection result dictionaries
|
||||
"""
|
||||
results = []
|
||||
total = len(image_paths)
|
||||
|
||||
logger.info(f"Starting batch detection on {total} images")
|
||||
|
||||
for i, image_path in enumerate(image_paths, 1):
|
||||
# Calculate relative path
|
||||
rel_path = get_relative_path(image_path, repository_root)
|
||||
|
||||
# Perform detection
|
||||
result = self.detect_single(image_path, rel_path, conf)
|
||||
results.append(result)
|
||||
|
||||
# Update progress
|
||||
if progress_callback:
|
||||
progress_callback(i, total, f"Processed {rel_path}")
|
||||
|
||||
if i % 10 == 0:
|
||||
logger.info(f"Processed {i}/{total} images")
|
||||
|
||||
logger.info(f"Batch detection complete: {total} images processed")
|
||||
return results
|
||||
|
||||
def detect_with_visualization(
|
||||
self,
|
||||
image_path: str,
|
||||
conf: float = 0.25,
|
||||
bbox_thickness: int = 2,
|
||||
bbox_colors: Optional[Dict[str, str]] = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Detect objects and return annotated image.
|
||||
|
||||
Args:
|
||||
image_path: Path to image
|
||||
conf: Confidence threshold
|
||||
bbox_thickness: Thickness of bounding boxes
|
||||
bbox_colors: Dictionary mapping class names to hex colors
|
||||
|
||||
Returns:
|
||||
Tuple of (detections, annotated_image_array)
|
||||
"""
|
||||
try:
|
||||
detections = self.yolo.predict(image_path, conf=conf)
|
||||
|
||||
# Load image
|
||||
img = cv2.imread(image_path)
|
||||
if img is None:
|
||||
raise ValueError(f"Failed to load image: {image_path}")
|
||||
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
height, width = img.shape[:2]
|
||||
|
||||
# Default colors if not provided
|
||||
if bbox_colors is None:
|
||||
bbox_colors = {}
|
||||
default_color = self._hex_to_bgr(bbox_colors.get("default", "#00FF00"))
|
||||
|
||||
# Draw bounding boxes
|
||||
for det in detections:
|
||||
# Get absolute coordinates
|
||||
bbox_abs = det["bbox_absolute"]
|
||||
x1, y1, x2, y2 = [int(v) for v in bbox_abs]
|
||||
|
||||
# Get color for this class
|
||||
class_name = det["class_name"]
|
||||
color_hex = bbox_colors.get(
|
||||
class_name, bbox_colors.get("default", "#00FF00")
|
||||
)
|
||||
color = self._hex_to_bgr(color_hex)
|
||||
|
||||
# Draw box
|
||||
cv2.rectangle(img, (x1, y1), (x2, y2), color, bbox_thickness)
|
||||
|
||||
# Prepare label
|
||||
label = f"{class_name} {det['confidence']:.2f}"
|
||||
|
||||
# Draw label background
|
||||
(label_w, label_h), baseline = cv2.getTextSize(
|
||||
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
|
||||
)
|
||||
cv2.rectangle(
|
||||
img,
|
||||
(x1, y1 - label_h - baseline - 5),
|
||||
(x1 + label_w, y1),
|
||||
color,
|
||||
-1,
|
||||
)
|
||||
|
||||
# Draw label text
|
||||
cv2.putText(
|
||||
img,
|
||||
label,
|
||||
(x1, y1 - baseline - 5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
(255, 255, 255),
|
||||
1,
|
||||
)
|
||||
|
||||
return detections, img
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating visualization: {e}")
|
||||
# Return empty detections and original image if possible
|
||||
try:
|
||||
img = cv2.imread(image_path)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
return [], img
|
||||
except:
|
||||
return [], np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
|
||||
def get_detection_summary(self, detections: List[Dict]) -> Dict[str, any]:
|
||||
"""
|
||||
Generate summary statistics for detections.
|
||||
|
||||
Args:
|
||||
detections: List of detection dictionaries
|
||||
|
||||
Returns:
|
||||
Dictionary with summary statistics
|
||||
"""
|
||||
if not detections:
|
||||
return {
|
||||
"total_count": 0,
|
||||
"class_counts": {},
|
||||
"avg_confidence": 0.0,
|
||||
"confidence_range": (0.0, 0.0),
|
||||
}
|
||||
|
||||
# Count by class
|
||||
class_counts = {}
|
||||
confidences = []
|
||||
|
||||
for det in detections:
|
||||
class_name = det["class_name"]
|
||||
class_counts[class_name] = class_counts.get(class_name, 0) + 1
|
||||
confidences.append(det["confidence"])
|
||||
|
||||
return {
|
||||
"total_count": len(detections),
|
||||
"class_counts": class_counts,
|
||||
"avg_confidence": sum(confidences) / len(confidences),
|
||||
"confidence_range": (min(confidences), max(confidences)),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _hex_to_bgr(hex_color: str) -> tuple:
|
||||
"""
|
||||
Convert hex color to BGR tuple.
|
||||
|
||||
Args:
|
||||
hex_color: Hex color string (e.g., '#FF0000')
|
||||
|
||||
Returns:
|
||||
BGR tuple (B, G, R)
|
||||
"""
|
||||
hex_color = hex_color.lstrip("#")
|
||||
if len(hex_color) != 6:
|
||||
return (0, 255, 0) # Default green
|
||||
|
||||
try:
|
||||
r = int(hex_color[0:2], 16)
|
||||
g = int(hex_color[2:4], 16)
|
||||
b = int(hex_color[4:6], 16)
|
||||
return (b, g, r) # OpenCV uses BGR
|
||||
except ValueError:
|
||||
return (0, 255, 0) # Default green
|
||||
|
||||
def change_model(self, model_path: str, model_id: int) -> bool:
|
||||
"""
|
||||
Change the current model.
|
||||
|
||||
Args:
|
||||
model_path: Path to new model weights
|
||||
model_id: ID of new model in database
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
self.yolo = YOLOWrapper(model_path)
|
||||
if self.yolo.load_model():
|
||||
self.model_id = model_id
|
||||
logger.info(f"Model changed to {model_path}")
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error changing model: {e}")
|
||||
return False
|
||||
364
src/model/yolo_wrapper.py
Normal file
364
src/model/yolo_wrapper.py
Normal file
@@ -0,0 +1,364 @@
|
||||
"""
|
||||
YOLO model wrapper for the microscopy object detection application.
|
||||
Provides a clean interface to YOLOv8 for training, validation, and inference.
|
||||
"""
|
||||
|
||||
from ultralytics import YOLO
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Callable, Any
|
||||
import torch
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class YOLOWrapper:
|
||||
"""Wrapper for YOLOv8 model operations."""
|
||||
|
||||
def __init__(self, model_path: str = "yolov8s.pt"):
|
||||
"""
|
||||
Initialize YOLO model.
|
||||
|
||||
Args:
|
||||
model_path: Path to model weights (.pt file)
|
||||
"""
|
||||
self.model_path = model_path
|
||||
self.model = None
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"YOLOWrapper initialized with device: {self.device}")
|
||||
|
||||
def load_model(self) -> bool:
|
||||
"""
|
||||
Load YOLO model from path.
|
||||
|
||||
Returns:
|
||||
True if loaded successfully, False otherwise
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Loading YOLO model from {self.model_path}")
|
||||
self.model = YOLO(self.model_path)
|
||||
self.model.to(self.device)
|
||||
logger.info("Model loaded successfully")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading model: {e}")
|
||||
return False
|
||||
|
||||
def train(
|
||||
self,
|
||||
data_yaml: str,
|
||||
epochs: int = 100,
|
||||
imgsz: int = 640,
|
||||
batch: int = 16,
|
||||
patience: int = 50,
|
||||
save_dir: str = "data/models",
|
||||
name: str = "custom_model",
|
||||
resume: bool = False,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Train the YOLO model.
|
||||
|
||||
Args:
|
||||
data_yaml: Path to data.yaml configuration file
|
||||
epochs: Number of training epochs
|
||||
imgsz: Input image size
|
||||
batch: Batch size
|
||||
patience: Early stopping patience
|
||||
save_dir: Directory to save trained model
|
||||
name: Name for the training run
|
||||
resume: Resume training from last checkpoint
|
||||
**kwargs: Additional training arguments
|
||||
|
||||
Returns:
|
||||
Dictionary with training results
|
||||
"""
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
|
||||
try:
|
||||
logger.info(f"Starting training: {name}")
|
||||
logger.info(
|
||||
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
|
||||
)
|
||||
|
||||
# Train the model
|
||||
results = self.model.train(
|
||||
data=data_yaml,
|
||||
epochs=epochs,
|
||||
imgsz=imgsz,
|
||||
batch=batch,
|
||||
patience=patience,
|
||||
project=save_dir,
|
||||
name=name,
|
||||
device=self.device,
|
||||
resume=resume,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
logger.info("Training completed successfully")
|
||||
return self._format_training_results(results)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during training: {e}")
|
||||
raise
|
||||
|
||||
def validate(self, data_yaml: str, split: str = "val", **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate the model.
|
||||
|
||||
Args:
|
||||
data_yaml: Path to data.yaml configuration file
|
||||
split: Dataset split to validate on ('val' or 'test')
|
||||
**kwargs: Additional validation arguments
|
||||
|
||||
Returns:
|
||||
Dictionary with validation metrics
|
||||
"""
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
|
||||
try:
|
||||
logger.info(f"Starting validation on {split} split")
|
||||
results = self.model.val(
|
||||
data=data_yaml, split=split, device=self.device, **kwargs
|
||||
)
|
||||
|
||||
logger.info("Validation completed successfully")
|
||||
return self._format_validation_results(results)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during validation: {e}")
|
||||
raise
|
||||
|
||||
def predict(
|
||||
self,
|
||||
source: str,
|
||||
conf: float = 0.25,
|
||||
iou: float = 0.45,
|
||||
save: bool = False,
|
||||
save_txt: bool = False,
|
||||
save_conf: bool = False,
|
||||
**kwargs,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Perform inference on image(s).
|
||||
|
||||
Args:
|
||||
source: Path to image or directory
|
||||
conf: Confidence threshold
|
||||
iou: IoU threshold for NMS
|
||||
save: Whether to save annotated images
|
||||
save_txt: Whether to save labels to .txt files
|
||||
save_conf: Whether to save confidence in labels
|
||||
**kwargs: Additional prediction arguments
|
||||
|
||||
Returns:
|
||||
List of detection dictionaries
|
||||
"""
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
|
||||
try:
|
||||
logger.info(f"Running inference on {source}")
|
||||
results = self.model.predict(
|
||||
source=source,
|
||||
conf=conf,
|
||||
iou=iou,
|
||||
save=save,
|
||||
save_txt=save_txt,
|
||||
save_conf=save_conf,
|
||||
device=self.device,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
detections = self._format_prediction_results(results)
|
||||
logger.info(f"Inference complete: {len(detections)} detections")
|
||||
return detections
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during inference: {e}")
|
||||
raise
|
||||
|
||||
def export(
|
||||
self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Export model to different format.
|
||||
|
||||
Args:
|
||||
format: Export format (onnx, torchscript, tflite, etc.)
|
||||
output_path: Path for exported model
|
||||
**kwargs: Additional export arguments
|
||||
|
||||
Returns:
|
||||
Path to exported model
|
||||
"""
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
|
||||
try:
|
||||
logger.info(f"Exporting model to {format} format")
|
||||
export_path = self.model.export(format=format, **kwargs)
|
||||
logger.info(f"Model exported to {export_path}")
|
||||
return str(export_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error exporting model: {e}")
|
||||
raise
|
||||
|
||||
def _format_training_results(self, results) -> Dict[str, Any]:
|
||||
"""Format training results into dictionary."""
|
||||
try:
|
||||
# Get the results dict
|
||||
results_dict = (
|
||||
results.results_dict if hasattr(results, "results_dict") else {}
|
||||
)
|
||||
|
||||
formatted = {
|
||||
"success": True,
|
||||
"final_epoch": getattr(results, "epoch", 0),
|
||||
"metrics": {
|
||||
"mAP50": float(results_dict.get("metrics/mAP50(B)", 0)),
|
||||
"mAP50-95": float(results_dict.get("metrics/mAP50-95(B)", 0)),
|
||||
"precision": float(results_dict.get("metrics/precision(B)", 0)),
|
||||
"recall": float(results_dict.get("metrics/recall(B)", 0)),
|
||||
},
|
||||
"best_model_path": str(Path(results.save_dir) / "weights" / "best.pt"),
|
||||
"last_model_path": str(Path(results.save_dir) / "weights" / "last.pt"),
|
||||
"save_dir": str(results.save_dir),
|
||||
}
|
||||
|
||||
return formatted
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting training results: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
def _format_validation_results(self, results) -> Dict[str, Any]:
|
||||
"""Format validation results into dictionary."""
|
||||
try:
|
||||
box_metrics = results.box
|
||||
|
||||
formatted = {
|
||||
"success": True,
|
||||
"mAP50": float(box_metrics.map50),
|
||||
"mAP50-95": float(box_metrics.map),
|
||||
"precision": float(box_metrics.mp),
|
||||
"recall": float(box_metrics.mr),
|
||||
"fitness": (
|
||||
float(results.fitness) if hasattr(results, "fitness") else 0.0
|
||||
),
|
||||
}
|
||||
|
||||
# Add per-class metrics if available
|
||||
if hasattr(box_metrics, "ap") and hasattr(results, "names"):
|
||||
class_metrics = {}
|
||||
for idx, name in results.names.items():
|
||||
if idx < len(box_metrics.ap):
|
||||
class_metrics[name] = {
|
||||
"ap": float(box_metrics.ap[idx]),
|
||||
"ap50": (
|
||||
float(box_metrics.ap50[idx])
|
||||
if hasattr(box_metrics, "ap50")
|
||||
else 0.0
|
||||
),
|
||||
}
|
||||
formatted["class_metrics"] = class_metrics
|
||||
|
||||
return formatted
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting validation results: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
def _format_prediction_results(self, results) -> List[Dict]:
|
||||
"""Format prediction results into list of dictionaries."""
|
||||
detections = []
|
||||
|
||||
try:
|
||||
for result in results:
|
||||
boxes = result.boxes
|
||||
image_path = str(result.path)
|
||||
orig_shape = result.orig_shape # (height, width)
|
||||
|
||||
for i in range(len(boxes)):
|
||||
# Get normalized coordinates
|
||||
xyxyn = boxes.xyxyn[i].cpu().numpy() # Normalized [x1, y1, x2, y2]
|
||||
|
||||
detection = {
|
||||
"image_path": image_path,
|
||||
"class_id": int(boxes.cls[i]),
|
||||
"class_name": result.names[int(boxes.cls[i])],
|
||||
"confidence": float(boxes.conf[i]),
|
||||
"bbox_normalized": [
|
||||
float(v) for v in xyxyn
|
||||
], # [x_min, y_min, x_max, y_max]
|
||||
"bbox_absolute": [
|
||||
float(v) for v in boxes.xyxy[i].cpu().numpy()
|
||||
], # Absolute pixels
|
||||
}
|
||||
detections.append(detection)
|
||||
|
||||
return detections
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting prediction results: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def convert_bbox_format(
|
||||
bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy"
|
||||
) -> List[float]:
|
||||
"""
|
||||
Convert bounding box between formats.
|
||||
|
||||
Formats:
|
||||
- xywh: [x_center, y_center, width, height]
|
||||
- xyxy: [x_min, y_min, x_max, y_max]
|
||||
|
||||
Args:
|
||||
bbox: Bounding box coordinates
|
||||
format_from: Source format
|
||||
format_to: Target format
|
||||
|
||||
Returns:
|
||||
Converted bounding box
|
||||
"""
|
||||
if format_from == "xywh" and format_to == "xyxy":
|
||||
x, y, w, h = bbox
|
||||
return [x - w / 2, y - h / 2, x + w / 2, y + h / 2]
|
||||
elif format_from == "xyxy" and format_to == "xywh":
|
||||
x1, y1, x2, y2 = bbox
|
||||
return [(x1 + x2) / 2, (y1 + y2) / 2, x2 - x1, y2 - y1]
|
||||
else:
|
||||
return bbox
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about the loaded model.
|
||||
|
||||
Returns:
|
||||
Dictionary with model information
|
||||
"""
|
||||
if self.model is None:
|
||||
return {"error": "Model not loaded"}
|
||||
|
||||
try:
|
||||
info = {
|
||||
"model_path": self.model_path,
|
||||
"device": self.device,
|
||||
"task": getattr(self.model, "task", "unknown"),
|
||||
}
|
||||
|
||||
# Try to get class names
|
||||
if hasattr(self.model, "names"):
|
||||
info["classes"] = self.model.names
|
||||
info["num_classes"] = len(self.model.names)
|
||||
|
||||
return info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model info: {e}")
|
||||
return {"error": str(e)}
|
||||
0
src/utils/__init__.py
Normal file
0
src/utils/__init__.py
Normal file
218
src/utils/config_manager.py
Normal file
218
src/utils/config_manager.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""
|
||||
Configuration manager for the microscopy object detection application.
|
||||
Handles loading, saving, and accessing application configuration.
|
||||
"""
|
||||
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
"""Manages application configuration."""
|
||||
|
||||
def __init__(self, config_path: str = "config/app_config.yaml"):
|
||||
"""
|
||||
Initialize configuration manager.
|
||||
|
||||
Args:
|
||||
config_path: Path to configuration file
|
||||
"""
|
||||
self.config_path = Path(config_path)
|
||||
self.config: Dict[str, Any] = {}
|
||||
self._load_config()
|
||||
|
||||
def _load_config(self) -> None:
|
||||
"""Load configuration from YAML file."""
|
||||
try:
|
||||
if self.config_path.exists():
|
||||
with open(self.config_path, "r") as f:
|
||||
self.config = yaml.safe_load(f) or {}
|
||||
logger.info(f"Configuration loaded from {self.config_path}")
|
||||
else:
|
||||
logger.warning(f"Configuration file not found: {self.config_path}")
|
||||
self._create_default_config()
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading configuration: {e}")
|
||||
self._create_default_config()
|
||||
|
||||
def _create_default_config(self) -> None:
|
||||
"""Create default configuration."""
|
||||
self.config = {
|
||||
"database": {"path": "data/detections.db"},
|
||||
"image_repository": {
|
||||
"base_path": "",
|
||||
"allowed_extensions": [
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".png",
|
||||
".tif",
|
||||
".tiff",
|
||||
".bmp",
|
||||
],
|
||||
},
|
||||
"models": {
|
||||
"default_base_model": "yolov8s.pt",
|
||||
"models_directory": "data/models",
|
||||
},
|
||||
"training": {
|
||||
"default_epochs": 100,
|
||||
"default_batch_size": 16,
|
||||
"default_imgsz": 640,
|
||||
"default_patience": 50,
|
||||
"default_lr0": 0.01,
|
||||
},
|
||||
"detection": {
|
||||
"default_confidence": 0.25,
|
||||
"default_iou": 0.45,
|
||||
"max_batch_size": 100,
|
||||
},
|
||||
"visualization": {
|
||||
"bbox_colors": {
|
||||
"organelle": "#FF6B6B",
|
||||
"membrane_branch": "#4ECDC4",
|
||||
"default": "#00FF00",
|
||||
},
|
||||
"bbox_thickness": 2,
|
||||
"font_size": 12,
|
||||
},
|
||||
"export": {"formats": ["csv", "json", "excel"], "default_format": "csv"},
|
||||
"logging": {
|
||||
"level": "INFO",
|
||||
"file": "logs/app.log",
|
||||
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
},
|
||||
}
|
||||
self.save_config()
|
||||
|
||||
def save_config(self) -> bool:
|
||||
"""
|
||||
Save current configuration to file.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Create directory if it doesn't exist
|
||||
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(self.config_path, "w") as f:
|
||||
yaml.dump(self.config, f, default_flow_style=False, sort_keys=False)
|
||||
|
||||
logger.info(f"Configuration saved to {self.config_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving configuration: {e}")
|
||||
return False
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Get configuration value by key.
|
||||
|
||||
Args:
|
||||
key: Configuration key (can use dot notation, e.g., 'database.path')
|
||||
default: Default value if key not found
|
||||
|
||||
Returns:
|
||||
Configuration value or default
|
||||
"""
|
||||
keys = key.split(".")
|
||||
value = self.config
|
||||
|
||||
for k in keys:
|
||||
if isinstance(value, dict) and k in value:
|
||||
value = value[k]
|
||||
else:
|
||||
return default
|
||||
|
||||
return value
|
||||
|
||||
def set(self, key: str, value: Any) -> None:
|
||||
"""
|
||||
Set configuration value by key.
|
||||
|
||||
Args:
|
||||
key: Configuration key (can use dot notation)
|
||||
value: Value to set
|
||||
"""
|
||||
keys = key.split(".")
|
||||
config = self.config
|
||||
|
||||
# Navigate to the nested dictionary
|
||||
for k in keys[:-1]:
|
||||
if k not in config:
|
||||
config[k] = {}
|
||||
config = config[k]
|
||||
|
||||
# Set the value
|
||||
config[keys[-1]] = value
|
||||
logger.debug(f"Configuration updated: {key} = {value}")
|
||||
|
||||
def get_section(self, section: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get entire configuration section.
|
||||
|
||||
Args:
|
||||
section: Section name (e.g., 'database', 'training')
|
||||
|
||||
Returns:
|
||||
Dictionary with section configuration
|
||||
"""
|
||||
return self.config.get(section, {})
|
||||
|
||||
def update_section(self, section: str, values: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Update entire configuration section.
|
||||
|
||||
Args:
|
||||
section: Section name
|
||||
values: Dictionary with new values
|
||||
"""
|
||||
if section not in self.config:
|
||||
self.config[section] = {}
|
||||
|
||||
self.config[section].update(values)
|
||||
logger.debug(f"Configuration section updated: {section}")
|
||||
|
||||
def reload(self) -> None:
|
||||
"""Reload configuration from file."""
|
||||
self._load_config()
|
||||
|
||||
def get_database_path(self) -> str:
|
||||
"""Get database path."""
|
||||
return self.get("database.path", "data/detections.db")
|
||||
|
||||
def get_image_repository_path(self) -> str:
|
||||
"""Get image repository base path."""
|
||||
return self.get("image_repository.base_path", "")
|
||||
|
||||
def set_image_repository_path(self, path: str) -> None:
|
||||
"""Set image repository base path."""
|
||||
self.set("image_repository.base_path", path)
|
||||
self.save_config()
|
||||
|
||||
def get_models_directory(self) -> str:
|
||||
"""Get models directory path."""
|
||||
return self.get("models.models_directory", "data/models")
|
||||
|
||||
def get_default_training_params(self) -> Dict[str, Any]:
|
||||
"""Get default training parameters."""
|
||||
return self.get_section("training")
|
||||
|
||||
def get_default_detection_params(self) -> Dict[str, Any]:
|
||||
"""Get default detection parameters."""
|
||||
return self.get_section("detection")
|
||||
|
||||
def get_bbox_colors(self) -> Dict[str, str]:
|
||||
"""Get bounding box colors for different classes."""
|
||||
return self.get("visualization.bbox_colors", {})
|
||||
|
||||
def get_allowed_extensions(self) -> list:
|
||||
"""Get list of allowed image file extensions."""
|
||||
return self.get(
|
||||
"image_repository.allowed_extensions", [".jpg", ".jpeg", ".png"]
|
||||
)
|
||||
235
src/utils/file_utils.py
Normal file
235
src/utils/file_utils.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""
|
||||
File utility functions for the microscopy object detection application.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_image_files(
|
||||
directory: str,
|
||||
allowed_extensions: Optional[List[str]] = None,
|
||||
recursive: bool = False,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get all image files in a directory.
|
||||
|
||||
Args:
|
||||
directory: Directory path to search
|
||||
allowed_extensions: List of allowed file extensions (e.g., ['.jpg', '.png'])
|
||||
recursive: Whether to search recursively
|
||||
|
||||
Returns:
|
||||
List of absolute paths to image files
|
||||
"""
|
||||
if allowed_extensions is None:
|
||||
allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
|
||||
|
||||
# Normalize extensions to lowercase
|
||||
allowed_extensions = [ext.lower() for ext in allowed_extensions]
|
||||
|
||||
image_files = []
|
||||
directory_path = Path(directory)
|
||||
|
||||
if not directory_path.exists():
|
||||
logger.error(f"Directory does not exist: {directory}")
|
||||
return image_files
|
||||
|
||||
try:
|
||||
if recursive:
|
||||
# Recursive search
|
||||
for ext in allowed_extensions:
|
||||
image_files.extend(directory_path.rglob(f"*{ext}"))
|
||||
# Also search uppercase extensions
|
||||
image_files.extend(directory_path.rglob(f"*{ext.upper()}"))
|
||||
else:
|
||||
# Top-level search only
|
||||
for ext in allowed_extensions:
|
||||
image_files.extend(directory_path.glob(f"*{ext}"))
|
||||
# Also search uppercase extensions
|
||||
image_files.extend(directory_path.glob(f"*{ext.upper()}"))
|
||||
|
||||
# Convert to absolute paths and sort
|
||||
image_files = sorted([str(f.absolute()) for f in image_files])
|
||||
logger.info(f"Found {len(image_files)} image files in {directory}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching for images: {e}")
|
||||
|
||||
return image_files
|
||||
|
||||
|
||||
def ensure_directory(directory: str) -> bool:
|
||||
"""
|
||||
Ensure a directory exists, create if it doesn't.
|
||||
|
||||
Args:
|
||||
directory: Directory path
|
||||
|
||||
Returns:
|
||||
True if directory exists or was created successfully
|
||||
"""
|
||||
try:
|
||||
Path(directory).mkdir(parents=True, exist_ok=True)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating directory {directory}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_relative_path(file_path: str, base_path: str) -> str:
|
||||
"""
|
||||
Get relative path from base path.
|
||||
|
||||
Args:
|
||||
file_path: Absolute file path
|
||||
base_path: Base directory path
|
||||
|
||||
Returns:
|
||||
Relative path string
|
||||
"""
|
||||
try:
|
||||
return str(Path(file_path).relative_to(base_path))
|
||||
except ValueError:
|
||||
# If file_path is not relative to base_path, return the filename
|
||||
return Path(file_path).name
|
||||
|
||||
|
||||
def validate_file_path(file_path: str, must_exist: bool = True) -> bool:
|
||||
"""
|
||||
Validate a file path.
|
||||
|
||||
Args:
|
||||
file_path: Path to validate
|
||||
must_exist: Whether the file must exist
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
path = Path(file_path)
|
||||
|
||||
if must_exist and not path.exists():
|
||||
logger.error(f"File does not exist: {file_path}")
|
||||
return False
|
||||
|
||||
if must_exist and not path.is_file():
|
||||
logger.error(f"Path is not a file: {file_path}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_file_size(file_path: str) -> int:
|
||||
"""
|
||||
Get file size in bytes.
|
||||
|
||||
Args:
|
||||
file_path: Path to file
|
||||
|
||||
Returns:
|
||||
File size in bytes, or 0 if error
|
||||
"""
|
||||
try:
|
||||
return Path(file_path).stat().st_size
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting file size for {file_path}: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
def format_file_size(size_bytes: int) -> str:
|
||||
"""
|
||||
Format file size in human-readable format.
|
||||
|
||||
Args:
|
||||
size_bytes: Size in bytes
|
||||
|
||||
Returns:
|
||||
Formatted string (e.g., "1.5 MB")
|
||||
"""
|
||||
for unit in ["B", "KB", "MB", "GB"]:
|
||||
if size_bytes < 1024.0:
|
||||
return f"{size_bytes:.1f} {unit}"
|
||||
size_bytes /= 1024.0
|
||||
return f"{size_bytes:.1f} TB"
|
||||
|
||||
|
||||
def create_unique_filename(directory: str, base_name: str, extension: str) -> str:
|
||||
"""
|
||||
Create a unique filename by adding a number suffix if file exists.
|
||||
|
||||
Args:
|
||||
directory: Directory path
|
||||
base_name: Base filename without extension
|
||||
extension: File extension (with or without dot)
|
||||
|
||||
Returns:
|
||||
Unique filename
|
||||
"""
|
||||
if not extension.startswith("."):
|
||||
extension = "." + extension
|
||||
|
||||
directory_path = Path(directory)
|
||||
filename = f"{base_name}{extension}"
|
||||
file_path = directory_path / filename
|
||||
|
||||
if not file_path.exists():
|
||||
return filename
|
||||
|
||||
# Add number suffix
|
||||
counter = 1
|
||||
while True:
|
||||
filename = f"{base_name}_{counter}{extension}"
|
||||
file_path = directory_path / filename
|
||||
if not file_path.exists():
|
||||
return filename
|
||||
counter += 1
|
||||
|
||||
|
||||
def is_image_file(
|
||||
file_path: str, allowed_extensions: Optional[List[str]] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a file is an image based on extension.
|
||||
|
||||
Args:
|
||||
file_path: Path to file
|
||||
allowed_extensions: List of allowed extensions
|
||||
|
||||
Returns:
|
||||
True if file is an image
|
||||
"""
|
||||
if allowed_extensions is None:
|
||||
allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
|
||||
|
||||
extension = Path(file_path).suffix.lower()
|
||||
return extension in [ext.lower() for ext in allowed_extensions]
|
||||
|
||||
|
||||
def safe_filename(filename: str) -> str:
|
||||
"""
|
||||
Convert a string to a safe filename by removing/replacing invalid characters.
|
||||
|
||||
Args:
|
||||
filename: Original filename
|
||||
|
||||
Returns:
|
||||
Safe filename
|
||||
"""
|
||||
# Replace invalid characters
|
||||
invalid_chars = '<>:"/\\|?*'
|
||||
for char in invalid_chars:
|
||||
filename = filename.replace(char, "_")
|
||||
|
||||
# Remove leading/trailing spaces and dots
|
||||
filename = filename.strip(". ")
|
||||
|
||||
# Ensure filename is not empty
|
||||
if not filename:
|
||||
filename = "unnamed"
|
||||
|
||||
return filename
|
||||
75
src/utils/logger.py
Normal file
75
src/utils/logger.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
Logging configuration for the microscopy object detection application.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def setup_logging(
|
||||
log_file: str = "logs/app.log",
|
||||
level: str = "INFO",
|
||||
log_format: Optional[str] = None,
|
||||
) -> logging.Logger:
|
||||
"""
|
||||
Setup application logging.
|
||||
|
||||
Args:
|
||||
log_file: Path to log file
|
||||
level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||
log_format: Custom log format string
|
||||
|
||||
Returns:
|
||||
Configured logger instance
|
||||
"""
|
||||
# Create logs directory if it doesn't exist
|
||||
log_path = Path(log_file)
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Default format if none provided
|
||||
if log_format is None:
|
||||
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
|
||||
# Convert level string to logging constant
|
||||
numeric_level = getattr(logging, level.upper(), logging.INFO)
|
||||
|
||||
# Configure root logger
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(numeric_level)
|
||||
|
||||
# Remove existing handlers
|
||||
root_logger.handlers.clear()
|
||||
|
||||
# Console handler
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(numeric_level)
|
||||
console_formatter = logging.Formatter(log_format)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# File handler
|
||||
file_handler = logging.FileHandler(log_file)
|
||||
file_handler.setLevel(numeric_level)
|
||||
file_formatter = logging.Formatter(log_format)
|
||||
file_handler.setFormatter(file_formatter)
|
||||
root_logger.addHandler(file_handler)
|
||||
|
||||
# Log initial message
|
||||
root_logger.info("Logging initialized")
|
||||
|
||||
return root_logger
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""
|
||||
Get a logger instance for a specific module.
|
||||
|
||||
Args:
|
||||
name: Logger name (typically __name__)
|
||||
|
||||
Returns:
|
||||
Logger instance
|
||||
"""
|
||||
return logging.getLogger(name)
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
Reference in New Issue
Block a user