231 lines
7.2 KiB
Python
231 lines
7.2 KiB
Python
"""
|
|
Configuration manager for the microscopy object detection application.
|
|
Handles loading, saving, and accessing application configuration.
|
|
"""
|
|
|
|
import yaml
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Optional
|
|
from src.utils.logger import get_logger
|
|
from src.utils.image import Image
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class ConfigManager:
|
|
"""Manages application configuration."""
|
|
|
|
def __init__(self, config_path: str = "config/app_config.yaml"):
|
|
"""
|
|
Initialize configuration manager.
|
|
|
|
Args:
|
|
config_path: Path to configuration file
|
|
"""
|
|
self.config_path = Path(config_path)
|
|
self.config: Dict[str, Any] = {}
|
|
self._load_config()
|
|
|
|
def _load_config(self) -> None:
|
|
"""Load configuration from YAML file."""
|
|
try:
|
|
if self.config_path.exists():
|
|
with open(self.config_path, "r") as f:
|
|
self.config = yaml.safe_load(f) or {}
|
|
logger.info(f"Configuration loaded from {self.config_path}")
|
|
else:
|
|
logger.warning(f"Configuration file not found: {self.config_path}")
|
|
self._create_default_config()
|
|
except Exception as e:
|
|
logger.error(f"Error loading configuration: {e}")
|
|
self._create_default_config()
|
|
|
|
def _create_default_config(self) -> None:
|
|
"""Create default configuration."""
|
|
self.config = {
|
|
"database": {"path": "data/detections.db"},
|
|
"image_repository": {
|
|
"base_path": "",
|
|
"allowed_extensions": Image.SUPPORTED_EXTENSIONS,
|
|
},
|
|
"models": {
|
|
"default_base_model": "yolov8s-seg.pt",
|
|
"models_directory": "data/models",
|
|
"base_model_choices": [
|
|
"yolov8s-seg.pt",
|
|
"yolov11s-seg.pt",
|
|
],
|
|
},
|
|
"training": {
|
|
"default_epochs": 100,
|
|
"default_batch_size": 16,
|
|
"default_imgsz": 640,
|
|
"default_patience": 50,
|
|
"default_lr0": 0.01,
|
|
"two_stage": {
|
|
"enabled": False,
|
|
"stage1": {
|
|
"epochs": 20,
|
|
"lr0": 0.0005,
|
|
"patience": 10,
|
|
"freeze": 10,
|
|
},
|
|
"stage2": {
|
|
"epochs": 150,
|
|
"lr0": 0.0003,
|
|
"patience": 30,
|
|
},
|
|
},
|
|
},
|
|
"detection": {
|
|
"default_confidence": 0.25,
|
|
"default_iou": 0.45,
|
|
"max_batch_size": 100,
|
|
},
|
|
"visualization": {
|
|
"bbox_colors": {
|
|
"organelle": "#FF6B6B",
|
|
"membrane_branch": "#4ECDC4",
|
|
"default": "#00FF00",
|
|
},
|
|
"bbox_thickness": 2,
|
|
"font_size": 12,
|
|
},
|
|
"export": {"formats": ["csv", "json", "excel"], "default_format": "csv"},
|
|
"logging": {
|
|
"level": "INFO",
|
|
"file": "logs/app.log",
|
|
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
},
|
|
}
|
|
self.save_config()
|
|
|
|
def save_config(self) -> bool:
|
|
"""
|
|
Save current configuration to file.
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
try:
|
|
# Create directory if it doesn't exist
|
|
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
with open(self.config_path, "w") as f:
|
|
yaml.dump(self.config, f, default_flow_style=False, sort_keys=False)
|
|
|
|
logger.info(f"Configuration saved to {self.config_path}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error saving configuration: {e}")
|
|
return False
|
|
|
|
def get(self, key: str, default: Any = None) -> Any:
|
|
"""
|
|
Get configuration value by key.
|
|
|
|
Args:
|
|
key: Configuration key (can use dot notation, e.g., 'database.path')
|
|
default: Default value if key not found
|
|
|
|
Returns:
|
|
Configuration value or default
|
|
"""
|
|
keys = key.split(".")
|
|
value = self.config
|
|
|
|
for k in keys:
|
|
if isinstance(value, dict) and k in value:
|
|
value = value[k]
|
|
else:
|
|
return default
|
|
|
|
return value
|
|
|
|
def set(self, key: str, value: Any) -> None:
|
|
"""
|
|
Set configuration value by key.
|
|
|
|
Args:
|
|
key: Configuration key (can use dot notation)
|
|
value: Value to set
|
|
"""
|
|
keys = key.split(".")
|
|
config = self.config
|
|
|
|
# Navigate to the nested dictionary
|
|
for k in keys[:-1]:
|
|
if k not in config:
|
|
config[k] = {}
|
|
config = config[k]
|
|
|
|
# Set the value
|
|
config[keys[-1]] = value
|
|
logger.debug(f"Configuration updated: {key} = {value}")
|
|
|
|
def get_section(self, section: str) -> Dict[str, Any]:
|
|
"""
|
|
Get entire configuration section.
|
|
|
|
Args:
|
|
section: Section name (e.g., 'database', 'training')
|
|
|
|
Returns:
|
|
Dictionary with section configuration
|
|
"""
|
|
return self.config.get(section, {})
|
|
|
|
def update_section(self, section: str, values: Dict[str, Any]) -> None:
|
|
"""
|
|
Update entire configuration section.
|
|
|
|
Args:
|
|
section: Section name
|
|
values: Dictionary with new values
|
|
"""
|
|
if section not in self.config:
|
|
self.config[section] = {}
|
|
|
|
self.config[section].update(values)
|
|
logger.debug(f"Configuration section updated: {section}")
|
|
|
|
def reload(self) -> None:
|
|
"""Reload configuration from file."""
|
|
self._load_config()
|
|
|
|
def get_database_path(self) -> str:
|
|
"""Get database path."""
|
|
return self.get("database.path", "data/detections.db")
|
|
|
|
def get_image_repository_path(self) -> str:
|
|
"""Get image repository base path."""
|
|
return self.get("image_repository.base_path", "")
|
|
|
|
def set_image_repository_path(self, path: str) -> None:
|
|
"""Set image repository base path."""
|
|
self.set("image_repository.base_path", path)
|
|
self.save_config()
|
|
|
|
def get_models_directory(self) -> str:
|
|
"""Get models directory path."""
|
|
return self.get("models.models_directory", "data/models")
|
|
|
|
def get_default_training_params(self) -> Dict[str, Any]:
|
|
"""Get default training parameters."""
|
|
return self.get_section("training")
|
|
|
|
def get_default_detection_params(self) -> Dict[str, Any]:
|
|
"""Get default detection parameters."""
|
|
return self.get_section("detection")
|
|
|
|
def get_bbox_colors(self) -> Dict[str, str]:
|
|
"""Get bounding box colors for different classes."""
|
|
return self.get("visualization.bbox_colors", {})
|
|
|
|
def get_allowed_extensions(self) -> list:
|
|
"""Get list of allowed image file extensions."""
|
|
return self.get(
|
|
"image_repository.allowed_extensions", Image.SUPPORTED_EXTENSIONS
|
|
)
|