Files
object-segmentation/src/utils/config_manager.py

231 lines
7.2 KiB
Python

"""
Configuration manager for the microscopy object detection application.
Handles loading, saving, and accessing application configuration.
"""
import yaml
from pathlib import Path
from typing import Any, Dict, Optional
from src.utils.logger import get_logger
from src.utils.image import Image
logger = get_logger(__name__)
class ConfigManager:
"""Manages application configuration."""
def __init__(self, config_path: str = "config/app_config.yaml"):
"""
Initialize configuration manager.
Args:
config_path: Path to configuration file
"""
self.config_path = Path(config_path)
self.config: Dict[str, Any] = {}
self._load_config()
def _load_config(self) -> None:
"""Load configuration from YAML file."""
try:
if self.config_path.exists():
with open(self.config_path, "r") as f:
self.config = yaml.safe_load(f) or {}
logger.info(f"Configuration loaded from {self.config_path}")
else:
logger.warning(f"Configuration file not found: {self.config_path}")
self._create_default_config()
except Exception as e:
logger.error(f"Error loading configuration: {e}")
self._create_default_config()
def _create_default_config(self) -> None:
"""Create default configuration."""
self.config = {
"database": {"path": "data/detections.db"},
"image_repository": {
"base_path": "",
"allowed_extensions": Image.SUPPORTED_EXTENSIONS,
},
"models": {
"default_base_model": "yolov8s-seg.pt",
"models_directory": "data/models",
"base_model_choices": [
"yolov8s-seg.pt",
"yolov11s-seg.pt",
],
},
"training": {
"default_epochs": 100,
"default_batch_size": 16,
"default_imgsz": 640,
"default_patience": 50,
"default_lr0": 0.01,
"two_stage": {
"enabled": False,
"stage1": {
"epochs": 20,
"lr0": 0.0005,
"patience": 10,
"freeze": 10,
},
"stage2": {
"epochs": 150,
"lr0": 0.0003,
"patience": 30,
},
},
},
"detection": {
"default_confidence": 0.25,
"default_iou": 0.45,
"max_batch_size": 100,
},
"visualization": {
"bbox_colors": {
"organelle": "#FF6B6B",
"membrane_branch": "#4ECDC4",
"default": "#00FF00",
},
"bbox_thickness": 2,
"font_size": 12,
},
"export": {"formats": ["csv", "json", "excel"], "default_format": "csv"},
"logging": {
"level": "INFO",
"file": "logs/app.log",
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
},
}
self.save_config()
def save_config(self) -> bool:
"""
Save current configuration to file.
Returns:
True if successful, False otherwise
"""
try:
# Create directory if it doesn't exist
self.config_path.parent.mkdir(parents=True, exist_ok=True)
with open(self.config_path, "w") as f:
yaml.dump(self.config, f, default_flow_style=False, sort_keys=False)
logger.info(f"Configuration saved to {self.config_path}")
return True
except Exception as e:
logger.error(f"Error saving configuration: {e}")
return False
def get(self, key: str, default: Any = None) -> Any:
"""
Get configuration value by key.
Args:
key: Configuration key (can use dot notation, e.g., 'database.path')
default: Default value if key not found
Returns:
Configuration value or default
"""
keys = key.split(".")
value = self.config
for k in keys:
if isinstance(value, dict) and k in value:
value = value[k]
else:
return default
return value
def set(self, key: str, value: Any) -> None:
"""
Set configuration value by key.
Args:
key: Configuration key (can use dot notation)
value: Value to set
"""
keys = key.split(".")
config = self.config
# Navigate to the nested dictionary
for k in keys[:-1]:
if k not in config:
config[k] = {}
config = config[k]
# Set the value
config[keys[-1]] = value
logger.debug(f"Configuration updated: {key} = {value}")
def get_section(self, section: str) -> Dict[str, Any]:
"""
Get entire configuration section.
Args:
section: Section name (e.g., 'database', 'training')
Returns:
Dictionary with section configuration
"""
return self.config.get(section, {})
def update_section(self, section: str, values: Dict[str, Any]) -> None:
"""
Update entire configuration section.
Args:
section: Section name
values: Dictionary with new values
"""
if section not in self.config:
self.config[section] = {}
self.config[section].update(values)
logger.debug(f"Configuration section updated: {section}")
def reload(self) -> None:
"""Reload configuration from file."""
self._load_config()
def get_database_path(self) -> str:
"""Get database path."""
return self.get("database.path", "data/detections.db")
def get_image_repository_path(self) -> str:
"""Get image repository base path."""
return self.get("image_repository.base_path", "")
def set_image_repository_path(self, path: str) -> None:
"""Set image repository base path."""
self.set("image_repository.base_path", path)
self.save_config()
def get_models_directory(self) -> str:
"""Get models directory path."""
return self.get("models.models_directory", "data/models")
def get_default_training_params(self) -> Dict[str, Any]:
"""Get default training parameters."""
return self.get_section("training")
def get_default_detection_params(self) -> Dict[str, Any]:
"""Get default detection parameters."""
return self.get_section("detection")
def get_bbox_colors(self) -> Dict[str, str]:
"""Get bounding box colors for different classes."""
return self.get("visualization.bbox_colors", {})
def get_allowed_extensions(self) -> list:
"""Get list of allowed image file extensions."""
return self.get(
"image_repository.allowed_extensions", Image.SUPPORTED_EXTENSIONS
)