from typing import Iterator, List, Optional, Any from contextlib import contextmanager from sqlalchemy import create_engine, text from sqlalchemy.orm import Session, sessionmaker from app.config import config from app.logger import logger class Database: """数据库单例类""" _instance = None @staticmethod def initialize(): """初始化数据库连接 """ # 调用单例实例会触发_init_db return Database() def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._init_db() return cls._instance def _init_db(self): """初始化数据库连接""" db_config = config.database try: self.engine = create_engine(self._get_sqlalchemy_uri( host=db_config.host, port=db_config.port, user=db_config.user, password=db_config.password, name=db_config.name), pool_pre_ping=True, pool_recycle=3600) # 测试连接 with self.engine.connect() as conn: conn.execute(text("SELECT 1")) logger.debug("数据库连接成功") except Exception as e: logger.warning(f"数据库连接失败: {e}") raise self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine) @staticmethod def _get_sqlalchemy_uri(host: str, port: int, user: str, password: str, name: str) -> str: """生成SQLAlchemy连接字符串""" sql = f"mysql+pymysql://{user}:{password}@{host}:{port}/{name}?charset=utf8mb4" # logger.debug(f"数据库连接字符串: {sql}") return sql @contextmanager def session(self) -> Iterator[Session]: """获取数据库会话""" db = self.SessionLocal() try: yield db except Exception: db.rollback() raise finally: db.close() def query(self, sql: str, params: Optional[dict] = None) -> List[dict]: """执行查询SQL""" with self.session() as db: try: result = db.execute(text(sql), params or {}) return [dict(row) for row in result.mappings()] except Exception as e: logger.error(f"查询SQL出错: [{sql}] {e}") raise def execute(self, sql: str, params: Optional[dict] = None) -> int: """执行非查询SQL""" with self.session() as db: try: result = db.execute(text(sql), params or {}) db.commit() return result.rowcount except Exception as e: logger.error(f"执行SQL出错: [{sql}] {e}") raise def batch_execute(self, sql: str, params_list: List[dict]) -> int: """批量执行SQL 注意:所有操作在单个事务中执行,要么全部成功,要么全部回滚 """ if not params_list: return 0 with self.session() as db: try: result = db.execute(text(sql), params_list) db.commit() return result.rowcount except Exception as e: logger.error(f"批量执行SQL出错: [{sql}] {e}") raise def query_one(self, sql: str, params: Optional[dict] = None) -> Optional[dict]: """执行查询SQL并返回单条记录""" with self.session() as db: try: result = db.execute(text(sql), params or {}) row = result.mappings().first() return dict(row) if row else None except Exception as e: logger.error(f"查询SQL出错: [{sql}] {e}") raise def execute_procedure(self, procedure_name: str, params: Optional[dict] = None) -> Any: """执行存储过程""" with self.session() as db: try: result = db.execute(text(f"CALL {procedure_name}"), params or {}) db.commit() return result.fetchall() except Exception as e: logger.error(f"执行存储过程出错: [{procedure_name}] {e}") raise def execute_function(self, function_name: str, params: Optional[dict] = None) -> Any: """执行数据库函数""" with self.session() as db: try: result = db.execute(text(f"SELECT {function_name}"), params or {}) return result.scalar() except Exception as e: logger.error(f"执行数据库函数出错: [{function_name}] {e}") raise