|
@@ -1,22 +1,16 @@
|
|
-from typing import Iterator, List, Optional, Any
|
|
|
|
|
|
+from typing import Iterator
|
|
from contextlib import contextmanager
|
|
from contextlib import contextmanager
|
|
from sqlalchemy import create_engine, text
|
|
from sqlalchemy import create_engine, text
|
|
-from sqlalchemy.orm import Session, sessionmaker
|
|
|
|
-from app.config import config
|
|
|
|
-from app.logger import logger
|
|
|
|
|
|
+from sqlalchemy.orm import sessionmaker, Session
|
|
|
|
+from sqlalchemy.exc import OperationalError
|
|
|
|
+from ..config import config
|
|
|
|
+from ..logger import logger
|
|
|
|
|
|
|
|
|
|
class Database:
|
|
class Database:
|
|
- """数据库单例类"""
|
|
|
|
|
|
+ """简化版数据库单例类"""
|
|
_instance = None
|
|
_instance = None
|
|
|
|
|
|
- @staticmethod
|
|
|
|
- def initialize():
|
|
|
|
- """初始化数据库连接
|
|
|
|
- """
|
|
|
|
- # 调用单例实例会触发_init_db
|
|
|
|
- return Database()
|
|
|
|
-
|
|
|
|
def __new__(cls):
|
|
def __new__(cls):
|
|
if cls._instance is None:
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
cls._instance = super().__new__(cls)
|
|
@@ -26,36 +20,14 @@ class Database:
|
|
def _init_db(self):
|
|
def _init_db(self):
|
|
"""初始化数据库连接"""
|
|
"""初始化数据库连接"""
|
|
db_config = config.database
|
|
db_config = config.database
|
|
- try:
|
|
|
|
- self.engine = create_engine(self._get_sqlalchemy_uri(
|
|
|
|
- host=db_config.host,
|
|
|
|
- port=db_config.port,
|
|
|
|
- user=db_config.user,
|
|
|
|
- password=db_config.password,
|
|
|
|
- name=db_config.name),
|
|
|
|
- pool_pre_ping=True,
|
|
|
|
- pool_recycle=3600)
|
|
|
|
-
|
|
|
|
- # 测试连接
|
|
|
|
- with self.engine.connect() as conn:
|
|
|
|
- conn.execute(text("SELECT 1"))
|
|
|
|
-
|
|
|
|
- logger.debug("数据库连接成功")
|
|
|
|
- except Exception as e:
|
|
|
|
- logger.warning(f"数据库连接失败: {e}")
|
|
|
|
- raise
|
|
|
|
-
|
|
|
|
|
|
+ self.engine = create_engine(
|
|
|
|
+ f"mysql+pymysql://{db_config.user}:{db_config.password}@{db_config.host}:{db_config.port}/{db_config.name}?charset=utf8mb4"
|
|
|
|
+ )
|
|
self.SessionLocal = sessionmaker(autocommit=False,
|
|
self.SessionLocal = sessionmaker(autocommit=False,
|
|
autoflush=False,
|
|
autoflush=False,
|
|
bind=self.engine)
|
|
bind=self.engine)
|
|
|
|
|
|
- @staticmethod
|
|
|
|
- def _get_sqlalchemy_uri(host: str, port: int, user: str, password: str,
|
|
|
|
- name: str) -> str:
|
|
|
|
- """生成SQLAlchemy连接字符串"""
|
|
|
|
- sql = f"mysql+pymysql://{user}:{password}@{host}:{port}/{name}?charset=utf8mb4"
|
|
|
|
- # logger.debug(f"数据库连接字符串: {sql}")
|
|
|
|
- return sql
|
|
|
|
|
|
+ self.test_connection()
|
|
|
|
|
|
@contextmanager
|
|
@contextmanager
|
|
def session(self) -> Iterator[Session]:
|
|
def session(self) -> Iterator[Session]:
|
|
@@ -63,85 +35,30 @@ class Database:
|
|
db = self.SessionLocal()
|
|
db = self.SessionLocal()
|
|
try:
|
|
try:
|
|
yield db
|
|
yield db
|
|
|
|
+ db.commit()
|
|
except Exception:
|
|
except Exception:
|
|
db.rollback()
|
|
db.rollback()
|
|
raise
|
|
raise
|
|
finally:
|
|
finally:
|
|
db.close()
|
|
db.close()
|
|
|
|
|
|
- def query(self, sql: str, params: Optional[dict] = None) -> List[dict]:
|
|
|
|
- """执行查询SQL"""
|
|
|
|
- with self.session() as db:
|
|
|
|
- try:
|
|
|
|
- result = db.execute(text(sql), params or {})
|
|
|
|
- return [dict(row) for row in result.mappings()]
|
|
|
|
- except Exception as e:
|
|
|
|
- logger.error(f"查询SQL出错: [{sql}] {e}")
|
|
|
|
- raise
|
|
|
|
|
|
+ @classmethod
|
|
|
|
+ def initialize(cls):
|
|
|
|
+ """初始化数据库"""
|
|
|
|
+ return cls()
|
|
|
|
|
|
- def execute(self, sql: str, params: Optional[dict] = None) -> int:
|
|
|
|
- """执行非查询SQL"""
|
|
|
|
- with self.session() as db:
|
|
|
|
- try:
|
|
|
|
- result = db.execute(text(sql), params or {})
|
|
|
|
- db.commit()
|
|
|
|
- return result.rowcount
|
|
|
|
- except Exception as e:
|
|
|
|
- logger.error(f"执行SQL出错: [{sql}] {e}")
|
|
|
|
- raise
|
|
|
|
-
|
|
|
|
- def batch_execute(self, sql: str, params_list: List[dict]) -> int:
|
|
|
|
- """批量执行SQL
|
|
|
|
- 注意:所有操作在单个事务中执行,要么全部成功,要么全部回滚
|
|
|
|
- """
|
|
|
|
- if not params_list:
|
|
|
|
- return 0
|
|
|
|
-
|
|
|
|
- with self.session() as db:
|
|
|
|
- try:
|
|
|
|
- result = db.execute(text(sql), params_list)
|
|
|
|
- db.commit()
|
|
|
|
- return result.rowcount
|
|
|
|
- except Exception as e:
|
|
|
|
- logger.error(f"批量执行SQL出错: [{sql}] {e}")
|
|
|
|
- raise
|
|
|
|
-
|
|
|
|
- def query_one(self,
|
|
|
|
- sql: str,
|
|
|
|
- params: Optional[dict] = None) -> Optional[dict]:
|
|
|
|
- """执行查询SQL并返回单条记录"""
|
|
|
|
- with self.session() as db:
|
|
|
|
- try:
|
|
|
|
- result = db.execute(text(sql), params or {})
|
|
|
|
- row = result.mappings().first()
|
|
|
|
- return dict(row) if row else None
|
|
|
|
- except Exception as e:
|
|
|
|
- logger.error(f"查询SQL出错: [{sql}] {e}")
|
|
|
|
- raise
|
|
|
|
-
|
|
|
|
- def execute_procedure(self,
|
|
|
|
- procedure_name: str,
|
|
|
|
- params: Optional[dict] = None) -> Any:
|
|
|
|
- """执行存储过程"""
|
|
|
|
- with self.session() as db:
|
|
|
|
- try:
|
|
|
|
- result = db.execute(text(f"CALL {procedure_name}"), params
|
|
|
|
- or {})
|
|
|
|
- db.commit()
|
|
|
|
- return result.fetchall()
|
|
|
|
- except Exception as e:
|
|
|
|
- logger.error(f"执行存储过程出错: [{procedure_name}] {e}")
|
|
|
|
- raise
|
|
|
|
-
|
|
|
|
- def execute_function(self,
|
|
|
|
- function_name: str,
|
|
|
|
- params: Optional[dict] = None) -> Any:
|
|
|
|
- """执行数据库函数"""
|
|
|
|
- with self.session() as db:
|
|
|
|
- try:
|
|
|
|
- result = db.execute(text(f"SELECT {function_name}"), params
|
|
|
|
- or {})
|
|
|
|
- return result.scalar()
|
|
|
|
- except Exception as e:
|
|
|
|
- logger.error(f"执行数据库函数出错: [{function_name}] {e}")
|
|
|
|
- raise
|
|
|
|
|
|
+ def test_connection(self):
|
|
|
|
+ """测试数据库连接"""
|
|
|
|
+ try:
|
|
|
|
+ with self.engine.connect() as conn:
|
|
|
|
+ result = conn.execute(text("SELECT 1"))
|
|
|
|
+ if result.scalar() == 1:
|
|
|
|
+ # logger.info(f"数据库 [{config.database.name}] 连接成功。")
|
|
|
|
+ return True
|
|
|
|
+ return False
|
|
|
|
+ except OperationalError as e:
|
|
|
|
+ logger.error(f"数据库 [{config.database.name}] 连接失败: {e}")
|
|
|
|
+ return False
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"未知错误: {e}")
|
|
|
|
+ return False
|