openai.py 7.4 KB


  1. import re, json
  2. from openai import OpenAI
  3. from pathlib import Path
  4. import tools.utils as utils, core.configs as configs
  5. class OpenAi:
  6. _api_key = None
  7. _api_url = None
  8. _max_tokens = 150
  9. _api_model = None
  10. def __init__(self, api_url: str = None, api_key: str = None, api_model: str = None):
  11. self._api_url = api_url if api_url else configs.ai.api_url
  12. self._api_key = api_key if api_key else configs.ai.api_key
  13. self._api_model = api_model if api_model else configs.ai.model
  14. max_tokens = configs.ai.max_tokens
  15. if max_tokens:
  16. self._max_tokens = int(max_tokens)
  17. def call_openai(
  18. self,
  19. system_prompt: str,
  20. user_prompt: str,
  21. api_url: str = None,
  22. api_key: str = None,
  23. api_model: str = None,
  24. ) -> json:
  25. self.check_api(api_key, api_model, api_url)
  26. utils.get_logger().info(
  27. f"调用AI API ==> Url:{self._api_url},Model:{self._api_model}"
  28. )
  29. client = OpenAI(api_key=self._api_key, base_url=self._api_url)
  30. completion = client.chat.completions.create(
  31. model=self._api_model,
  32. messages=[
  33. {
  34. "role": "system",
  35. "content": system_prompt,
  36. },
  37. {
  38. "role": "user",
  39. "content": user_prompt,
  40. },
  41. ],
  42. stream=False,
  43. temperature=0.7,
  44. response_format={"type": "json_object"},
  45. # max_tokens=self._ai_max_tokens,
  46. )
  47. try:
  48. response = completion.model_dump_json()
  49. result = {}
  50. response_json = json.loads(response)
  51. res_str = self._extract_message_content(response_json)
  52. result_data = self._parse_response(res_str, True)
  53. if result_data:
  54. result["data"] = result_data
  55. usage = response_json["usage"]
  56. result["completion_tokens"] = usage.get("completion_tokens", 0)
  57. result["prompt_tokens"] = usage.get("prompt_tokens", 0)
  58. result["total_tokens"] = usage.get("total_tokens", 0)
  59. utils.get_logger().info(f"AI Process JSON: {result}")
  60. else:
  61. utils.get_logger().info(f"AI Response: {response}")
  62. return result
  63. except Exception as e:
  64. raise Exception(f"解析 AI 响应错误: {e}")
  65. def check_api(self, api_key, api_model, api_url):
  66. if api_url:
  67. self._api_url = api_url
  68. if api_key:
  69. self._api_key = api_key
  70. if api_model:
  71. self._api_model = api_model
  72. if self._api_key is None:
  73. raise Exception("AI API key 没有配置")
  74. if self._api_url is None:
  75. raise Exception("AI API url 没有配置")
  76. if self._api_model is None:
  77. raise Exception("AI API model 没有配置")
  78. @staticmethod
  79. def _extract_message_content(response_json: dict) -> str:
  80. utils.get_logger().info(f"AI Response JSON: {response_json}")
  81. if "choices" in response_json and len(response_json["choices"]) > 0:
  82. choice = response_json["choices"][0]
  83. message_content = choice.get("message", {}).get("content", "")
  84. elif "message" in response_json:
  85. message_content = response_json["message"].get("content", "")
  86. else:
  87. raise Exception("AI 响应中未找到有效的 choices 或 message 数据")
  88. # 移除多余的 ```json 和 ```
  89. if message_content.startswith("```json") and message_content.endswith("```"):
  90. message_content = message_content[6:-3]
  91. # 去除开头的 'n' 字符
  92. if message_content.startswith("n"):
  93. message_content = message_content[1:]
  94. # 移除无效的转义字符和时间戳前缀
  95. message_content = re.sub(
  96. r"\\[0-9]{2}", "", message_content
  97. ) # 移除 \32 等无效转义字符
  98. message_content = re.sub(
  99. r"\d{4}-\d{2}-\dT\d{2}:\d{2}:\d{2}\.\d+Z", "", message_content
  100. ) # 移除时间戳
  101. message_content = message_content.strip() # 去除首尾空白字符
  102. # 替换所有的反斜杠
  103. message_content = message_content.replace("\\", "")
  104. return message_content
  105. def _parse_response(self, response: str, first=True) -> json:
  106. # utils.get_logger().info(f"AI Response JSON STR: {response}")
  107. try:
  108. data = json.loads(response)
  109. return data
  110. except json.JSONDecodeError as e:
  111. if first:
  112. utils.get_logger().error(
  113. f"JSON 解析错误,去除部分特殊字符重新解析一次: {e}"
  114. )
  115. # 替换中文引号为空
  116. message_content = re.sub(r"[“”]", "", response) # 替换双引号
  117. message_content = re.sub(r"[‘’]", "", message_content) # 替换单引号
  118. return self._parse_response(message_content, False)
  119. else:
  120. raise Exception(f"解析 AI 响应错误: {response} {e}")
  121. def call_openai_with_image(
  122. self,
  123. image_path,
  124. system_prompt: str,
  125. user_prompt: str,
  126. api_url: str = None,
  127. api_key: str = None,
  128. api_model: str = None,
  129. ) -> json:
  130. pass
  131. def call_openai_with_file(
  132. self,
  133. file_path,
  134. system_prompt: str,
  135. user_prompt: str,
  136. api_url: str = None,
  137. api_key: str = None,
  138. api_model: str = None,
  139. ) -> json:
  140. self.check_api(api_key, api_model, api_url)
  141. utils.get_logger().info(
  142. f"调用AI API File==> Url:{self._api_url},Model:{self._api_model}"
  143. )
  144. client = OpenAI(api_key=self._api_key, base_url=self._api_url)
  145. file_object = client.files.create(
  146. file=Path(file_path),
  147. purpose="file-extract",
  148. )
  149. completion = client.chat.completions.create(
  150. model=self._api_model,
  151. messages=[
  152. {
  153. "role": "system",
  154. # "content": system_prompt,
  155. "content": f"fileid://{file_object.id}",
  156. },
  157. {
  158. "role": "user",
  159. "content": user_prompt,
  160. },
  161. ],
  162. stream=False,
  163. temperature=0.7,
  164. response_format={"type": "json_object"},
  165. # max_tokens=self._ai_max_tokens,
  166. )
  167. try:
  168. response = completion.model_dump_json()
  169. result = {}
  170. response_json = json.loads(response)
  171. res_str = self._extract_message_content(response_json)
  172. result_data = self._parse_response(res_str, True)
  173. if result_data:
  174. result["data"] = result_data
  175. usage = response_json["usage"]
  176. result["completion_tokens"] = usage.get("completion_tokens", 0)
  177. result["prompt_tokens"] = usage.get("prompt_tokens", 0)
  178. result["total_tokens"] = usage.get("total_tokens", 0)
  179. utils.get_logger().info(f"AI Process JSON: {result}")
  180. else:
  181. utils.get_logger().info(f"AI Response: {response}")
  182. return result
  183. except Exception as e:
  184. raise Exception(f"解析 AI 响应错误: {e}")
  185. pass