from typing import Iterator from contextlib import contextmanager from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.exc import OperationalError from ..config import config from ..logger import logger class Database: """简化版数据库单例类""" _instance = None def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._init_db() return cls._instance def _init_db(self): """初始化数据库连接""" db_config = config.database self.engine = create_engine( f"mysql+pymysql://{db_config.user}:{db_config.password}@{db_config.host}:{db_config.port}/{db_config.name}?charset=utf8mb4" ) self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine) self.test_connection() @contextmanager def session(self) -> Iterator[Session]: """获取数据库会话""" db = self.SessionLocal() try: yield db db.commit() except Exception: db.rollback() raise finally: db.close() @classmethod def initialize(cls): """初始化数据库""" return cls() def test_connection(self): """测试数据库连接""" try: with self.engine.connect() as conn: result = conn.execute(text("SELECT 1")) if result.scalar() == 1: # logger.info(f"数据库 [{config.database.name}] 连接成功。") return True return False except OperationalError as e: logger.error(f"数据库 [{config.database.name}] 连接失败: {e}") return False except Exception as e: logger.error(f"未知错误: {e}") return False