|
@@ -0,0 +1,147 @@
|
|
|
+from typing import Iterator, List, Optional, Any
|
|
|
+from contextlib import contextmanager
|
|
|
+from sqlalchemy import create_engine, text
|
|
|
+from sqlalchemy.orm import Session, sessionmaker
|
|
|
+from app.config import config
|
|
|
+from app.logger import logger
|
|
|
+
|
|
|
+
|
|
|
+class Database:
|
|
|
+ """数据库单例类"""
|
|
|
+ _instance = None
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def initialize():
|
|
|
+ """初始化数据库连接
|
|
|
+ """
|
|
|
+ # 调用单例实例会触发_init_db
|
|
|
+ return Database()
|
|
|
+
|
|
|
+ def __new__(cls):
|
|
|
+ if cls._instance is None:
|
|
|
+ cls._instance = super().__new__(cls)
|
|
|
+ cls._instance._init_db()
|
|
|
+ return cls._instance
|
|
|
+
|
|
|
+ def _init_db(self):
|
|
|
+ """初始化数据库连接"""
|
|
|
+ db_config = config.database
|
|
|
+ try:
|
|
|
+ self.engine = create_engine(self._get_sqlalchemy_uri(
|
|
|
+ host=db_config.host,
|
|
|
+ port=db_config.port,
|
|
|
+ user=db_config.user,
|
|
|
+ password=db_config.password,
|
|
|
+ name=db_config.name),
|
|
|
+ pool_pre_ping=True,
|
|
|
+ pool_recycle=3600)
|
|
|
+
|
|
|
+ # 测试连接
|
|
|
+ with self.engine.connect() as conn:
|
|
|
+ conn.execute(text("SELECT 1"))
|
|
|
+
|
|
|
+ logger.debug("数据库连接成功")
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"数据库连接失败: {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ self.SessionLocal = sessionmaker(autocommit=False,
|
|
|
+ autoflush=False,
|
|
|
+ bind=self.engine)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _get_sqlalchemy_uri(host: str, port: int, user: str, password: str,
|
|
|
+ name: str) -> str:
|
|
|
+ """生成SQLAlchemy连接字符串"""
|
|
|
+ sql = f"mysql+pymysql://{user}:{password}@{host}:{port}/{name}?charset=utf8mb4"
|
|
|
+ logger.debug(f"数据库连接字符串: {sql}")
|
|
|
+ return sql
|
|
|
+
|
|
|
+ @contextmanager
|
|
|
+ def session(self) -> Iterator[Session]:
|
|
|
+ """获取数据库会话"""
|
|
|
+ db = self.SessionLocal()
|
|
|
+ try:
|
|
|
+ yield db
|
|
|
+ except Exception:
|
|
|
+ db.rollback()
|
|
|
+ raise
|
|
|
+ finally:
|
|
|
+ db.close()
|
|
|
+
|
|
|
+ def query(self, sql: str, params: Optional[dict] = None) -> List[dict]:
|
|
|
+ """执行查询SQL"""
|
|
|
+ with self.session() as db:
|
|
|
+ try:
|
|
|
+ result = db.execute(text(sql), params or {})
|
|
|
+ return [dict(row) for row in result.mappings()]
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"查询SQL出错: [{sql}] {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ def execute(self, sql: str, params: Optional[dict] = None) -> int:
|
|
|
+ """执行非查询SQL"""
|
|
|
+ with self.session() as db:
|
|
|
+ try:
|
|
|
+ result = db.execute(text(sql), params or {})
|
|
|
+ db.commit()
|
|
|
+ return result.rowcount
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"执行SQL出错: [{sql}] {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ def batch_execute(self, sql: str, params_list: List[dict]) -> int:
|
|
|
+ """批量执行SQL
|
|
|
+ 注意:所有操作在单个事务中执行,要么全部成功,要么全部回滚
|
|
|
+ """
|
|
|
+ if not params_list:
|
|
|
+ return 0
|
|
|
+
|
|
|
+ with self.session() as db:
|
|
|
+ try:
|
|
|
+ result = db.execute(text(sql), params_list)
|
|
|
+ db.commit()
|
|
|
+ return result.rowcount
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"批量执行SQL出错: [{sql}] {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ def query_one(self,
|
|
|
+ sql: str,
|
|
|
+ params: Optional[dict] = None) -> Optional[dict]:
|
|
|
+ """执行查询SQL并返回单条记录"""
|
|
|
+ with self.session() as db:
|
|
|
+ try:
|
|
|
+ result = db.execute(text(sql), params or {})
|
|
|
+ row = result.mappings().first()
|
|
|
+ return dict(row) if row else None
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"查询SQL出错: [{sql}] {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ def execute_procedure(self,
|
|
|
+ procedure_name: str,
|
|
|
+ params: Optional[dict] = None) -> Any:
|
|
|
+ """执行存储过程"""
|
|
|
+ with self.session() as db:
|
|
|
+ try:
|
|
|
+ result = db.execute(text(f"CALL {procedure_name}"), params
|
|
|
+ or {})
|
|
|
+ db.commit()
|
|
|
+ return result.fetchall()
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"执行存储过程出错: [{procedure_name}] {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ def execute_function(self,
|
|
|
+ function_name: str,
|
|
|
+ params: Optional[dict] = None) -> Any:
|
|
|
+ """执行数据库函数"""
|
|
|
+ with self.session() as db:
|
|
|
+ try:
|
|
|
+ result = db.execute(text(f"SELECT {function_name}"), params
|
|
|
+ or {})
|
|
|
+ return result.scalar()
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"执行数据库函数出错: [{function_name}] {e}")
|
|
|
+ raise
|