| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325 |
- import asyncio
- import inspect
- from datetime import datetime
- from sqlalchemy import event
- from sqlalchemy.orm import Query
- from core.constant import CommonConstant
- from core.current_user import CurrentUser, CurrentUserProvider
- from core.decorators.data_scope_decorators import DataScopeConfig
- from core.enums import DataScopeTypeEnum, DataScopeEnum
- from domain.models.base_model import (
- BaseModel,
- CreateModelBase,
- UpdateModelBase,
- SoftDeleteModelBase,
- )
- from utils import logger
- class SqlalchemyEventRegister:
- def __init__(self):
- pass
- @classmethod
- def register(cls):
- event.listen(BaseModel, "before_insert", cls.before_insert)
- event.listen(BaseModel, "before_update", cls.before_update)
- event.listen(BaseModel, "before_delete", cls.before_delete)
- event.listen(Query, "before_compile", cls.before_query)
- logger.info("SqlalchemyEvent 注册完成")
- pass
- @staticmethod
- def before_insert(_mapper, _connection, instance):
- """在插入记录前自动填充创建人和创建时间
- Args:
- _mapper: 映射器
- _connection: 数据库连接
- instance: 模型实例
- """
- if isinstance(instance, CreateModelBase):
- # 设置创建人
- if not instance.created_by:
- # 不再直接使用UserContext获取当前用户
- # 此处应通过依赖注入方式获取当前用户
- # 由于事件处理可能在不同上下文中触发,这里需要通过其他方式获取用户信息
- # 在实际应用中,应该通过请求上下文或显式传递方式获取
- current_user = None
- if current_user:
- instance.created_by = current_user
- # 设置创建时间(如果没有自动设置)
- if not instance.created_time:
- instance.created_time = datetime.now()
- @staticmethod
- def before_update(_mapper, _connection, instance):
- """在更新记录前自动填充更新人和更新时间
- Args:
- _mapper: 映射器
- _connection: 数据库连接
- instance: 模型实例
- """
- if isinstance(instance, UpdateModelBase):
- # 设置更新人
- # 不再直接使用UserContext获取当前用户
- # 此处应通过依赖注入方式获取当前用户
- current_user = None
- if current_user:
- instance.updated_by = current_user
- # 设置更新时间(如果没有自动设置)
- instance.updated_time = datetime.now()
- @staticmethod
- def before_delete(_mapper, _connection, instance):
- """在删除记录前处理软删除信息
- Args:
- _mapper: 映射器
- _connection: 数据库连接
- instance: 模型实例
- """
- if (
- isinstance(instance, SoftDeleteModelBase)
- and hasattr(instance, "is_del")
- and instance.is_del == 1
- ):
- # 软删除情况下设置删除人和删除时间
- # 不再直接使用UserContext获取当前用户
- # 此处应通过依赖注入方式获取当前用户
- current_user = None
- if current_user:
- instance.deleted_by = current_user
- instance.deleted_time = datetime.now()
- def before_query(self, query: Query, *_args, **_kwargs):
- """
- 查询前事件处理器,用于处理数据权限过滤
- Args:
- query: SQLAlchemy查询对象
- Returns:
- Query: 过滤后的查询对象或None
- """
- # 获取查询的实体类
- if not query.column_descriptions:
- return query
- # 获取查询的主实体
- entity_info = query.column_descriptions[0]
- if "type" not in entity_info:
- return query
- entity_class = entity_info["type"]
- if not entity_class:
- return query
- data_scope_config = self._check_decorator_marks()
- current_user = CurrentUserProvider.get_current_user()
- if not data_scope_config or data_scope_config.disable_scope:
- return query
- if data_scope_config.scope_type == DataScopeTypeEnum.DEPT:
- return self._dept_data_scope(
- query, entity_class, data_scope_config, current_user
- )
- elif data_scope_config.scope_type == DataScopeTypeEnum.CUSTOM:
- return self._custom_data_scope(
- query, entity_class, data_scope_config, current_user
- )
- else:
- return query
- @staticmethod
- def _check_decorator_marks():
- data_scope_config: DataScopeConfig | None = None
- frame = inspect.currentframe()
- try:
- # 跳过当前函数和before_query函数的栈帧
- frame = frame.f_back.f_back if frame and frame.f_back else None
- while frame:
- if frame.f_code:
- func_globals = frame.f_globals
- func_name = frame.f_code.co_name
- # 检查全局函数
- if func_name in func_globals:
- func = func_globals[func_name]
- # 检查是否有数据权限字段标记
- if hasattr(func, CommonConstant.DataScopeConfigFieldName):
- data_scope_config = getattr(
- func, CommonConstant.DataScopeConfigFieldName
- )
- break
- # 检查类方法
- if "self" in frame.f_locals:
- instance = frame.f_locals.get("self")
- if instance:
- # 获取实例的类
- cls = instance.__class__
- # 获取当前方法
- method = getattr(cls, func_name, None)
- if method:
- # 检查方法是否有数据权限字段标记
- if hasattr(
- method, CommonConstant.DataScopeConfigFieldName
- ):
- data_scope_config = getattr(
- method, CommonConstant.DataScopeConfigFieldName
- )
- break
- frame = frame.f_back
- finally:
- del frame # 避免循环引用
- return data_scope_config
- def _dept_data_scope(
- self,
- query: Query,
- entity_class,
- data_scope_config: DataScopeConfig,
- current_user: CurrentUser,
- ):
- """
- 部门数据权限过滤
- """
- if not hasattr(entity_class, data_scope_config.dept_id_field) or not hasattr(
- entity_class, data_scope_config.user_name_field
- ):
- return query
- if current_user.data_scope == DataScopeEnum.ALL:
- return None
- # 使用字典映射数据权限范围到对应的过滤函数
- scope_filters = {
- DataScopeEnum.CUSTOM: lambda q, ec, dg, cu: self._apply_custom_dept_scope_filter(
- q, ec, dg, cu
- ),
- DataScopeEnum.DEPT: lambda q, ec, dg, cu: self._apply_dept_scope_filter(
- q, ec, dg, cu
- ),
- DataScopeEnum.DEPT_AND_CHILD: lambda q, ec, dg, cu: self._apply_dept_and_child_scope_filter(
- q, ec, dg, cu
- ),
- # 默认为仅本人数据权限
- DataScopeEnum.SELF: lambda q, ec, dg, cu: self._apply_self_scope_filter(
- q, ec, dg, cu
- ),
- }
- # 获取对应的过滤函数,如果没有找到则使用仅本人数据权限过滤
- filter_func = scope_filters.get(
- current_user.data_scope, scope_filters[DataScopeEnum.SELF]
- )
- # 调用对应的过滤函数
- return filter_func(query, entity_class, data_scope_config, current_user)
- @staticmethod
- def _apply_custom_dept_scope_filter(
- query,
- entity_class,
- data_scope_config: DataScopeConfig,
- current_user: CurrentUser,
- ):
- from domain.services import CommonService
- role_dept_list = asyncio.run(
- CommonService().get_user_role_dept_ids(current_user.user_id)
- )
- if role_dept_list:
- dept_ids = [
- int(dept_id)
- for dept_id in role_dept_list
- if isinstance(dept_id, str) and dept_id.isdigit()
- ]
- if dept_ids:
- return query.filter(
- getattr(entity_class, data_scope_config.dept_id_field).in_(dept_ids)
- )
- # 如果没有配置组织机构ID,则只能查看自己的数据
- return query.filter(
- getattr(entity_class, data_scope_config.user_name_field)
- == current_user.username
- )
- @staticmethod
- def _apply_dept_scope_filter(
- query,
- entity_class,
- data_scope_config: DataScopeConfig,
- current_user: CurrentUser,
- ):
- if current_user and current_user.dept_id:
- return query.filter(
- getattr(entity_class, data_scope_config.dept_id_field)
- == current_user.dept_id
- )
- # 如果用户没有组织机构,则只能查看自己的数据
- return query.filter(
- getattr(entity_class, data_scope_config.user_name_field)
- == current_user.username
- )
- @staticmethod
- def _apply_dept_and_child_scope_filter(
- query,
- entity_class,
- data_scope_config: DataScopeConfig,
- current_user: CurrentUser,
- ):
- if current_user and not current_user.dept_id:
- return query.filter(
- getattr(entity_class, data_scope_config.user_name_field)
- == current_user.username
- )
- if current_user.dept_id:
- from domain.services import CommonService
- dept_ids = asyncio.run(
- CommonService().get_dept_all_children_ids(current_user.dept_id)
- )
- if dept_ids:
- return query.filter(
- getattr(entity_class, data_scope_config.dept_id_field).in_(dept_ids)
- )
- else:
- return query.filter(
- getattr(entity_class, data_scope_config.user_name_field)
- == current_user.username
- )
- return query
- @staticmethod
- def _apply_self_scope_filter(
- query,
- entity_class,
- data_scope_config: DataScopeConfig,
- current_user: CurrentUser,
- ):
- return query.filter(
- getattr(entity_class, data_scope_config.user_name_field)
- == current_user.username
- )
- @staticmethod
- def _custom_data_scope(
- query,
- entity_class,
- data_scope_config: DataScopeConfig,
- current_user: CurrentUser,
- ):
- if current_user and current_user.dept_id:
- return query.filter(data_scope_config.dept_id_field == current_user.dept_id)
- else:
- return query.filter(
- getattr(entity_class, data_scope_config.user_name_field)
- == current_user.username
- )
|