database_util.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. from typing import Dict, Any
  2. from urllib.parse import quote_plus
  3. from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
  4. from core.settings import db_settings
  5. class DBUtil:
  6. """
  7. 数据库连接管理器
  8. 支持多数据库配置,通过db_name参数切换
  9. """
  10. _engines = {} # 缓存异步引擎
  11. _session_makers = {} # 缓存会话工厂
  12. @classmethod
  13. def get_db(cls, db_name: str | None = None):
  14. """
  15. 获取数据库会话
  16. :param db_name: 数据库名称,如果为None则使用默认数据库
  17. :return: 数据库会话
  18. """
  19. key = db_name or "default"
  20. db_config = DBConfigLoader.get_db_config(db_name)
  21. # 获取或创建异步引擎
  22. if key in cls._engines:
  23. async_engine = cls._engines[key]
  24. else:
  25. async_engine = cls._create_async_engine(db_config)
  26. cls._engines[key] = async_engine
  27. # 获取或创建会话工厂
  28. if key in cls._session_makers:
  29. async_session_local = cls._session_makers[key]
  30. else:
  31. async_session_local = async_sessionmaker(
  32. autocommit=False, autoflush=False, bind=async_engine
  33. )
  34. cls._session_makers[key] = async_session_local
  35. db = async_session_local()
  36. try:
  37. yield db
  38. finally:
  39. db.close()
  40. @staticmethod
  41. def _create_async_engine(config):
  42. """根据配置创建异步引擎
  43. :param config: 数据库配置字典
  44. :return: 异步引擎实例
  45. """
  46. # 使用配置加载器构建连接URL
  47. url = DBConfigLoader.build_connection_url(config)
  48. # 获取引擎参数
  49. engine_params = DBConfigLoader.get_engine_params(config)
  50. # 创建异步引擎
  51. return create_async_engine(url, **engine_params)
  52. def __call__(self, *args, **kwargs):
  53. return self.get_db(*args, **kwargs)
  54. class DBConfigLoader:
  55. """
  56. 数据库配置加载器
  57. 负责加载和处理数据库配置,与数据库连接管理分离
  58. """
  59. @classmethod
  60. def get_db_config(cls, db_name: str | None = None) -> Dict[str, Any]:
  61. """
  62. 获取指定数据库的配置
  63. :param db_name: 数据库名称,如果为None则返回默认数据库配置
  64. :return: 数据库配置字典
  65. """
  66. return (
  67. db_settings.databases[db_name]
  68. if db_name
  69. else db_settings.databases["default"]
  70. )
  71. @classmethod
  72. def build_connection_url(cls, config: Dict[str, Any]) -> str:
  73. """
  74. 根据配置构建数据库连接URL
  75. :param config: 数据库配置字典
  76. :return: 连接URL字符串
  77. """
  78. db_type = config["type"]
  79. driver = "asyncmy" if db_type == "mysql" else "asyncpg"
  80. # 构建数据库连接URL
  81. url = (
  82. f"{db_type}+{driver}://{config['username']}:{quote_plus(config['password'])}@"
  83. f"{config['host']}:{config['port']}/{config['database']}"
  84. )
  85. return url
  86. @classmethod
  87. def get_engine_params(cls, config: Dict[str, Any]) -> Dict[str, Any]:
  88. """
  89. 获取创建引擎所需的参数
  90. :param config: 数据库配置字典
  91. :return: 引擎参数字典
  92. """
  93. return {
  94. "echo": config["echo"],
  95. "max_overflow": config["max_overflow"],
  96. "pool_size": config["pool_size"],
  97. "pool_recycle": config["pool_recycle"],
  98. "pool_timeout": config["pool_timeout"],
  99. }