database.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. from typing import Iterator
  2. from contextlib import contextmanager
  3. from sqlalchemy import create_engine, text
  4. from sqlalchemy.orm import sessionmaker, Session
  5. from sqlalchemy.exc import OperationalError
  6. from ..config import config
  7. from ..logger import logger
  8. class Database:
  9. """简化版数据库单例类"""
  10. _instance = None
  11. def __new__(cls):
  12. if cls._instance is None:
  13. cls._instance = super().__new__(cls)
  14. cls._instance._init_db()
  15. return cls._instance
  16. def _init_db(self):
  17. """初始化数据库连接"""
  18. db_config = config.database
  19. self.engine = create_engine(
  20. f"mysql+pymysql://{db_config.user}:{db_config.password}@{db_config.host}:{db_config.port}/{db_config.name}?charset=utf8mb4"
  21. )
  22. self.SessionLocal = sessionmaker(autocommit=False,
  23. autoflush=False,
  24. bind=self.engine)
  25. self.test_connection()
  26. @contextmanager
  27. def session(self) -> Iterator[Session]:
  28. """获取数据库会话"""
  29. db = self.SessionLocal()
  30. try:
  31. yield db
  32. db.commit()
  33. except Exception:
  34. db.rollback()
  35. raise
  36. finally:
  37. db.close()
  38. @classmethod
  39. def initialize(cls):
  40. """初始化数据库"""
  41. return cls()
  42. def test_connection(self):
  43. """测试数据库连接"""
  44. try:
  45. with self.engine.connect() as conn:
  46. result = conn.execute(text("SELECT 1"))
  47. if result.scalar() == 1:
  48. # logger.info(f"数据库 [{config.database.name}] 连接成功。")
  49. return True
  50. return False
  51. except OperationalError as e:
  52. logger.error(f"数据库 [{config.database.name}] 连接失败: {e}")
  53. return False
  54. except Exception as e:
  55. logger.error(f"未知错误: {e}")
  56. return False