database.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. from typing import Iterator, List, Optional, Any
  2. from contextlib import contextmanager
  3. from sqlalchemy import create_engine, text
  4. from sqlalchemy.orm import Session, sessionmaker
  5. from app.config import config
  6. from app.logger import logger
  7. class Database:
  8. """数据库单例类"""
  9. _instance = None
  10. @staticmethod
  11. def initialize():
  12. """初始化数据库连接
  13. """
  14. # 调用单例实例会触发_init_db
  15. return Database()
  16. def __new__(cls):
  17. if cls._instance is None:
  18. cls._instance = super().__new__(cls)
  19. cls._instance._init_db()
  20. return cls._instance
  21. def _init_db(self):
  22. """初始化数据库连接"""
  23. db_config = config.database
  24. try:
  25. self.engine = create_engine(self._get_sqlalchemy_uri(
  26. host=db_config.host,
  27. port=db_config.port,
  28. user=db_config.user,
  29. password=db_config.password,
  30. name=db_config.name),
  31. pool_pre_ping=True,
  32. pool_recycle=3600)
  33. # 测试连接
  34. with self.engine.connect() as conn:
  35. conn.execute(text("SELECT 1"))
  36. logger.debug("数据库连接成功")
  37. except Exception as e:
  38. logger.warning(f"数据库连接失败: {e}")
  39. raise
  40. self.SessionLocal = sessionmaker(autocommit=False,
  41. autoflush=False,
  42. bind=self.engine)
  43. @staticmethod
  44. def _get_sqlalchemy_uri(host: str, port: int, user: str, password: str,
  45. name: str) -> str:
  46. """生成SQLAlchemy连接字符串"""
  47. sql = f"mysql+pymysql://{user}:{password}@{host}:{port}/{name}?charset=utf8mb4"
  48. # logger.debug(f"数据库连接字符串: {sql}")
  49. return sql
  50. @contextmanager
  51. def session(self) -> Iterator[Session]:
  52. """获取数据库会话"""
  53. db = self.SessionLocal()
  54. try:
  55. yield db
  56. except Exception:
  57. db.rollback()
  58. raise
  59. finally:
  60. db.close()
  61. def query(self, sql: str, params: Optional[dict] = None) -> List[dict]:
  62. """执行查询SQL"""
  63. with self.session() as db:
  64. try:
  65. result = db.execute(text(sql), params or {})
  66. return [dict(row) for row in result.mappings()]
  67. except Exception as e:
  68. logger.error(f"查询SQL出错: [{sql}] {e}")
  69. raise
  70. def execute(self, sql: str, params: Optional[dict] = None) -> int:
  71. """执行非查询SQL"""
  72. with self.session() as db:
  73. try:
  74. result = db.execute(text(sql), params or {})
  75. db.commit()
  76. return result.rowcount
  77. except Exception as e:
  78. logger.error(f"执行SQL出错: [{sql}] {e}")
  79. raise
  80. def batch_execute(self, sql: str, params_list: List[dict]) -> int:
  81. """批量执行SQL
  82. 注意:所有操作在单个事务中执行,要么全部成功,要么全部回滚
  83. """
  84. if not params_list:
  85. return 0
  86. with self.session() as db:
  87. try:
  88. result = db.execute(text(sql), params_list)
  89. db.commit()
  90. return result.rowcount
  91. except Exception as e:
  92. logger.error(f"批量执行SQL出错: [{sql}] {e}")
  93. raise
  94. def query_one(self,
  95. sql: str,
  96. params: Optional[dict] = None) -> Optional[dict]:
  97. """执行查询SQL并返回单条记录"""
  98. with self.session() as db:
  99. try:
  100. result = db.execute(text(sql), params or {})
  101. row = result.mappings().first()
  102. return dict(row) if row else None
  103. except Exception as e:
  104. logger.error(f"查询SQL出错: [{sql}] {e}")
  105. raise
  106. def execute_procedure(self,
  107. procedure_name: str,
  108. params: Optional[dict] = None) -> Any:
  109. """执行存储过程"""
  110. with self.session() as db:
  111. try:
  112. result = db.execute(text(f"CALL {procedure_name}"), params
  113. or {})
  114. db.commit()
  115. return result.fetchall()
  116. except Exception as e:
  117. logger.error(f"执行存储过程出错: [{procedure_name}] {e}")
  118. raise
  119. def execute_function(self,
  120. function_name: str,
  121. params: Optional[dict] = None) -> Any:
  122. """执行数据库函数"""
  123. with self.session() as db:
  124. try:
  125. result = db.execute(text(f"SELECT {function_name}"), params
  126. or {})
  127. return result.scalar()
  128. except Exception as e:
  129. logger.error(f"执行数据库函数出错: [{function_name}] {e}")
  130. raise