12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- from typing import Iterator
- from contextlib import contextmanager
- from sqlalchemy import create_engine, text
- from sqlalchemy.orm import sessionmaker, Session
- from sqlalchemy.exc import OperationalError
- from ..config import config
- from ..logger import logger
- class Database:
- """简化版数据库单例类"""
- _instance = None
- def __new__(cls):
- if cls._instance is None:
- cls._instance = super().__new__(cls)
- cls._instance._init_db()
- return cls._instance
- def _init_db(self):
- """初始化数据库连接"""
- db_config = config.database
- self.engine = create_engine(
- f"mysql+pymysql://{db_config.user}:{db_config.password}@{db_config.host}:{db_config.port}/{db_config.name}?charset=utf8mb4"
- )
- self.SessionLocal = sessionmaker(autocommit=False,
- autoflush=False,
- bind=self.engine)
- self.test_connection()
- @contextmanager
- def session(self) -> Iterator[Session]:
- """获取数据库会话"""
- db = self.SessionLocal()
- try:
- yield db
- db.commit()
- except Exception:
- db.rollback()
- raise
- finally:
- db.close()
- @classmethod
- def initialize(cls):
- """初始化数据库"""
- return cls()
- def test_connection(self):
- """测试数据库连接"""
- try:
- with self.engine.connect() as conn:
- result = conn.execute(text("SELECT 1"))
- if result.scalar() == 1:
- # logger.info(f"数据库 [{config.database.name}] 连接成功。")
- return True
- return False
- except OperationalError as e:
- logger.error(f"数据库 [{config.database.name}] 连接失败: {e}")
- return False
- except Exception as e:
- logger.error(f"未知错误: {e}")
- return False
|