response_middleware.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. import json
  2. from typing import Callable, Optional, Any
  3. from fastapi import Request, Response
  4. from fastapi.responses import JSONResponse, StreamingResponse
  5. from starlette.responses import (
  6. HTMLResponse,
  7. PlainTextResponse,
  8. RedirectResponse,
  9. FileResponse,
  10. )
  11. from core.constant import CommonConstant
  12. from utils import logger
  13. from utils.response_util import ResponseUtil
  14. # 特殊响应类型列表
  15. SPECIAL_RESPONSE_TYPES = (
  16. StreamingResponse,
  17. HTMLResponse,
  18. PlainTextResponse,
  19. RedirectResponse,
  20. FileResponse,
  21. )
  22. # 已包装响应的标记字段
  23. WRAPPED_RESPONSE_FIELDS = ("success", "code", "message")
  24. # 文件类型的Content-Type前缀
  25. FILE_CONTENT_TYPES = (
  26. "image/",
  27. "audio/",
  28. "video/",
  29. "application/pdf",
  30. "application/octet-stream",
  31. "application/zip",
  32. "application/x-rar",
  33. "application/x-tar",
  34. "application/x-7z",
  35. )
  36. def add_response_middleware(app):
  37. """
  38. 注册响应包装中间件
  39. 需在所有中间件最后注册以保证执行顺序
  40. """
  41. def _is_file_response(response: Response) -> bool:
  42. """判断是否为文件响应"""
  43. # 1. 通过类型名判断
  44. response_type = type(response).__name__
  45. if "FileResponse" in response_type or hasattr(response, "path"):
  46. return True
  47. # 2. 通过Content-Type判断
  48. content_type = _get_content_type(response)
  49. if content_type and any(
  50. content_type.startswith(prefix) for prefix in FILE_CONTENT_TYPES
  51. ):
  52. return True
  53. # 3. 通过Content-Disposition判断
  54. if hasattr(response, "raw_headers"):
  55. for header, value in response.raw_headers:
  56. if header.lower() == b"content-disposition":
  57. header_value = value.decode("utf-8", errors="ignore").lower()
  58. if "attachment" in header_value or "filename=" in header_value:
  59. return True
  60. return False
  61. def _is_html_response(response: Response) -> bool:
  62. """判断是否为HTML响应"""
  63. content_type = _get_content_type(response)
  64. return content_type and content_type.startswith("text/html")
  65. def _get_content_type(response: Response) -> Optional[str]:
  66. """获取响应的Content-Type"""
  67. if not hasattr(response, "raw_headers"):
  68. return None
  69. for header, value in response.raw_headers:
  70. if header.lower() == b"content-type":
  71. return value.decode("utf-8", errors="ignore").lower()
  72. return None
  73. def _is_not_warp_response(data) -> bool:
  74. """判断数据是否需要包装"""
  75. if data is None:
  76. return True
  77. elif isinstance(data, JSONResponse):
  78. body = data.body.decode("utf-8")
  79. return _is_not_warp_response(json.loads(body))
  80. if not isinstance(data, dict):
  81. return True
  82. return (
  83. False if all(field in data for field in WRAPPED_RESPONSE_FIELDS) else True
  84. )
  85. async def _handle_streaming_response(response: Response) -> Optional[Any]:
  86. """处理流式响应,尝试提取数据并包装"""
  87. try:
  88. # 1. 处理body_iterator
  89. if hasattr(response, "body_iterator"):
  90. data = await _extract_from_iterator(response.body_iterator)
  91. if data:
  92. return data
  93. # 2. 处理JSON类型的响应
  94. content_type = _get_content_type(response)
  95. if content_type == "application/json" and hasattr(response, "__dict__"):
  96. for attr_name, attr_value in response.__dict__.items():
  97. if attr_name not in [
  98. "body_iterator",
  99. "raw_headers",
  100. "headers",
  101. ] and isinstance(attr_value, dict):
  102. return attr_value
  103. elif content_type:
  104. logger.debug(f"未处理的响应类型: {content_type}")
  105. except Exception as e:
  106. logger.warn(f"特殊响应类型处理失败: {e}")
  107. return None
  108. async def _extract_from_iterator(body_iter) -> Optional[Any]:
  109. """从迭代器中提取数据"""
  110. if isinstance(body_iter, dict):
  111. return body_iter
  112. if not hasattr(body_iter, "__aiter__"):
  113. return None
  114. try:
  115. _aiter = body_iter.__aiter__()
  116. first_item = await _aiter.__anext__()
  117. if not first_item:
  118. return None
  119. # 处理字节类型
  120. if isinstance(first_item, bytes):
  121. try:
  122. decoded_data = first_item.decode()
  123. try:
  124. return json.loads(decoded_data)
  125. except json.JSONDecodeError:
  126. return decoded_data
  127. except UnicodeDecodeError:
  128. pass
  129. # 处理字符串类型
  130. elif isinstance(first_item, str):
  131. try:
  132. return json.loads(first_item)
  133. except json.JSONDecodeError:
  134. return first_item
  135. # 处理字典类型
  136. elif isinstance(first_item, dict):
  137. return first_item
  138. except Exception as e:
  139. logger.warn(f"迭代器数据提取失败: {e}")
  140. return None
  141. async def _extract_response_data(response: Response) -> Optional[Any]:
  142. """尝试从响应中提取数据"""
  143. # 1. 尝试获取响应体内容
  144. if hasattr(response, "body") and response.body:
  145. try:
  146. body_content = response.body.decode("utf-8")
  147. try:
  148. return json.loads(body_content)
  149. except json.JSONDecodeError:
  150. return body_content
  151. except UnicodeDecodeError:
  152. # 二进制数据,原样返回
  153. return None
  154. # 2. 尝试获取content属性
  155. elif hasattr(response, "content") and getattr(response, "content") is not None:
  156. return response.content
  157. return None
  158. @app.middleware("http")
  159. async def middleware(request: Request, call_next: Callable) -> Response:
  160. """响应包装中间件"""
  161. # 获取路由函数
  162. endpoint = request.scope.get("endpoint")
  163. path = request.url.path
  164. # 检查是否需要跳过包装
  165. skip_wrapper = getattr(endpoint, CommonConstant.NotWrapperFieldName, False)
  166. if skip_wrapper:
  167. logger.debug(f"跳过响应包装: {path}")
  168. return await call_next(request)
  169. # 执行请求处理
  170. response = await call_next(request)
  171. # 只处理成功的响应
  172. if response.status_code != 200:
  173. return response
  174. response_type = type(response).__name__
  175. # 1. 处理特殊响应类型
  176. if (
  177. _is_html_response(response)
  178. or isinstance(response, SPECIAL_RESPONSE_TYPES)
  179. or _is_file_response(response)
  180. ):
  181. return response
  182. # 2. 处理StreamingResponse类型
  183. if response_type.endswith("StreamingResponse"):
  184. response_data = await _handle_streaming_response(response)
  185. # 3. 处理已包装的JSONResponse
  186. elif isinstance(response, JSONResponse):
  187. content = response.body.decode("utf-8")
  188. response_data = json.loads(content)
  189. # 4. 处理其他类型响应
  190. else:
  191. response_data = await _extract_response_data(response)
  192. if _is_not_warp_response(response_data):
  193. return ResponseUtil.success(data=response_data)
  194. else:
  195. return ResponseUtil.json(response_data)