| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124 |
- from typing import Dict, Any
- from urllib.parse import quote_plus
- from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
- from core.settings import db_settings
- class DBUtil:
- """
- 数据库连接管理器
- 支持多数据库配置,通过db_name参数切换
- """
- _engines = {} # 缓存异步引擎
- _session_makers = {} # 缓存会话工厂
- @classmethod
- def get_db(cls, db_name: str | None = None):
- """
- 获取数据库会话
- :param db_name: 数据库名称,如果为None则使用默认数据库
- :return: 数据库会话
- """
- key = db_name or "default"
- db_config = DBConfigLoader.get_db_config(db_name)
- # 获取或创建异步引擎
- if key in cls._engines:
- async_engine = cls._engines[key]
- else:
- async_engine = cls._create_async_engine(db_config)
- cls._engines[key] = async_engine
- # 获取或创建会话工厂
- if key in cls._session_makers:
- async_session_local = cls._session_makers[key]
- else:
- async_session_local = async_sessionmaker(
- autocommit=False, autoflush=False, bind=async_engine
- )
- cls._session_makers[key] = async_session_local
- db = async_session_local()
- try:
- yield db
- finally:
- db.close()
- @staticmethod
- def _create_async_engine(config):
- """根据配置创建异步引擎
- :param config: 数据库配置字典
- :return: 异步引擎实例
- """
- # 使用配置加载器构建连接URL
- url = DBConfigLoader.build_connection_url(config)
- # 获取引擎参数
- engine_params = DBConfigLoader.get_engine_params(config)
- # 创建异步引擎
- return create_async_engine(url, **engine_params)
- def __call__(self, *args, **kwargs):
- return self.get_db(*args, **kwargs)
- class DBConfigLoader:
- """
- 数据库配置加载器
- 负责加载和处理数据库配置,与数据库连接管理分离
- """
- @classmethod
- def get_db_config(cls, db_name: str | None = None) -> Dict[str, Any]:
- """
- 获取指定数据库的配置
- :param db_name: 数据库名称,如果为None则返回默认数据库配置
- :return: 数据库配置字典
- """
- return (
- db_settings.databases[db_name]
- if db_name
- else db_settings.databases["default"]
- )
- @classmethod
- def build_connection_url(cls, config: Dict[str, Any]) -> str:
- """
- 根据配置构建数据库连接URL
- :param config: 数据库配置字典
- :return: 连接URL字符串
- """
- db_type = config["type"]
- driver = "asyncmy" if db_type == "mysql" else "asyncpg"
- # 构建数据库连接URL
- url = (
- f"{db_type}+{driver}://{config['username']}:{quote_plus(config['password'])}@"
- f"{config['host']}:{config['port']}/{config['database']}"
- )
- return url
- @classmethod
- def get_engine_params(cls, config: Dict[str, Any]) -> Dict[str, Any]:
- """
- 获取创建引擎所需的参数
- :param config: 数据库配置字典
- :return: 引擎参数字典
- """
- return {
- "echo": config["echo"],
- "max_overflow": config["max_overflow"],
- "pool_size": config["pool_size"],
- "pool_recycle": config["pool_recycle"],
- "pool_timeout": config["pool_timeout"],
- }
|