extract.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. import csv, json, tools.utils as utils, os
  2. from tools.stores.mysql_store import MysqlStore
  3. from tools.models.standard_model import StandardModel
  4. class ImageExtractor:
  5. def __init__(self):
  6. self._logger = utils.get_logger()
  7. self._db_store = MysqlStore()
  8. self._base_path = "./temp_files/images/output"
  9. self._complete_path=""
  10. self._ai = utils.AiHelper()
  11. self._err_files=[]
  12. self._file_name = ""
  13. self._sys_prompt = "请提取图片中的表格,用json格式输出。"
  14. self._user_prompt = """提取表格信息,要求:
  15. 1. 提取结构化信息:```typescript
  16. type item {
  17. a: string; //书号
  18. b: string; //定额编号
  19. c:string; //定额名称
  20. d: string; //工作内容
  21. e: string; //单位
  22. f: string; //基本定额
  23. g: float; //基价(元)
  24. h: float; //单重(t)
  25. i: float; //工费
  26. j: float; //料费
  27. k: float; //机费
  28. l: string; //主材
  29. }
  30. ```
  31. 2. 提取的文字中间的空格需要保留,数据没有就留空
  32. 3. 确保符号提取准确,例如 kg,m²,m³,直径符号∅等
  33. 4. 返回压缩成一行的item数组的json字符串
  34. """
  35. def extract(self,file_name: str):
  36. self._file_name = file_name
  37. self._err_files =[]
  38. path = f"{self._base_path}/img/{self._file_name}/"
  39. self._complete_path = f"{self._base_path}/img_complete/{self._file_name}/"
  40. os.makedirs(self._complete_path , exist_ok=True)
  41. try:
  42. self._logger.info(f"开始处理目录: {path}")
  43. # 确保目录存在
  44. if not os.path.exists(path):
  45. self._logger.error(f"目录不存在: {path}")
  46. return
  47. # 遍历目录下的所有文件
  48. for root, dirs, files in os.walk(path):
  49. for file in files:
  50. # 检查是否为图片文件
  51. if file.lower().endswith(('.png', '.jpg', '.jpeg')):
  52. image_path = os.path.join(root, file)
  53. self.extract_image(image_path)
  54. self._logger.info(f"目录处理完成: {path}")
  55. if len(self._err_files)>0:
  56. self._logger.error(f"----【处理图片失败】-----: {self._err_files}")
  57. except Exception as e:
  58. self._logger.error(f"处理目录失败 {path}: {e}")
  59. def extract_image(self, image_path: str) -> None:
  60. try:
  61. self._logger.info(f"开始处理图片: {image_path}")
  62. # content = self._ai.call_openai_with_image(image_path,self._sys_prompt,self._user_prompt,api_model="qwen2.5-vl-72b-instruct")
  63. api_key= utils.get_config_value("fastgpt.api_key")
  64. content = self._ai.call_fastgpt_ai_with_image(image_path,self._user_prompt,api_key)
  65. self.save_to_db(content)
  66. # 保存成功后移动文件到已处理目录
  67. os.rename(image_path, os.path.join(self._complete_path,os.path.basename(image_path)))
  68. self._logger.info(f"图片处理完成: {image_path}")
  69. except Exception as e:
  70. self._err_files.append(image_path)
  71. self._logger.error(f"处理图片失败 {image_path}: {e}")
  72. def save_to_db(self, data_list: str|list) -> None:
  73. try:
  74. self._logger.info(f"开始保存图片内到数据库:{data_list}")
  75. if isinstance(data_list,str):
  76. data_list = json.loads(data_list)
  77. for item in data_list:
  78. try :
  79. standard = StandardModel(
  80. book_number=item['a'],
  81. quota_number=item['b'],
  82. quota_name=item['c'],
  83. work_content=item['d'],
  84. unit=item['e'],
  85. basic_quota=item['f'],
  86. base_price=item['g'],
  87. unit_weight=item['h'],
  88. labor_cost=item['i'],
  89. material_cost=item['j'],
  90. machine_cost=item['k'],
  91. main_material=item['l']
  92. )
  93. if not self._db_store.insert_standard(standard):
  94. self._logger.error(f"保存数据到数据库失败: {item}")
  95. except Exception as e:
  96. self._logger.error(f"保存图片内容失败: {e}")
  97. continue
  98. except Exception as e:
  99. self._logger.error(f"保存图片内容失败: {e}")
  100. def export(self):
  101. try:
  102. self._logger.info(f"开始导出数据库数据")
  103. data = self._db_store.query_standard_group_by_book()
  104. for k, v in data.items():
  105. # 数据保存为 csv
  106. csv_file = f"{self._base_path}/csv/{k}.csv"
  107. # 确保目录存在
  108. os.makedirs(os.path.dirname(csv_file), exist_ok=True)
  109. with open(csv_file, 'w', newline='', encoding='utf-8-sig') as f:
  110. writer = csv.writer(f)
  111. writer.writerow(['书号', '定额编号', '定额名称', '工作内容', '单位', '基本定额', '基价(元)', '单重(t)', '工费', '料费', '机费', '主材'])
  112. for item in v:
  113. # 将 StandardModel 对象的属性提取出来,构造成一个列表
  114. row = [
  115. item.book_number,
  116. item.quota_number,
  117. item.quota_name,
  118. item.work_content,
  119. item.unit,
  120. item.basic_quota,
  121. item.base_price,
  122. item.unit_weight,
  123. item.labor_cost,
  124. item.material_cost,
  125. item.machine_cost,
  126. item.main_material
  127. ]
  128. writer.writerow(row)
  129. self._logger.info(f"成功导出数据库数据")
  130. return data
  131. except Exception as e:
  132. self._logger.error(f"导出数据库数据失败: {e}")