db_settings.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import os
  2. from typing import Dict, Literal
  3. from pydantic import BaseModel, Field
  4. from pydantic_settings import BaseSettings
  5. class DBConfig(BaseModel):
  6. type: Literal['mysql', 'postgresql']
  7. host: str
  8. port: int
  9. username: str
  10. password: str
  11. database: str
  12. echo: bool = False
  13. max_overflow: int = 10
  14. pool_size: int = 50
  15. pool_recycle: int = 3600
  16. pool_timeout: int = 30
  17. class DBSettings(BaseSettings):
  18. databases: Dict[str, DBConfig] = Field(default_factory=dict) # 新增字段
  19. def __init__(self, **kwargs):
  20. super().__init__(**kwargs)
  21. self.databases = self._load_databases()
  22. @staticmethod
  23. def _load_databases() -> Dict[str, DBConfig]:
  24. db_dict = {}
  25. # 默认数据库 - 直接从环境变量加载
  26. default_env_vars = {
  27. 'type': os.environ.get('DB_TYPE', 'mysql'),
  28. 'host': os.environ.get('DB_HOST'),
  29. 'port': int(os.environ.get('DB_PORT', 3306)),
  30. 'username': os.environ.get('DB_USERNAME'),
  31. 'password': os.environ.get('DB_PASSWORD'),
  32. 'database': os.environ.get('DB_DATABASE'),
  33. 'echo': os.environ.get('DB_ECHO', 'false').lower() == 'true',
  34. 'max_overflow': int(os.environ.get('DB_MAX_OVERFLOW', 10)),
  35. 'pool_size': int(os.environ.get('DB_POOL_SIZE', 50)),
  36. 'pool_recycle': int(os.environ.get('DB_POOL_RECYCLE', 3600)),
  37. 'pool_timeout': int(os.environ.get('DB_POOL_TIMEOUT', 30)),
  38. }
  39. # 验证必要字段
  40. if not default_env_vars['host'] or not default_env_vars['port']:
  41. raise ValueError("Default database requires DB_HOST and DB_PORT")
  42. default_db = DBConfig(**default_env_vars)
  43. db_dict['default'] = default_db
  44. # 多数据库支持
  45. for key in os.environ:
  46. if not key.startswith('DB_'):
  47. continue
  48. suffix = key[3:]
  49. if '_' not in suffix:
  50. continue # 不含下划线,视为全局配置
  51. prefix = suffix.split('_', 1)[0]
  52. if not prefix: # 空前缀跳过
  53. continue
  54. if prefix in db_dict:
  55. continue # 已加载过该数据库
  56. host_key = f"DB_{prefix}_HOST"
  57. port_key = f"DB_{prefix}_PORT"
  58. host = os.environ.get(host_key)
  59. port = os.environ.get(port_key)
  60. if not host or not port:
  61. continue # 必要字段缺失,跳过
  62. try:
  63. db_dict[prefix] = DBConfig(
  64. type=os.environ.get(f"DB_{prefix}_TYPE", default_db.type),
  65. host=host,
  66. port=int(port),
  67. username=os.environ.get(f"DB_{prefix}_USERNAME", default_db.username),
  68. password=os.environ.get(f"DB_{prefix}_PASSWORD", default_db.password),
  69. database=os.environ.get(f"DB_{prefix}_DATABASE", default_db.database),
  70. echo=os.environ.get(f"DB_{prefix}_ECHO", str(default_db.echo)).lower() == 'true',
  71. max_overflow=int(os.environ.get(f"DB_{prefix}_MAX_OVERFLOW", default_db.max_overflow)),
  72. pool_size=int(os.environ.get(f"DB_{prefix}_POOL_SIZE", default_db.pool_size)),
  73. pool_recycle=int(os.environ.get(f"DB_{prefix}_POOL_RECYCLE", default_db.pool_recycle)),
  74. pool_timeout=int(os.environ.get(f"DB_{prefix}_POOL_TIMEOUT", default_db.pool_timeout)),
  75. )
  76. db_dict[prefix.lower()] = db_dict[prefix]
  77. except Exception as e:
  78. print(f"Failed to load DB config for {prefix}: {e}")
  79. return db_dict