sqlalchemy_events.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. import asyncio
  2. import inspect
  3. from datetime import datetime
  4. from sqlalchemy import event
  5. from sqlalchemy.orm import Query
  6. from core.constant import CommonConstant
  7. from core.current_user import CurrentUser, CurrentUserProvider
  8. from core.decorators.data_scope_decorators import DataScopeConfig
  9. from core.enums import DataScopeTypeEnum, DataScopeEnum
  10. from domain.models.base_model import (
  11. BaseModel,
  12. CreateModelBase,
  13. UpdateModelBase,
  14. SoftDeleteModelBase,
  15. )
  16. from utils import logger
  17. class SqlalchemyEventRegister:
  18. def __init__(self):
  19. pass
  20. @classmethod
  21. def register(cls):
  22. event.listen(BaseModel, "before_insert", cls.before_insert)
  23. event.listen(BaseModel, "before_update", cls.before_update)
  24. event.listen(BaseModel, "before_delete", cls.before_delete)
  25. event.listen(Query, "before_compile", cls.before_query)
  26. logger.info("SqlalchemyEvent 注册完成")
  27. pass
  28. @staticmethod
  29. def before_insert(_mapper, _connection, instance):
  30. """在插入记录前自动填充创建人和创建时间
  31. Args:
  32. _mapper: 映射器
  33. _connection: 数据库连接
  34. instance: 模型实例
  35. """
  36. if isinstance(instance, CreateModelBase):
  37. # 设置创建人
  38. if not instance.created_by:
  39. # 不再直接使用UserContext获取当前用户
  40. # 此处应通过依赖注入方式获取当前用户
  41. # 由于事件处理可能在不同上下文中触发,这里需要通过其他方式获取用户信息
  42. # 在实际应用中,应该通过请求上下文或显式传递方式获取
  43. current_user = None
  44. if current_user:
  45. instance.created_by = current_user
  46. # 设置创建时间(如果没有自动设置)
  47. if not instance.created_time:
  48. instance.created_time = datetime.now()
  49. @staticmethod
  50. def before_update(_mapper, _connection, instance):
  51. """在更新记录前自动填充更新人和更新时间
  52. Args:
  53. _mapper: 映射器
  54. _connection: 数据库连接
  55. instance: 模型实例
  56. """
  57. if isinstance(instance, UpdateModelBase):
  58. # 设置更新人
  59. # 不再直接使用UserContext获取当前用户
  60. # 此处应通过依赖注入方式获取当前用户
  61. current_user = None
  62. if current_user:
  63. instance.updated_by = current_user
  64. # 设置更新时间(如果没有自动设置)
  65. instance.updated_time = datetime.now()
  66. @staticmethod
  67. def before_delete(_mapper, _connection, instance):
  68. """在删除记录前处理软删除信息
  69. Args:
  70. _mapper: 映射器
  71. _connection: 数据库连接
  72. instance: 模型实例
  73. """
  74. if (
  75. isinstance(instance, SoftDeleteModelBase)
  76. and hasattr(instance, "is_del")
  77. and instance.is_del == 1
  78. ):
  79. # 软删除情况下设置删除人和删除时间
  80. # 不再直接使用UserContext获取当前用户
  81. # 此处应通过依赖注入方式获取当前用户
  82. current_user = None
  83. if current_user:
  84. instance.deleted_by = current_user
  85. instance.deleted_time = datetime.now()
  86. def before_query(self, query: Query, *_args, **_kwargs):
  87. """
  88. 查询前事件处理器,用于处理数据权限过滤
  89. Args:
  90. query: SQLAlchemy查询对象
  91. Returns:
  92. Query: 过滤后的查询对象或None
  93. """
  94. # 获取查询的实体类
  95. if not query.column_descriptions:
  96. return query
  97. # 获取查询的主实体
  98. entity_info = query.column_descriptions[0]
  99. if "type" not in entity_info:
  100. return query
  101. entity_class = entity_info["type"]
  102. if not entity_class:
  103. return query
  104. data_scope_config = self._check_decorator_marks()
  105. current_user = CurrentUserProvider.get_current_user()
  106. if not data_scope_config or data_scope_config.disable_scope:
  107. return query
  108. if data_scope_config.scope_type == DataScopeTypeEnum.DEPT:
  109. return self._dept_data_scope(
  110. query, entity_class, data_scope_config, current_user
  111. )
  112. elif data_scope_config.scope_type == DataScopeTypeEnum.CUSTOM:
  113. return self._custom_data_scope(
  114. query, entity_class, data_scope_config, current_user
  115. )
  116. else:
  117. return query
  118. @staticmethod
  119. def _check_decorator_marks():
  120. data_scope_config: DataScopeConfig | None = None
  121. frame = inspect.currentframe()
  122. try:
  123. # 跳过当前函数和before_query函数的栈帧
  124. frame = frame.f_back.f_back if frame and frame.f_back else None
  125. while frame:
  126. if frame.f_code:
  127. func_globals = frame.f_globals
  128. func_name = frame.f_code.co_name
  129. # 检查全局函数
  130. if func_name in func_globals:
  131. func = func_globals[func_name]
  132. # 检查是否有数据权限字段标记
  133. if hasattr(func, CommonConstant.DataScopeConfigFieldName):
  134. data_scope_config = getattr(
  135. func, CommonConstant.DataScopeConfigFieldName
  136. )
  137. break
  138. # 检查类方法
  139. if "self" in frame.f_locals:
  140. instance = frame.f_locals.get("self")
  141. if instance:
  142. # 获取实例的类
  143. cls = instance.__class__
  144. # 获取当前方法
  145. method = getattr(cls, func_name, None)
  146. if method:
  147. # 检查方法是否有数据权限字段标记
  148. if hasattr(
  149. method, CommonConstant.DataScopeConfigFieldName
  150. ):
  151. data_scope_config = getattr(
  152. method, CommonConstant.DataScopeConfigFieldName
  153. )
  154. break
  155. frame = frame.f_back
  156. finally:
  157. del frame # 避免循环引用
  158. return data_scope_config
  159. def _dept_data_scope(
  160. self,
  161. query: Query,
  162. entity_class,
  163. data_scope_config: DataScopeConfig,
  164. current_user: CurrentUser,
  165. ):
  166. """
  167. 部门数据权限过滤
  168. """
  169. if not hasattr(entity_class, data_scope_config.dept_id_field) or not hasattr(
  170. entity_class, data_scope_config.user_name_field
  171. ):
  172. return query
  173. if current_user.data_scope == DataScopeEnum.ALL:
  174. return None
  175. # 使用字典映射数据权限范围到对应的过滤函数
  176. scope_filters = {
  177. DataScopeEnum.CUSTOM: lambda q, ec, dg, cu: self._apply_custom_dept_scope_filter(
  178. q, ec, dg, cu
  179. ),
  180. DataScopeEnum.DEPT: lambda q, ec, dg, cu: self._apply_dept_scope_filter(
  181. q, ec, dg, cu
  182. ),
  183. DataScopeEnum.DEPT_AND_CHILD: lambda q, ec, dg, cu: self._apply_dept_and_child_scope_filter(
  184. q, ec, dg, cu
  185. ),
  186. # 默认为仅本人数据权限
  187. DataScopeEnum.SELF: lambda q, ec, dg, cu: self._apply_self_scope_filter(
  188. q, ec, dg, cu
  189. ),
  190. }
  191. # 获取对应的过滤函数,如果没有找到则使用仅本人数据权限过滤
  192. filter_func = scope_filters.get(
  193. current_user.data_scope, scope_filters[DataScopeEnum.SELF]
  194. )
  195. # 调用对应的过滤函数
  196. return filter_func(query, entity_class, data_scope_config, current_user)
  197. @staticmethod
  198. def _apply_custom_dept_scope_filter(
  199. query,
  200. entity_class,
  201. data_scope_config: DataScopeConfig,
  202. current_user: CurrentUser,
  203. ):
  204. from domain.services import CommonService
  205. role_dept_list = asyncio.run(
  206. CommonService().get_user_role_dept_ids(current_user.user_id)
  207. )
  208. if role_dept_list:
  209. dept_ids = [
  210. int(dept_id)
  211. for dept_id in role_dept_list
  212. if isinstance(dept_id, str) and dept_id.isdigit()
  213. ]
  214. if dept_ids:
  215. return query.filter(
  216. getattr(entity_class, data_scope_config.dept_id_field).in_(dept_ids)
  217. )
  218. # 如果没有配置组织机构ID,则只能查看自己的数据
  219. return query.filter(
  220. getattr(entity_class, data_scope_config.user_name_field)
  221. == current_user.username
  222. )
  223. @staticmethod
  224. def _apply_dept_scope_filter(
  225. query,
  226. entity_class,
  227. data_scope_config: DataScopeConfig,
  228. current_user: CurrentUser,
  229. ):
  230. if current_user and current_user.dept_id:
  231. return query.filter(
  232. getattr(entity_class, data_scope_config.dept_id_field)
  233. == current_user.dept_id
  234. )
  235. # 如果用户没有组织机构,则只能查看自己的数据
  236. return query.filter(
  237. getattr(entity_class, data_scope_config.user_name_field)
  238. == current_user.username
  239. )
  240. @staticmethod
  241. def _apply_dept_and_child_scope_filter(
  242. query,
  243. entity_class,
  244. data_scope_config: DataScopeConfig,
  245. current_user: CurrentUser,
  246. ):
  247. if current_user and not current_user.dept_id:
  248. return query.filter(
  249. getattr(entity_class, data_scope_config.user_name_field)
  250. == current_user.username
  251. )
  252. if current_user.dept_id:
  253. from domain.services import CommonService
  254. dept_ids = asyncio.run(
  255. CommonService().get_dept_all_children_ids(current_user.dept_id)
  256. )
  257. if dept_ids:
  258. return query.filter(
  259. getattr(entity_class, data_scope_config.dept_id_field).in_(dept_ids)
  260. )
  261. else:
  262. return query.filter(
  263. getattr(entity_class, data_scope_config.user_name_field)
  264. == current_user.username
  265. )
  266. return query
  267. @staticmethod
  268. def _apply_self_scope_filter(
  269. query,
  270. entity_class,
  271. data_scope_config: DataScopeConfig,
  272. current_user: CurrentUser,
  273. ):
  274. return query.filter(
  275. getattr(entity_class, data_scope_config.user_name_field)
  276. == current_user.username
  277. )
  278. @staticmethod
  279. def _custom_data_scope(
  280. query,
  281. entity_class,
  282. data_scope_config: DataScopeConfig,
  283. current_user: CurrentUser,
  284. ):
  285. if current_user and current_user.dept_id:
  286. return query.filter(data_scope_config.dept_id_field == current_user.dept_id)
  287. else:
  288. return query.filter(
  289. getattr(entity_class, data_scope_config.user_name_field)
  290. == current_user.username
  291. )