| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798 |
- import os
- from typing import Dict, Literal
- from pydantic import BaseModel, Field
- from pydantic_settings import BaseSettings
- class DBConfig(BaseModel):
- type: Literal['mysql', 'postgresql']
- host: str
- port: int
- username: str
- password: str
- database: str
- echo: bool = False
- max_overflow: int = 10
- pool_size: int = 50
- pool_recycle: int = 3600
- pool_timeout: int = 30
- class DBSettings(BaseSettings):
- databases: Dict[str, DBConfig] = Field(default_factory=dict) # 新增字段
-
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self.databases = self._load_databases()
-
- @staticmethod
- def _load_databases() -> Dict[str, DBConfig]:
- db_dict = {}
-
- # 默认数据库 - 直接从环境变量加载
- default_env_vars = {
- 'type': os.environ.get('DB_TYPE', 'mysql'),
- 'host': os.environ.get('DB_HOST'),
- 'port': int(os.environ.get('DB_PORT', 3306)),
- 'username': os.environ.get('DB_USERNAME'),
- 'password': os.environ.get('DB_PASSWORD'),
- 'database': os.environ.get('DB_DATABASE'),
- 'echo': os.environ.get('DB_ECHO', 'false').lower() == 'true',
- 'max_overflow': int(os.environ.get('DB_MAX_OVERFLOW', 10)),
- 'pool_size': int(os.environ.get('DB_POOL_SIZE', 50)),
- 'pool_recycle': int(os.environ.get('DB_POOL_RECYCLE', 3600)),
- 'pool_timeout': int(os.environ.get('DB_POOL_TIMEOUT', 30)),
- }
-
- # 验证必要字段
- if not default_env_vars['host'] or not default_env_vars['port']:
- raise ValueError("Default database requires DB_HOST and DB_PORT")
-
- default_db = DBConfig(**default_env_vars)
- db_dict['default'] = default_db
-
- # 多数据库支持
- for key in os.environ:
- if not key.startswith('DB_'):
- continue
-
- suffix = key[3:]
- if '_' not in suffix:
- continue # 不含下划线,视为全局配置
-
- prefix = suffix.split('_', 1)[0]
- if not prefix: # 空前缀跳过
- continue
-
- if prefix in db_dict:
- continue # 已加载过该数据库
-
- host_key = f"DB_{prefix}_HOST"
- port_key = f"DB_{prefix}_PORT"
-
- host = os.environ.get(host_key)
- port = os.environ.get(port_key)
-
- if not host or not port:
- continue # 必要字段缺失,跳过
-
- try:
- db_dict[prefix] = DBConfig(
- type=os.environ.get(f"DB_{prefix}_TYPE", default_db.type),
- host=host,
- port=int(port),
- username=os.environ.get(f"DB_{prefix}_USERNAME", default_db.username),
- password=os.environ.get(f"DB_{prefix}_PASSWORD", default_db.password),
- database=os.environ.get(f"DB_{prefix}_DATABASE", default_db.database),
- echo=os.environ.get(f"DB_{prefix}_ECHO", str(default_db.echo)).lower() == 'true',
- max_overflow=int(os.environ.get(f"DB_{prefix}_MAX_OVERFLOW", default_db.max_overflow)),
- pool_size=int(os.environ.get(f"DB_{prefix}_POOL_SIZE", default_db.pool_size)),
- pool_recycle=int(os.environ.get(f"DB_{prefix}_POOL_RECYCLE", default_db.pool_recycle)),
- pool_timeout=int(os.environ.get(f"DB_{prefix}_POOL_TIMEOUT", default_db.pool_timeout)),
- )
- db_dict[prefix.lower()] = db_dict[prefix]
- except Exception as e:
- print(f"Failed to load DB config for {prefix}: {e}")
-
- return db_dict
|