| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229 |
- import json
- from typing import Callable, Optional, Any
- from fastapi import Request, Response
- from fastapi.responses import JSONResponse, StreamingResponse
- from starlette.responses import (
- HTMLResponse,
- PlainTextResponse,
- RedirectResponse,
- FileResponse,
- )
- from core.constant import CommonConstant
- from utils import logger
- from utils.response_util import ResponseUtil
- # 特殊响应类型列表
- SPECIAL_RESPONSE_TYPES = (
- StreamingResponse,
- HTMLResponse,
- PlainTextResponse,
- RedirectResponse,
- FileResponse,
- )
- # 已包装响应的标记字段
- WRAPPED_RESPONSE_FIELDS = ("success", "code", "message")
- # 文件类型的Content-Type前缀
- FILE_CONTENT_TYPES = (
- "image/",
- "audio/",
- "video/",
- "application/pdf",
- "application/octet-stream",
- "application/zip",
- "application/x-rar",
- "application/x-tar",
- "application/x-7z",
- )
- def add_response_middleware(app):
- """
- 注册响应包装中间件
- 需在所有中间件最后注册以保证执行顺序
- """
- def _is_file_response(response: Response) -> bool:
- """判断是否为文件响应"""
- # 1. 通过类型名判断
- response_type = type(response).__name__
- if "FileResponse" in response_type or hasattr(response, "path"):
- return True
- # 2. 通过Content-Type判断
- content_type = _get_content_type(response)
- if content_type and any(
- content_type.startswith(prefix) for prefix in FILE_CONTENT_TYPES
- ):
- return True
- # 3. 通过Content-Disposition判断
- if hasattr(response, "raw_headers"):
- for header, value in response.raw_headers:
- if header.lower() == b"content-disposition":
- header_value = value.decode("utf-8", errors="ignore").lower()
- if "attachment" in header_value or "filename=" in header_value:
- return True
- return False
- def _is_html_response(response: Response) -> bool:
- """判断是否为HTML响应"""
- content_type = _get_content_type(response)
- return content_type and content_type.startswith("text/html")
- def _get_content_type(response: Response) -> Optional[str]:
- """获取响应的Content-Type"""
- if not hasattr(response, "raw_headers"):
- return None
- for header, value in response.raw_headers:
- if header.lower() == b"content-type":
- return value.decode("utf-8", errors="ignore").lower()
- return None
- def _is_not_warp_response(data) -> bool:
- """判断数据是否需要包装"""
- if data is None:
- return True
- elif isinstance(data, JSONResponse):
- body = data.body.decode("utf-8")
- return _is_not_warp_response(json.loads(body))
- if not isinstance(data, dict):
- return True
- return (
- False if all(field in data for field in WRAPPED_RESPONSE_FIELDS) else True
- )
- async def _handle_streaming_response(response: Response) -> Optional[Any]:
- """处理流式响应,尝试提取数据并包装"""
- try:
- # 1. 处理body_iterator
- if hasattr(response, "body_iterator"):
- data = await _extract_from_iterator(response.body_iterator)
- if data:
- return data
- # 2. 处理JSON类型的响应
- content_type = _get_content_type(response)
- if content_type == "application/json" and hasattr(response, "__dict__"):
- for attr_name, attr_value in response.__dict__.items():
- if attr_name not in [
- "body_iterator",
- "raw_headers",
- "headers",
- ] and isinstance(attr_value, dict):
- return attr_value
- elif content_type:
- logger.debug(f"未处理的响应类型: {content_type}")
- except Exception as e:
- logger.warn(f"特殊响应类型处理失败: {e}")
- return None
- async def _extract_from_iterator(body_iter) -> Optional[Any]:
- """从迭代器中提取数据"""
- if isinstance(body_iter, dict):
- return body_iter
- if not hasattr(body_iter, "__aiter__"):
- return None
- try:
- _aiter = body_iter.__aiter__()
- first_item = await _aiter.__anext__()
- if not first_item:
- return None
- # 处理字节类型
- if isinstance(first_item, bytes):
- try:
- decoded_data = first_item.decode()
- try:
- return json.loads(decoded_data)
- except json.JSONDecodeError:
- return decoded_data
- except UnicodeDecodeError:
- pass
- # 处理字符串类型
- elif isinstance(first_item, str):
- try:
- return json.loads(first_item)
- except json.JSONDecodeError:
- return first_item
- # 处理字典类型
- elif isinstance(first_item, dict):
- return first_item
- except Exception as e:
- logger.warn(f"迭代器数据提取失败: {e}")
- return None
- async def _extract_response_data(response: Response) -> Optional[Any]:
- """尝试从响应中提取数据"""
- # 1. 尝试获取响应体内容
- if hasattr(response, "body") and response.body:
- try:
- body_content = response.body.decode("utf-8")
- try:
- return json.loads(body_content)
- except json.JSONDecodeError:
- return body_content
- except UnicodeDecodeError:
- # 二进制数据,原样返回
- return None
- # 2. 尝试获取content属性
- elif hasattr(response, "content") and getattr(response, "content") is not None:
- return response.content
- return None
- @app.middleware("http")
- async def middleware(request: Request, call_next: Callable) -> Response:
- """响应包装中间件"""
- # 获取路由函数
- endpoint = request.scope.get("endpoint")
- path = request.url.path
- # 检查是否需要跳过包装
- skip_wrapper = getattr(endpoint, CommonConstant.NotWrapperFieldName, False)
- if skip_wrapper:
- logger.debug(f"跳过响应包装: {path}")
- return await call_next(request)
- # 执行请求处理
- response = await call_next(request)
- # 只处理成功的响应
- if response.status_code != 200:
- return response
- response_type = type(response).__name__
- # 1. 处理特殊响应类型
- if (
- _is_html_response(response)
- or isinstance(response, SPECIAL_RESPONSE_TYPES)
- or _is_file_response(response)
- ):
- return response
- # 2. 处理StreamingResponse类型
- if response_type.endswith("StreamingResponse"):
- response_data = await _handle_streaming_response(response)
- # 3. 处理已包装的JSONResponse
- elif isinstance(response, JSONResponse):
- content = response.body.decode("utf-8")
- response_data = json.loads(content)
- # 4. 处理其他类型响应
- else:
- response_data = await _extract_response_data(response)
- if _is_not_warp_response(response_data):
- return ResponseUtil.success(data=response_data)
- else:
- return ResponseUtil.json(response_data)
|