123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147 |
- from typing import Iterator, List, Optional, Any
- from contextlib import contextmanager
- from sqlalchemy import create_engine, text
- from sqlalchemy.orm import Session, sessionmaker
- from app.config import config
- from app.logger import logger
- class Database:
- """数据库单例类"""
- _instance = None
- @staticmethod
- def initialize():
- """初始化数据库连接
- """
- # 调用单例实例会触发_init_db
- return Database()
- def __new__(cls):
- if cls._instance is None:
- cls._instance = super().__new__(cls)
- cls._instance._init_db()
- return cls._instance
- def _init_db(self):
- """初始化数据库连接"""
- db_config = config.database
- try:
- self.engine = create_engine(self._get_sqlalchemy_uri(
- host=db_config.host,
- port=db_config.port,
- user=db_config.user,
- password=db_config.password,
- name=db_config.name),
- pool_pre_ping=True,
- pool_recycle=3600)
- # 测试连接
- with self.engine.connect() as conn:
- conn.execute(text("SELECT 1"))
- logger.debug("数据库连接成功")
- except Exception as e:
- logger.warning(f"数据库连接失败: {e}")
- raise
- self.SessionLocal = sessionmaker(autocommit=False,
- autoflush=False,
- bind=self.engine)
- @staticmethod
- def _get_sqlalchemy_uri(host: str, port: int, user: str, password: str,
- name: str) -> str:
- """生成SQLAlchemy连接字符串"""
- sql = f"mysql+pymysql://{user}:{password}@{host}:{port}/{name}?charset=utf8mb4"
- # logger.debug(f"数据库连接字符串: {sql}")
- return sql
- @contextmanager
- def session(self) -> Iterator[Session]:
- """获取数据库会话"""
- db = self.SessionLocal()
- try:
- yield db
- except Exception:
- db.rollback()
- raise
- finally:
- db.close()
- def query(self, sql: str, params: Optional[dict] = None) -> List[dict]:
- """执行查询SQL"""
- with self.session() as db:
- try:
- result = db.execute(text(sql), params or {})
- return [dict(row) for row in result.mappings()]
- except Exception as e:
- logger.error(f"查询SQL出错: [{sql}] {e}")
- raise
- def execute(self, sql: str, params: Optional[dict] = None) -> int:
- """执行非查询SQL"""
- with self.session() as db:
- try:
- result = db.execute(text(sql), params or {})
- db.commit()
- return result.rowcount
- except Exception as e:
- logger.error(f"执行SQL出错: [{sql}] {e}")
- raise
- def batch_execute(self, sql: str, params_list: List[dict]) -> int:
- """批量执行SQL
- 注意:所有操作在单个事务中执行,要么全部成功,要么全部回滚
- """
- if not params_list:
- return 0
- with self.session() as db:
- try:
- result = db.execute(text(sql), params_list)
- db.commit()
- return result.rowcount
- except Exception as e:
- logger.error(f"批量执行SQL出错: [{sql}] {e}")
- raise
- def query_one(self,
- sql: str,
- params: Optional[dict] = None) -> Optional[dict]:
- """执行查询SQL并返回单条记录"""
- with self.session() as db:
- try:
- result = db.execute(text(sql), params or {})
- row = result.mappings().first()
- return dict(row) if row else None
- except Exception as e:
- logger.error(f"查询SQL出错: [{sql}] {e}")
- raise
- def execute_procedure(self,
- procedure_name: str,
- params: Optional[dict] = None) -> Any:
- """执行存储过程"""
- with self.session() as db:
- try:
- result = db.execute(text(f"CALL {procedure_name}"), params
- or {})
- db.commit()
- return result.fetchall()
- except Exception as e:
- logger.error(f"执行存储过程出错: [{procedure_name}] {e}")
- raise
- def execute_function(self,
- function_name: str,
- params: Optional[dict] = None) -> Any:
- """执行数据库函数"""
- with self.session() as db:
- try:
- result = db.execute(text(f"SELECT {function_name}"), params
- or {})
- return result.scalar()
- except Exception as e:
- logger.error(f"执行数据库函数出错: [{function_name}] {e}")
- raise
|