|
@@ -0,0 +1,122 @@
|
|
|
+from typing import Dict, Optional, Any, List, Tuple
|
|
|
+
|
|
|
+from sqlalchemy import create_engine, text
|
|
|
+from sqlalchemy.engine import Engine
|
|
|
+from sqlalchemy.orm import sessionmaker
|
|
|
+
|
|
|
+from .base import DBHelper
|
|
|
+
|
|
|
+
|
|
|
+class SQLServerHelper(DBHelper):
|
|
|
+ def __init__(self):
|
|
|
+ super().__init__()
|
|
|
+ self._engines: Dict[str, Engine] = {}
|
|
|
+ self._session_makers: Dict[str, sessionmaker] = {}
|
|
|
+ self._default_config = {
|
|
|
+ 'driver': 'ODBC Driver 17 for SQL Server',
|
|
|
+ 'server': 'localhost',
|
|
|
+ 'username': '',
|
|
|
+ 'password': '',
|
|
|
+ 'trusted_connection': 'yes'
|
|
|
+ }
|
|
|
+ self._pool_config = {
|
|
|
+ 'pool_size': 5, # 减少初始连接数以降低资源占用
|
|
|
+ 'max_overflow': 10, # 适当减少最大溢出连接数
|
|
|
+ 'pool_timeout': 60, # 增加池等待超时时间
|
|
|
+ 'pool_recycle': 1800, # 每30分钟回收连接
|
|
|
+ 'pool_pre_ping': True, # 启用连接健康检查
|
|
|
+ 'connect_args': {
|
|
|
+ 'timeout': 60, # 连接超时时间
|
|
|
+ 'driver_connects_timeout': 60, # 驱动连接超时
|
|
|
+ 'connect_timeout': 60, # ODBC连接超时
|
|
|
+ 'connect_retries': 3, # 连接重试次数
|
|
|
+ 'connect_retry_interval': 10, # 重试间隔增加到10秒
|
|
|
+ 'connection_timeout': 60 # 额外的连接超时设置
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ self._main_database_name = "sqlserver_mian_2024"
|
|
|
+ def _build_connection_string(self, database: str, config: Optional[Dict[str, str]] = None) -> str:
|
|
|
+ """构建连接字符串"""
|
|
|
+ conn_config = self._default_config.copy()
|
|
|
+ db_config = self.get_config_for_database(database)
|
|
|
+ conn_config.update(db_config)
|
|
|
+ if config:
|
|
|
+ conn_config.update(config)
|
|
|
+
|
|
|
+ # 构建认证字符串
|
|
|
+ auth_params = []
|
|
|
+ if conn_config.get('trusted_connection', True):
|
|
|
+ auth_params.append("Trusted_Connection=yes")
|
|
|
+ else:
|
|
|
+ auth_params.extend([
|
|
|
+ f"UID={conn_config['username']}",
|
|
|
+ f"PWD={conn_config['password']}"
|
|
|
+ ])
|
|
|
+
|
|
|
+ # 构建ODBC连接字符串
|
|
|
+ conn_parts = [
|
|
|
+ f"DRIVER={conn_config['driver']}",
|
|
|
+ f"SERVER={conn_config['server']}",
|
|
|
+ f"DATABASE={conn_config['database'] if 'database' in conn_config else database}",
|
|
|
+ "CHARSET=UTF-8"
|
|
|
+ ]
|
|
|
+ conn_parts.extend(auth_params)
|
|
|
+
|
|
|
+ # 构建SQLAlchemy连接URL
|
|
|
+ conn_str = ";".join(conn_parts)
|
|
|
+ conn_url = f"mssql+pyodbc:///?odbc_connect={conn_str}"
|
|
|
+
|
|
|
+ return conn_url
|
|
|
+
|
|
|
+ def get_engine(self, database: str="", config: Optional[Dict[str, str]] = None) -> Engine:
|
|
|
+ database = database or self._main_database_name
|
|
|
+ """获取或创建数据库引擎"""
|
|
|
+ if database not in self._engines:
|
|
|
+ conn_str = self._build_connection_string(database, config)
|
|
|
+ engine = create_engine(conn_str, **self._pool_config)
|
|
|
+ # 预热连接池
|
|
|
+ with engine.connect() as conn:
|
|
|
+ conn.execute(text("SELECT 1"))
|
|
|
+ self._engines[database] = engine
|
|
|
+ return self._engines[database]
|
|
|
+
|
|
|
+ def execute_query(self, database: str, query: str, params: Optional[Dict[str, Any]] = None) -> List[Tuple]:
|
|
|
+ """执行查询并返回结果"""
|
|
|
+ with self.session_scope(database) as session:
|
|
|
+ result = session.execute(text(query), params or {})
|
|
|
+ return [tuple(row) for row in result.fetchall()]
|
|
|
+
|
|
|
+ def execute_non_query(self, database: str, query: str, params: Optional[Dict[str, Any]] = None) -> int:
|
|
|
+ """执行非查询操作(如INSERT, UPDATE, DELETE)"""
|
|
|
+ with self.session_scope(database) as session:
|
|
|
+ result = session.execute(text(query), params or {})
|
|
|
+ return result.rowcount
|
|
|
+
|
|
|
+ def execute_scalar(self, database: str, query: str, params: Optional[Dict[str, Any]] = None) -> Any:
|
|
|
+ """执行查询并返回第一行第一列的值"""
|
|
|
+ with self.session_scope(database) as session:
|
|
|
+ result = session.execute(text(query), params or {})
|
|
|
+ row = result.fetchone()
|
|
|
+ return row[0] if row else None
|
|
|
+
|
|
|
+ def execute_procedure(self, database: str, procedure_name: str, params: Optional[Dict[str, Any]] = None) -> List[Tuple]:
|
|
|
+ """执行存储过程"""
|
|
|
+ params = params or {}
|
|
|
+ param_str = ", ".join([f"@{key}=:{key}" for key in params.keys()])
|
|
|
+ query = f"EXEC {procedure_name} {param_str}"
|
|
|
+
|
|
|
+ with self.session_scope(database) as session:
|
|
|
+ result = session.execute(text(query), params)
|
|
|
+ return [tuple(row) for row in result.fetchall()]
|
|
|
+
|
|
|
+ def dispose_all(self) -> None:
|
|
|
+ """释放所有数据库引擎资源"""
|
|
|
+ for engine in self._engines.values():
|
|
|
+ engine.dispose()
|
|
|
+ self._engines.clear()
|
|
|
+ self._session_makers.clear()
|
|
|
+
|
|
|
+ def __del__(self):
|
|
|
+ """析构函数,确保所有引擎资源被释放"""
|
|
|
+ self.dispose_all()
|