import os from typing import Dict, Literal from pydantic import BaseModel, Field from pydantic_settings import BaseSettings class DBConfig(BaseModel): type: Literal['mysql', 'postgresql'] host: str port: int username: str password: str database: str echo: bool = False max_overflow: int = 10 pool_size: int = 50 pool_recycle: int = 3600 pool_timeout: int = 30 class DBSettings(BaseSettings): databases: Dict[str, DBConfig] = Field(default_factory=dict) # 新增字段 def __init__(self, **kwargs): super().__init__(**kwargs) self.databases = self._load_databases() @staticmethod def _load_databases() -> Dict[str, DBConfig]: db_dict = {} # 默认数据库 - 直接从环境变量加载 default_env_vars = { 'type': os.environ.get('DB_TYPE', 'mysql'), 'host': os.environ.get('DB_HOST'), 'port': int(os.environ.get('DB_PORT', 3306)), 'username': os.environ.get('DB_USERNAME'), 'password': os.environ.get('DB_PASSWORD'), 'database': os.environ.get('DB_DATABASE'), 'echo': os.environ.get('DB_ECHO', 'false').lower() == 'true', 'max_overflow': int(os.environ.get('DB_MAX_OVERFLOW', 10)), 'pool_size': int(os.environ.get('DB_POOL_SIZE', 50)), 'pool_recycle': int(os.environ.get('DB_POOL_RECYCLE', 3600)), 'pool_timeout': int(os.environ.get('DB_POOL_TIMEOUT', 30)), } # 验证必要字段 if not default_env_vars['host'] or not default_env_vars['port']: raise ValueError("Default database requires DB_HOST and DB_PORT") default_db = DBConfig(**default_env_vars) db_dict['default'] = default_db # 多数据库支持 for key in os.environ: if not key.startswith('DB_'): continue suffix = key[3:] if '_' not in suffix: continue # 不含下划线,视为全局配置 prefix = suffix.split('_', 1)[0] if not prefix: # 空前缀跳过 continue if prefix in db_dict: continue # 已加载过该数据库 host_key = f"DB_{prefix}_HOST" port_key = f"DB_{prefix}_PORT" host = os.environ.get(host_key) port = os.environ.get(port_key) if not host or not port: continue # 必要字段缺失,跳过 try: db_dict[prefix] = DBConfig( type=os.environ.get(f"DB_{prefix}_TYPE", default_db.type), host=host, port=int(port), username=os.environ.get(f"DB_{prefix}_USERNAME", default_db.username), password=os.environ.get(f"DB_{prefix}_PASSWORD", default_db.password), database=os.environ.get(f"DB_{prefix}_DATABASE", default_db.database), echo=os.environ.get(f"DB_{prefix}_ECHO", str(default_db.echo)).lower() == 'true', max_overflow=int(os.environ.get(f"DB_{prefix}_MAX_OVERFLOW", default_db.max_overflow)), pool_size=int(os.environ.get(f"DB_{prefix}_POOL_SIZE", default_db.pool_size)), pool_recycle=int(os.environ.get(f"DB_{prefix}_POOL_RECYCLE", default_db.pool_recycle)), pool_timeout=int(os.environ.get(f"DB_{prefix}_POOL_TIMEOUT", default_db.pool_timeout)), ) db_dict[prefix.lower()] = db_dict[prefix] except Exception as e: print(f"Failed to load DB config for {prefix}: {e}") return db_dict