ai_helper.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. import json, re
  2. import tools.utils as utils
  3. from openai import OpenAI
  4. from pathlib import Path
  5. class AiHelper:
  6. _ai_api_key = None
  7. _ai_api_url = None
  8. _ai_max_tokens = 150
  9. def __init__(self, api_url: str = None, api_key: str = None, api_model: str = None):
  10. self._ai_api_url = api_url if api_url else utils.get_config_value("ai.url")
  11. self._ai_api_key = api_key if api_key else utils.get_config_value("ai.key")
  12. self._api_model = api_model if api_model else utils.get_config_value("ai.model")
  13. max_tokens = utils.get_config_value("ai.max_tokens")
  14. if max_tokens:
  15. self._ai_max_tokens = int(max_tokens)
  16. def call_openai(
  17. self,
  18. system_prompt: str,
  19. user_prompt: str,
  20. api_url: str = None,
  21. api_key: str = None,
  22. api_model: str = None,
  23. ) -> json:
  24. self.check_api(api_key, api_model, api_url)
  25. utils.get_logger().info(
  26. f"调用AI API ==> Url:{self._ai_api_url},Model:{self._api_model}"
  27. )
  28. client = OpenAI(api_key=self._ai_api_key, base_url=self._ai_api_url)
  29. completion = client.chat.completions.create(
  30. model=self._api_model,
  31. messages=[
  32. {
  33. "role": "system",
  34. "content": system_prompt,
  35. },
  36. {
  37. "role": "user",
  38. "content": user_prompt,
  39. },
  40. ],
  41. stream=False,
  42. temperature=0.7,
  43. response_format={"type": "json_object"},
  44. # max_tokens=self._ai_max_tokens,
  45. )
  46. try:
  47. response = completion.model_dump_json()
  48. result = {}
  49. response_json = json.loads(response)
  50. res_str = self._extract_message_content(response_json)
  51. result_data = self._parse_response(res_str, True)
  52. if result_data:
  53. result["data"] = result_data
  54. usage = response_json["usage"]
  55. result["completion_tokens"] = usage.get("completion_tokens", 0)
  56. result["prompt_tokens"] = usage.get("prompt_tokens", 0)
  57. result["total_tokens"] = usage.get("total_tokens", 0)
  58. utils.get_logger().info(f"AI Process JSON: {result}")
  59. else:
  60. utils.get_logger().info(f"AI Response: {response}")
  61. return result
  62. except Exception as e:
  63. raise Exception(f"解析 AI 响应错误: {e}")
  64. def check_api(self, api_key, api_model, api_url):
  65. if api_url:
  66. self._ai_api_url = api_url
  67. if api_key:
  68. self._ai_api_key = api_key
  69. if api_model:
  70. self._api_model = api_model
  71. if self._ai_api_key is None:
  72. raise Exception("AI API key 没有配置")
  73. if self._ai_api_url is None:
  74. raise Exception("AI API url 没有配置")
  75. if self._api_model is None:
  76. raise Exception("AI API model 没有配置")
  77. @staticmethod
  78. def _extract_message_content(response_json: dict) -> str:
  79. utils.get_logger().info(f"AI Response JSON: {response_json}")
  80. if "choices" in response_json and len(response_json["choices"]) > 0:
  81. choice = response_json["choices"][0]
  82. message_content = choice.get("message", {}).get("content", "")
  83. elif "message" in response_json:
  84. message_content = response_json["message"].get("content", "")
  85. else:
  86. raise Exception("AI 响应中未找到有效的 choices 或 message 数据")
  87. # 移除多余的 ```json 和 ```
  88. if message_content.startswith("```json") and message_content.endswith("```"):
  89. message_content = message_content[6:-3]
  90. # 去除开头的 'n' 字符
  91. if message_content.startswith("n"):
  92. message_content = message_content[1:]
  93. # 移除无效的转义字符和时间戳前缀
  94. message_content = re.sub(
  95. r"\\[0-9]{2}", "", message_content
  96. ) # 移除 \32 等无效转义字符
  97. message_content = re.sub(
  98. r"\d{4}-\d{2}-\dT\d{2}:\d{2}:\d{2}\.\d+Z", "", message_content
  99. ) # 移除时间戳
  100. message_content = message_content.strip() # 去除首尾空白字符
  101. # 替换所有的反斜杠
  102. message_content = message_content.replace("\\", "")
  103. return message_content
  104. def _parse_response(self, response: str, first=True) -> json:
  105. # utils.get_logger().info(f"AI Response JSON STR: {response}")
  106. try:
  107. data = json.loads(response)
  108. return data
  109. except json.JSONDecodeError as e:
  110. if first:
  111. utils.get_logger().error(
  112. f"JSON 解析错误,去除部分特殊字符重新解析一次: {e}"
  113. )
  114. # 替换中文引号为空
  115. message_content = re.sub(r"[“”]", "", response) # 替换双引号
  116. message_content = re.sub(r"[‘’]", "", message_content) # 替换单引号
  117. return self._parse_response(message_content, False)
  118. else:
  119. raise Exception(f"解析 AI 响应错误: {response} {e}")
  120. def call_openai_with_image(
  121. self,
  122. image_path,
  123. system_prompt: str,
  124. user_prompt: str,
  125. api_url: str = None,
  126. api_key: str = None,
  127. api_model: str = None,
  128. ) -> json:
  129. pass
  130. def call_openai_with_file(
  131. self,
  132. file_path,
  133. system_prompt: str,
  134. user_prompt: str,
  135. api_url: str = None,
  136. api_key: str = None,
  137. api_model: str = None,
  138. ) -> json:
  139. self.check_api(api_key, api_model, api_url)
  140. utils.get_logger().info(
  141. f"调用AI API File==> Url:{self._ai_api_url},Model:{self._api_model}"
  142. )
  143. client = OpenAI(api_key=self._ai_api_key, base_url=self._ai_api_url)
  144. file_object = client.files.create(
  145. file=Path(file_path),
  146. purpose="file-extract",
  147. )
  148. completion = client.chat.completions.create(
  149. model=self._api_model,
  150. messages=[
  151. {
  152. "role": "system",
  153. # "content": system_prompt,
  154. "content": f"fileid://{file_object.id}",
  155. },
  156. {
  157. "role": "user",
  158. "content": user_prompt,
  159. },
  160. ],
  161. stream=False,
  162. temperature=0.7,
  163. response_format={"type": "json_object"},
  164. # max_tokens=self._ai_max_tokens,
  165. )
  166. try:
  167. response = completion.model_dump_json()
  168. result = {}
  169. response_json = json.loads(response)
  170. res_str = self._extract_message_content(response_json)
  171. result_data = self._parse_response(res_str, True)
  172. if result_data:
  173. result["data"] = result_data
  174. usage = response_json["usage"]
  175. result["completion_tokens"] = usage.get("completion_tokens", 0)
  176. result["prompt_tokens"] = usage.get("prompt_tokens", 0)
  177. result["total_tokens"] = usage.get("total_tokens", 0)
  178. utils.get_logger().info(f"AI Process JSON: {result}")
  179. else:
  180. utils.get_logger().info(f"AI Response: {response}")
  181. return result
  182. except Exception as e:
  183. raise Exception(f"解析 AI 响应错误: {e}")
  184. pass