ai_helper.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import json
  2. import re
  3. from openai import OpenAI
  4. import utils
  5. class AiHelper:
  6. _ai_api_key = None
  7. _ai_api_url = None
  8. _ai_max_tokens = 150
  9. def __init__(self):
  10. self._ai_api_key = utils.get_config_value("ai.key")
  11. self._ai_api_url = utils.get_config_value("ai.url")
  12. self._api_model = utils.get_config_value("ai.model")
  13. max_tokens = utils.get_config_value("ai.max_tokens")
  14. if max_tokens:
  15. self._ai_max_tokens = int(max_tokens)
  16. def call_openai(self, system_prompt: str, user_prompt: str) -> json:
  17. utils.get_logger().info("调用AI API")
  18. if self._ai_api_key is None:
  19. raise Exception("AI API key 没有配置")
  20. if self._ai_api_url is None:
  21. raise Exception("AI API url 没有配置")
  22. if self._api_model is None:
  23. raise Exception("AI API model 没有配置")
  24. client = OpenAI(api_key=self._ai_api_key, base_url=self._ai_api_url)
  25. completion = client.chat.completions.create(
  26. model=self._api_model,
  27. messages=[
  28. {
  29. "role": "system",
  30. "content": system_prompt,
  31. },
  32. {
  33. "role": "user",
  34. "content": user_prompt,
  35. },
  36. ],
  37. stream=False,
  38. temperature=0.7,
  39. )
  40. try:
  41. response = completion.model_dump_json()
  42. response_json = json.loads(response)
  43. res_str = self._extract_message_content(response_json)
  44. result = self._parse_response(res_str, True)
  45. if result:
  46. usage = response_json["usage"]
  47. result["completion_tokens"] = usage.get("completion_tokens", 0)
  48. result["prompt_tokens"] = usage.get("prompt_tokens", 0)
  49. result["total_tokens"] = usage.get("total_tokens", 0)
  50. # utils.get_logger().info(f"AI Process JSON: {result}")
  51. else:
  52. utils.get_logger().info(f"AI Response: {response}")
  53. return result
  54. except Exception as e:
  55. raise Exception(f"解析 AI 响应错误: {e}")
  56. @staticmethod
  57. def _extract_message_content(response_json: dict) -> str:
  58. if "choices" in response_json and len(response_json["choices"]) > 0:
  59. choice = response_json["choices"][0]
  60. message_content = choice.get("message", {}).get("content", "")
  61. elif "message" in response_json:
  62. message_content = response_json["message"].get("content", "")
  63. else:
  64. raise Exception("AI 响应中未找到有效的 choices 或 message 数据")
  65. # 移除多余的 ```json 和 ```
  66. if message_content.startswith("```json") and message_content.endswith("```"):
  67. message_content = message_content[6:-3]
  68. # 去除开头的 'n' 字符
  69. if message_content.startswith("n"):
  70. message_content = message_content[1:]
  71. # 移除无效的转义字符和时间戳前缀
  72. message_content = re.sub(
  73. r"\\[0-9]{2}", "", message_content
  74. ) # 移除 \32 等无效转义字符
  75. message_content = re.sub(
  76. r"\d{4}-\d{2}-\dT\d{2}:\d{2}:\d{2}\.\d+Z", "", message_content
  77. ) # 移除时间戳
  78. message_content = message_content.strip() # 去除首尾空白字符
  79. # 替换所有的反斜杠
  80. message_content = message_content.replace("\\", "")
  81. return message_content
  82. def _parse_response(self, response: str, first=True) -> json:
  83. # utils.get_logger().info(f"AI Response JSON STR: {response}")
  84. try:
  85. data = json.loads(response)
  86. return data
  87. except json.JSONDecodeError as e:
  88. if first:
  89. utils.get_logger().error(
  90. f"JSON 解析错误,去除部分特殊字符重新解析一次: {e}"
  91. )
  92. # 替换中文引号为空
  93. message_content = re.sub(r"[“”]", "", response) # 替换双引号
  94. message_content = re.sub(r"[‘’]", "", message_content) # 替换单引号
  95. return self._parse_response(message_content, False)
  96. else:
  97. raise Exception(f"解析 AI 响应错误: {response} {e}")