import asyncio import inspect from datetime import datetime from sqlalchemy import event from sqlalchemy.orm import Query from core.constant import CommonConstant from core.current_user import CurrentUser, CurrentUserProvider from core.decorators.data_scope_decorators import DataScopeConfig from core.enums import DataScopeTypeEnum, DataScopeEnum from domain.models.base_model import ( BaseModel, CreateModelBase, UpdateModelBase, SoftDeleteModelBase, ) from utils import logger class SqlalchemyEventRegister: def __init__(self): pass @classmethod def register(cls): event.listen(BaseModel, "before_insert", cls.before_insert) event.listen(BaseModel, "before_update", cls.before_update) event.listen(BaseModel, "before_delete", cls.before_delete) event.listen(Query, "before_compile", cls.before_query) logger.info("SqlalchemyEvent 注册完成") pass @staticmethod def before_insert(_mapper, _connection, instance): """在插入记录前自动填充创建人和创建时间 Args: _mapper: 映射器 _connection: 数据库连接 instance: 模型实例 """ if isinstance(instance, CreateModelBase): # 设置创建人 if not instance.created_by: # 不再直接使用UserContext获取当前用户 # 此处应通过依赖注入方式获取当前用户 # 由于事件处理可能在不同上下文中触发,这里需要通过其他方式获取用户信息 # 在实际应用中,应该通过请求上下文或显式传递方式获取 current_user = None if current_user: instance.created_by = current_user # 设置创建时间(如果没有自动设置) if not instance.created_time: instance.created_time = datetime.now() @staticmethod def before_update(_mapper, _connection, instance): """在更新记录前自动填充更新人和更新时间 Args: _mapper: 映射器 _connection: 数据库连接 instance: 模型实例 """ if isinstance(instance, UpdateModelBase): # 设置更新人 # 不再直接使用UserContext获取当前用户 # 此处应通过依赖注入方式获取当前用户 current_user = None if current_user: instance.updated_by = current_user # 设置更新时间(如果没有自动设置) instance.updated_time = datetime.now() @staticmethod def before_delete(_mapper, _connection, instance): """在删除记录前处理软删除信息 Args: _mapper: 映射器 _connection: 数据库连接 instance: 模型实例 """ if ( isinstance(instance, SoftDeleteModelBase) and hasattr(instance, "is_del") and instance.is_del == 1 ): # 软删除情况下设置删除人和删除时间 # 不再直接使用UserContext获取当前用户 # 此处应通过依赖注入方式获取当前用户 current_user = None if current_user: instance.deleted_by = current_user instance.deleted_time = datetime.now() def before_query(self, query: Query, *_args, **_kwargs): """ 查询前事件处理器,用于处理数据权限过滤 Args: query: SQLAlchemy查询对象 Returns: Query: 过滤后的查询对象或None """ # 获取查询的实体类 if not query.column_descriptions: return query # 获取查询的主实体 entity_info = query.column_descriptions[0] if "type" not in entity_info: return query entity_class = entity_info["type"] if not entity_class: return query data_scope_config = self._check_decorator_marks() current_user = CurrentUserProvider.get_current_user() if not data_scope_config or data_scope_config.disable_scope: return query if data_scope_config.scope_type == DataScopeTypeEnum.DEPT: return self._dept_data_scope( query, entity_class, data_scope_config, current_user ) elif data_scope_config.scope_type == DataScopeTypeEnum.CUSTOM: return self._custom_data_scope( query, entity_class, data_scope_config, current_user ) else: return query @staticmethod def _check_decorator_marks(): data_scope_config: DataScopeConfig | None = None frame = inspect.currentframe() try: # 跳过当前函数和before_query函数的栈帧 frame = frame.f_back.f_back if frame and frame.f_back else None while frame: if frame.f_code: func_globals = frame.f_globals func_name = frame.f_code.co_name # 检查全局函数 if func_name in func_globals: func = func_globals[func_name] # 检查是否有数据权限字段标记 if hasattr(func, CommonConstant.DataScopeConfigFieldName): data_scope_config = getattr( func, CommonConstant.DataScopeConfigFieldName ) break # 检查类方法 if "self" in frame.f_locals: instance = frame.f_locals.get("self") if instance: # 获取实例的类 cls = instance.__class__ # 获取当前方法 method = getattr(cls, func_name, None) if method: # 检查方法是否有数据权限字段标记 if hasattr( method, CommonConstant.DataScopeConfigFieldName ): data_scope_config = getattr( method, CommonConstant.DataScopeConfigFieldName ) break frame = frame.f_back finally: del frame # 避免循环引用 return data_scope_config def _dept_data_scope( self, query: Query, entity_class, data_scope_config: DataScopeConfig, current_user: CurrentUser, ): """ 部门数据权限过滤 """ if not hasattr(entity_class, data_scope_config.dept_id_field) or not hasattr( entity_class, data_scope_config.user_name_field ): return query if current_user.data_scope == DataScopeEnum.ALL: return None # 使用字典映射数据权限范围到对应的过滤函数 scope_filters = { DataScopeEnum.CUSTOM: lambda q, ec, dg, cu: self._apply_custom_dept_scope_filter( q, ec, dg, cu ), DataScopeEnum.DEPT: lambda q, ec, dg, cu: self._apply_dept_scope_filter( q, ec, dg, cu ), DataScopeEnum.DEPT_AND_CHILD: lambda q, ec, dg, cu: self._apply_dept_and_child_scope_filter( q, ec, dg, cu ), # 默认为仅本人数据权限 DataScopeEnum.SELF: lambda q, ec, dg, cu: self._apply_self_scope_filter( q, ec, dg, cu ), } # 获取对应的过滤函数,如果没有找到则使用仅本人数据权限过滤 filter_func = scope_filters.get( current_user.data_scope, scope_filters[DataScopeEnum.SELF] ) # 调用对应的过滤函数 return filter_func(query, entity_class, data_scope_config, current_user) @staticmethod def _apply_custom_dept_scope_filter( query, entity_class, data_scope_config: DataScopeConfig, current_user: CurrentUser, ): from domain.services import CommonService role_dept_list = asyncio.run( CommonService().get_user_role_dept_ids(current_user.user_id) ) if role_dept_list: dept_ids = [ int(dept_id) for dept_id in role_dept_list if isinstance(dept_id, str) and dept_id.isdigit() ] if dept_ids: return query.filter( getattr(entity_class, data_scope_config.dept_id_field).in_(dept_ids) ) # 如果没有配置组织机构ID,则只能查看自己的数据 return query.filter( getattr(entity_class, data_scope_config.user_name_field) == current_user.username ) @staticmethod def _apply_dept_scope_filter( query, entity_class, data_scope_config: DataScopeConfig, current_user: CurrentUser, ): if current_user and current_user.dept_id: return query.filter( getattr(entity_class, data_scope_config.dept_id_field) == current_user.dept_id ) # 如果用户没有组织机构,则只能查看自己的数据 return query.filter( getattr(entity_class, data_scope_config.user_name_field) == current_user.username ) @staticmethod def _apply_dept_and_child_scope_filter( query, entity_class, data_scope_config: DataScopeConfig, current_user: CurrentUser, ): if current_user and not current_user.dept_id: return query.filter( getattr(entity_class, data_scope_config.user_name_field) == current_user.username ) if current_user.dept_id: from domain.services import CommonService dept_ids = asyncio.run( CommonService().get_dept_all_children_ids(current_user.dept_id) ) if dept_ids: return query.filter( getattr(entity_class, data_scope_config.dept_id_field).in_(dept_ids) ) else: return query.filter( getattr(entity_class, data_scope_config.user_name_field) == current_user.username ) return query @staticmethod def _apply_self_scope_filter( query, entity_class, data_scope_config: DataScopeConfig, current_user: CurrentUser, ): return query.filter( getattr(entity_class, data_scope_config.user_name_field) == current_user.username ) @staticmethod def _custom_data_scope( query, entity_class, data_scope_config: DataScopeConfig, current_user: CurrentUser, ): if current_user and current_user.dept_id: return query.filter(data_scope_config.dept_id_field == current_user.dept_id) else: return query.filter( getattr(entity_class, data_scope_config.user_name_field) == current_user.username )