import json from typing import Callable, Optional, Any from fastapi import Request, Response from fastapi.responses import JSONResponse, StreamingResponse from starlette.responses import ( HTMLResponse, PlainTextResponse, RedirectResponse, FileResponse, ) from core.constant import CommonConstant from utils import logger from utils.response_util import ResponseUtil # 特殊响应类型列表 SPECIAL_RESPONSE_TYPES = ( StreamingResponse, HTMLResponse, PlainTextResponse, RedirectResponse, FileResponse, ) # 已包装响应的标记字段 WRAPPED_RESPONSE_FIELDS = ("success", "code", "message") # 文件类型的Content-Type前缀 FILE_CONTENT_TYPES = ( "image/", "audio/", "video/", "application/pdf", "application/octet-stream", "application/zip", "application/x-rar", "application/x-tar", "application/x-7z", ) def add_response_middleware(app): """ 注册响应包装中间件 需在所有中间件最后注册以保证执行顺序 """ def _is_file_response(response: Response) -> bool: """判断是否为文件响应""" # 1. 通过类型名判断 response_type = type(response).__name__ if "FileResponse" in response_type or hasattr(response, "path"): return True # 2. 通过Content-Type判断 content_type = _get_content_type(response) if content_type and any( content_type.startswith(prefix) for prefix in FILE_CONTENT_TYPES ): return True # 3. 通过Content-Disposition判断 if hasattr(response, "raw_headers"): for header, value in response.raw_headers: if header.lower() == b"content-disposition": header_value = value.decode("utf-8", errors="ignore").lower() if "attachment" in header_value or "filename=" in header_value: return True return False def _is_html_response(response: Response) -> bool: """判断是否为HTML响应""" content_type = _get_content_type(response) return content_type and content_type.startswith("text/html") def _get_content_type(response: Response) -> Optional[str]: """获取响应的Content-Type""" if not hasattr(response, "raw_headers"): return None for header, value in response.raw_headers: if header.lower() == b"content-type": return value.decode("utf-8", errors="ignore").lower() return None def _is_not_warp_response(data) -> bool: """判断数据是否需要包装""" if data is None: return True elif isinstance(data, JSONResponse): body = data.body.decode("utf-8") return _is_not_warp_response(json.loads(body)) if not isinstance(data, dict): return True return ( False if all(field in data for field in WRAPPED_RESPONSE_FIELDS) else True ) async def _handle_streaming_response(response: Response) -> Optional[Any]: """处理流式响应,尝试提取数据并包装""" try: # 1. 处理body_iterator if hasattr(response, "body_iterator"): data = await _extract_from_iterator(response.body_iterator) if data: return data # 2. 处理JSON类型的响应 content_type = _get_content_type(response) if content_type == "application/json" and hasattr(response, "__dict__"): for attr_name, attr_value in response.__dict__.items(): if attr_name not in [ "body_iterator", "raw_headers", "headers", ] and isinstance(attr_value, dict): return attr_value elif content_type: logger.debug(f"未处理的响应类型: {content_type}") except Exception as e: logger.warn(f"特殊响应类型处理失败: {e}") return None async def _extract_from_iterator(body_iter) -> Optional[Any]: """从迭代器中提取数据""" if isinstance(body_iter, dict): return body_iter if not hasattr(body_iter, "__aiter__"): return None try: _aiter = body_iter.__aiter__() first_item = await _aiter.__anext__() if not first_item: return None # 处理字节类型 if isinstance(first_item, bytes): try: decoded_data = first_item.decode() try: return json.loads(decoded_data) except json.JSONDecodeError: return decoded_data except UnicodeDecodeError: pass # 处理字符串类型 elif isinstance(first_item, str): try: return json.loads(first_item) except json.JSONDecodeError: return first_item # 处理字典类型 elif isinstance(first_item, dict): return first_item except Exception as e: logger.warn(f"迭代器数据提取失败: {e}") return None async def _extract_response_data(response: Response) -> Optional[Any]: """尝试从响应中提取数据""" # 1. 尝试获取响应体内容 if hasattr(response, "body") and response.body: try: body_content = response.body.decode("utf-8") try: return json.loads(body_content) except json.JSONDecodeError: return body_content except UnicodeDecodeError: # 二进制数据,原样返回 return None # 2. 尝试获取content属性 elif hasattr(response, "content") and getattr(response, "content") is not None: return response.content return None @app.middleware("http") async def middleware(request: Request, call_next: Callable) -> Response: """响应包装中间件""" # 获取路由函数 endpoint = request.scope.get("endpoint") path = request.url.path # 检查是否需要跳过包装 skip_wrapper = getattr(endpoint, CommonConstant.NotWrapperFieldName, False) if skip_wrapper: logger.debug(f"跳过响应包装: {path}") return await call_next(request) # 执行请求处理 response = await call_next(request) # 只处理成功的响应 if response.status_code != 200: return response response_type = type(response).__name__ # 1. 处理特殊响应类型 if ( _is_html_response(response) or isinstance(response, SPECIAL_RESPONSE_TYPES) or _is_file_response(response) ): return response # 2. 处理StreamingResponse类型 if response_type.endswith("StreamingResponse"): response_data = await _handle_streaming_response(response) # 3. 处理已包装的JSONResponse elif isinstance(response, JSONResponse): content = response.body.decode("utf-8") response_data = json.loads(content) # 4. 处理其他类型响应 else: response_data = await _extract_response_data(response) if _is_not_warp_response(response_data): return ResponseUtil.success(data=response_data) else: return ResponseUtil.json(response_data)