yue 3 mesi fa
parent
commit
a3b21362e7

+ 52 - 19
SourceCode/IntelligentRailwayCosting/app/ai/openai.py

@@ -1,16 +1,17 @@
-import re ,json, os
+import re, json
 from openai import OpenAI
 from pathlib import Path
 
 import tools.utils as utils, core.configs as configs
 
+
 class OpenAi:
     _api_key = None
     _api_url = None
     _max_tokens = 150
-    _api_model =None
+    _api_model = None
 
-    def __init__(self, api_url: str=None, api_key: str=None, api_model: str=None):
+    def __init__(self, api_url: str = None, api_key: str = None, api_model: str = None):
         self._api_url = api_url if api_url else configs.ai.api_url
         self._api_key = api_key if api_key else configs.ai.api_key
         self._api_model = api_model if api_model else configs.ai.model
@@ -18,9 +19,18 @@ class OpenAi:
         if max_tokens:
             self._max_tokens = int(max_tokens)
 
-    def call_openai(self, system_prompt: str, user_prompt: str,api_url: str=None,api_key: str=None,api_model: str=None) -> json:
+    def call_openai(
+            self,
+            system_prompt: str,
+            user_prompt: str,
+            api_url: str = None,
+            api_key: str = None,
+            api_model: str = None,
+    ) -> json:
         self.check_api(api_key, api_model, api_url)
-        utils.get_logger().info(f"调用AI API ==> Url:{self._api_url},Model:{self._api_model}")
+        utils.get_logger().info(
+            f"调用AI API ==> Url:{self._api_url},Model:{self._api_model}"
+        )
 
         client = OpenAI(api_key=self._api_key, base_url=self._api_url)
         completion = client.chat.completions.create(
@@ -85,18 +95,19 @@ class OpenAi:
             raise Exception("AI 响应中未找到有效的 choices 或 message 数据")
 
         # 移除多余的 ```json 和 ```
-        if message_content.startswith("```json") and message_content.endswith(
-                "```"):
+        if message_content.startswith("```json") and message_content.endswith("```"):
             message_content = message_content[6:-3]
 
         # 去除开头的 'n' 字符
         if message_content.startswith("n"):
             message_content = message_content[1:]
         # 移除无效的转义字符和时间戳前缀
-        message_content = re.sub(r"\\[0-9]{2}", "",
-                                 message_content)  # 移除 \32 等无效转义字符
-        message_content = re.sub(r"\d{4}-\d{2}-\dT\d{2}:\d{2}:\d{2}\.\d+Z", "",
-                                 message_content)  # 移除时间戳
+        message_content = re.sub(
+            r"\\[0-9]{2}", "", message_content
+        )  # 移除 \32 等无效转义字符
+        message_content = re.sub(
+            r"\d{4}-\d{2}-\dT\d{2}:\d{2}:\d{2}\.\d+Z", "", message_content
+        )  # 移除时间戳
         message_content = message_content.strip()  # 去除首尾空白字符
 
         # 替换所有的反斜杠
@@ -112,7 +123,9 @@ class OpenAi:
 
         except json.JSONDecodeError as e:
             if first:
-                utils.get_logger().error(f"JSON 解析错误,去除部分特殊字符重新解析一次: {e}")
+                utils.get_logger().error(
+                    f"JSON 解析错误,去除部分特殊字符重新解析一次: {e}"
+                )
                 # 替换中文引号为空
                 message_content = re.sub(r"[“”]", "", response)  # 替换双引号
                 message_content = re.sub(r"[‘’]", "", message_content)  # 替换单引号
@@ -120,23 +133,43 @@ class OpenAi:
             else:
                 raise Exception(f"解析 AI 响应错误: {response} {e}")
 
-    def call_openai_with_image(self, image_path,system_prompt: str, user_prompt: str, api_url: str=None,api_key: str=None,api_model: str=None) -> json:
+    def call_openai_with_image(
+            self,
+            image_path,
+            system_prompt: str,
+            user_prompt: str,
+            api_url: str = None,
+            api_key: str = None,
+            api_model: str = None,
+    ) -> json:
         pass
 
-    
-    def call_openai_with_file(self, file_path,system_prompt: str, user_prompt: str, api_url: str=None,api_key: str=None,api_model: str=None)->json:
+    def call_openai_with_file(
+            self,
+            file_path,
+            system_prompt: str,
+            user_prompt: str,
+            api_url: str = None,
+            api_key: str = None,
+            api_model: str = None,
+    ) -> json:
         self.check_api(api_key, api_model, api_url)
-        utils.get_logger().info(f"调用AI API File==> Url:{self._api_url},Model:{self._api_model}")
+        utils.get_logger().info(
+            f"调用AI API File==> Url:{self._api_url},Model:{self._api_model}"
+        )
 
         client = OpenAI(api_key=self._api_key, base_url=self._api_url)
-        file_object = client.files.create( file=Path(file_path),purpose='file-extract',)
+        file_object = client.files.create(
+            file=Path(file_path),
+            purpose="file-extract",
+        )
         completion = client.chat.completions.create(
             model=self._api_model,
             messages=[
                 {
                     "role": "system",
                     # "content": system_prompt,
-                    'content': f'fileid://{file_object.id}'
+                    "content": f"fileid://{file_object.id}",
                 },
                 {
                     "role": "user",
@@ -167,4 +200,4 @@ class OpenAi:
             return result
         except Exception as e:
             raise Exception(f"解析 AI 响应错误: {e}")
-        pass
+        pass

+ 7 - 6
SourceCode/IntelligentRailwayCosting/app/core/api/response.py

@@ -1,4 +1,5 @@
-from typing import Dict, Any, Optional, List
+from typing import Dict, Any, Optional
+
 from flask import jsonify, Response, make_response
 
 
@@ -33,7 +34,7 @@ class ResponseBase:
 
     @staticmethod
     def error(
-        message: str = "操作失败", code: int = 400, data: Optional[Any] = None
+            message: str = "操作失败", code: int = 400, data: Optional[Any] = None
     ) -> Response:
         """错误响应
         Args:
@@ -51,10 +52,10 @@ class ResponseBase:
 
     @staticmethod
     def json_response(
-        success: bool = True,
-        code: int = 200,
-        message: str = "",
-        data: Optional[Any] = None,
+            success: bool = True,
+            code: int = 200,
+            message: str = "",
+            data: Optional[Any] = None,
     ) -> Response:
         """自定义响应
         Args:

+ 9 - 8
SourceCode/IntelligentRailwayCosting/app/core/api/table_response.py

@@ -1,12 +1,17 @@
-from typing import Dict, Any, Optional, List
-from flask import jsonify, Response
+from typing import Dict, List
+
+from flask import Response
+
 from .response import ResponseBase
 
+
 class TableResponse(ResponseBase):
     """表格数据响应结构"""
 
     @staticmethod
-    def success(rows: List[Dict] = None, total: int = 0, message: str = "操作成功") -> Response:
+    def success(
+            rows: List[Dict] = None, total: int = 0, message: str = "操作成功"
+    ) -> Response:
         """表格数据成功响应
         Args:
             rows: 表格数据行列表
@@ -16,11 +21,7 @@ class TableResponse(ResponseBase):
             Dict: 统一的表格数据响应格式
         """
         return ResponseBase.success(
-            data={
-                "rows": rows or [],
-                "total": total
-            },
-            message=message
+            data={"rows": rows or [], "total": total}, message=message
         )
 
     @staticmethod

+ 9 - 3
SourceCode/IntelligentRailwayCosting/app/core/dtos/chapter.py

@@ -1,8 +1,9 @@
-from typing import Optional, List
+from typing import Optional
 from pydantic import BaseModel
 from ..models.chapter import ChapterModel
 from ..models.total_budget_item import TotalBudgetItemModel
 
+
 class ChapterDto(BaseModel):
     # 章节表字段
     item_id: int
@@ -43,6 +44,7 @@ class ChapterDto(BaseModel):
 
     # 总概算条目字段
     budget_id: Optional[int] = None
+
     # project_quantity1: Optional[float] = None
     # project_quantity2: Optional[float] = None
     # budget_value: Optional[float] = None
@@ -63,7 +65,11 @@ class ChapterDto(BaseModel):
     # tax: Optional[float] = None
 
     @classmethod
-    def from_model(cls, chapter_model: ChapterModel, budget_item_model: Optional[TotalBudgetItemModel] = None) -> 'ChapterDto':
+    def from_model(
+            cls,
+            chapter_model: ChapterModel,
+            budget_item_model: Optional[TotalBudgetItemModel] = None,
+    ) -> "ChapterDto":
         """从数据库模型创建DTO对象"""
         dto = cls(
             # 章节表字段
@@ -133,4 +139,4 @@ class ChapterDto(BaseModel):
         return self.model_dump()
 
     class Config:
-        from_attributes = True
+        from_attributes = True

+ 5 - 2
SourceCode/IntelligentRailwayCosting/app/core/dtos/total_budget_item.py

@@ -2,8 +2,10 @@ from pydantic import BaseModel
 from typing import Optional
 from ..models.total_budget_item import TotalBudgetItemModel
 
+
 class TotalBudgetItemDto(BaseModel):
     """总概算条目DTO"""
+
     budget_id: int
     item_id: int
     # project_quantity1: Optional[float] = None
@@ -42,6 +44,7 @@ class TotalBudgetItemDto(BaseModel):
     project_name: Optional[str] = None
     unit: Optional[str] = None
     item_type: Optional[str] = None
+
     # project_code: Optional[str] = None
     # material_price_diff_code: Optional[str] = None
     # summary_method: Optional[str] = None
@@ -60,7 +63,7 @@ class TotalBudgetItemDto(BaseModel):
     # professional_name: Optional[str] = None
 
     @classmethod
-    def from_model(cls, model: TotalBudgetItemModel) -> 'TotalBudgetItemDto':
+    def from_model(cls, model) -> "TotalBudgetItemDto":
         """从数据库模型创建DTO对象"""
         return cls(
             budget_id=model.budget_id,
@@ -124,4 +127,4 @@ class TotalBudgetItemDto(BaseModel):
         return self.model_dump()
 
     class Config:
-        from_attributes = True
+        from_attributes = True

+ 45 - 43
SourceCode/IntelligentRailwayCosting/app/core/models/quota_input.py

@@ -1,53 +1,55 @@
-from sqlalchemy import Column, String, Integer, Float, Text, ForeignKey
+from sqlalchemy import Column, String, Integer, Float, Text
 from sqlalchemy.ext.declarative import declarative_base
 
 Base = declarative_base()
 
+
 class QuotaInputModel(Base):
-    __tablename__ = '定额输入'
+    __tablename__ = "定额输入"
 
-    quota_id = Column('定额序号', Integer, primary_key=True, autoincrement=True)
+    quota_id = Column("定额序号", Integer, primary_key=True, autoincrement=True)
     # budget_id = Column('总概算序号', Integer, ForeignKey('总概算信息.总概算序号'), nullable=False)
     # item_id = Column('条目序号', Integer, ForeignKey('章节表.条目序号'), nullable=False)
-    budget_id = Column('总概算序号', Integer, nullable=False)
-    item_id = Column('条目序号', Integer, nullable=False)
-    quota_code = Column('定额编号', String(255), nullable=False)
-    sequence_number = Column('顺号', Integer)
-    project_name = Column('工程或费用项目名称', String(255))
-    unit = Column('单位', String(20))
-    project_quantity = Column('工程数量', Float)
-    project_quantity_input = Column('工程数量输入', Text)
-    quota_adjustment = Column('定额调整', Text)
-    unit_price = Column('单价', Float)
-    compilation_unit_price = Column('编制期单价', Float)
-    total_price = Column('合价', Float)
-    compilation_total_price = Column('编制期合价', Float)
-    unit_weight = Column('单重', Float)
-    total_weight = Column('合重', Float)
-    labor_cost = Column('人工费', Float)
-    compilation_labor_cost = Column('编制期人工费', Float)
-    material_cost = Column('材料费', Float)
-    compilation_material_cost = Column('编制期材料费', Float)
-    deduct_material_cost = Column('扣料费', Float)
-    compilation_deduct_material_cost = Column('编制期扣料费', Float)
-    mechanical_cost = Column('机械使用费', Float)
-    compilation_mechanical_cost = Column('编制期机械使用费', Float)
-    equipment_cost = Column('设备费', Float)
-    compilation_equipment_cost = Column('编制期设备费', Float)
-    transport_cost = Column('运杂费', Float)
-    compilation_transport_cost = Column('编制期运杂费', Float)
-    quota_workday = Column('定额工日', Float)
-    total_workday = Column('工日合计', Float)
-    workday_salary = Column('工日工资', Float)
-    compilation_workday_salary = Column('编制期工日工资', Float)
-    quota_mechanical_workday = Column('定额机械工日', Float)
-    total_mechanical_workday = Column('机械工合计', Float)
-    mechanical_workday_salary = Column('机械工日工资', Float)
-    compilation_mechanical_workday_salary = Column('编制期机械工日工资', Float)
-    compiler = Column('编制人', String(50))
-    modify_date = Column('修改日期', String(50))
-    quota_consumption = Column('定额消耗', Text)
-    basic_quota = Column('基本定额', String(255))
+    budget_id = Column("总概算序号", Integer, nullable=False)
+    item_id = Column("条目序号", Integer, nullable=False)
+    quota_code = Column("定额编号", String(255), nullable=False)
+    sequence_number = Column("顺号", Integer)
+    project_name = Column("工程或费用项目名称", String(255))
+    unit = Column("单位", String(20))
+    project_quantity = Column("工程数量", Float)
+    project_quantity_input = Column("工程数量输入", Text)
+    quota_adjustment = Column("定额调整", Text)
+    unit_price = Column("单价", Float)
+    compilation_unit_price = Column("编制期单价", Float)
+    total_price = Column("合价", Float)
+    compilation_total_price = Column("编制期合价", Float)
+    unit_weight = Column("单重", Float)
+    total_weight = Column("合重", Float)
+    labor_cost = Column("人工费", Float)
+    compilation_labor_cost = Column("编制期人工费", Float)
+    material_cost = Column("材料费", Float)
+    compilation_material_cost = Column("编制期材料费", Float)
+    deduct_material_cost = Column("扣料费", Float)
+    compilation_deduct_material_cost = Column("编制期扣料费", Float)
+    mechanical_cost = Column("机械使用费", Float)
+    compilation_mechanical_cost = Column("编制期机械使用费", Float)
+    equipment_cost = Column("设备费", Float)
+    compilation_equipment_cost = Column("编制期设备费", Float)
+    transport_cost = Column("运杂费", Float)
+    compilation_transport_cost = Column("编制期运杂费", Float)
+    quota_workday = Column("定额工日", Float)
+    total_workday = Column("工日合计", Float)
+    workday_salary = Column("工日工资", Float)
+    compilation_workday_salary = Column("编制期工日工资", Float)
+    quota_mechanical_workday = Column("定额机械工日", Float)
+    total_mechanical_workday = Column("机械工合计", Float)
+    mechanical_workday_salary = Column("机械工日工资", Float)
+    compilation_mechanical_workday_salary = Column("编制期机械工日工资", Float)
+    compiler = Column("编制人", String(50))
+    modify_date = Column("修改日期", String(50))
+    quota_consumption = Column("定额消耗", Text)
+    basic_quota = Column("基本定额", String(255))
+
     # quota_comprehensive_unit_price = Column('定额综合单价', Float)
     # quota_comprehensive_total_price = Column('定额综合合价', Float)
 
@@ -55,4 +57,4 @@ class QuotaInputModel(Base):
     # chapter = relationship('ChapterModel')
 
     def __repr__(self):
-        return f"<QuotaInput(quota_id={self.quota_id}, quota_code='{self.quota_code}')>"
+        return f"<QuotaInput(quota_id={self.quota_id}, quota_code='{self.quota_code}')>"

+ 28 - 29
SourceCode/IntelligentRailwayCosting/app/core/models/total_budget_info.py

@@ -1,40 +1,39 @@
 from sqlalchemy import Column, String, Integer, Float, Boolean
 from sqlalchemy.ext.declarative import declarative_base
-from sqlalchemy.orm import relationship
 
 Base = declarative_base()
 
+
 class TotalBudgetInfoModel(Base):
-    __tablename__ = '总概算信息'
+    __tablename__ = "总概算信息"
 
+    budget_id = Column("总概算序号", Integer, primary_key=True, autoincrement=True)
+    budget_code = Column("总概算编号", String(50), nullable=False)
+    compilation_scope = Column("编制范围", String(255))
+    project_quantity = Column("工程数量", Float, nullable=False)
+    unit = Column("单位", String(20), nullable=False)
+    budget_value = Column("概算价值", Float)
+    budget_index = Column("概算指标", Float)
+    price_diff_coefficient = Column("价差系数", String(50), nullable=False)
+    price_diff_area = Column("价差区号", String(20), nullable=False)
+    ending_scheme = Column("结尾方案", String(50), nullable=False)
+    material_cost_scheme = Column("材料费方案", String(50), nullable=False)
+    mechanical_cost_scheme = Column("机械费方案", String(50), nullable=False)
+    equipment_cost_scheme = Column("设备费方案", String(50), nullable=False)
+    labor_cost_scheme = Column("工费方案", String(50), nullable=False)
+    compilation_status = Column("编制状态", Integer, nullable=False)
+    train_interference_count = Column("行车干扰次数", Integer)
+    train_interference_10_count = Column("行干10号工次数", Integer)
+    deduct_supplied_materials = Column("扣甲供料", Integer)
+    auto_calculate_quantity = Column("是否自动计算工程量", Boolean)
+    mechanical_depreciation_adjustment = Column("机械折旧费调差系数", Float)
+    construction_supervision_group = Column("施工监理分组", Integer)
+    construction_management_group = Column("建设管理分组", Integer)
+    survey_group = Column("勘察分组", Integer)
+    design_group = Column("设计分组", Integer)
+    compilation_scope_group = Column("编制范围分组", Integer)
+    enable_total_budget_group = Column("启用总概算分组", Integer)
 
-    budget_id = Column('总概算序号', Integer, primary_key=True, autoincrement=True)
-    budget_code = Column('总概算编号', String(50), nullable=False)
-    compilation_scope = Column('编制范围', String(255))
-    project_quantity = Column('工程数量', Float, nullable=False)
-    unit = Column('单位', String(20), nullable=False)
-    budget_value = Column('概算价值', Float)
-    budget_index = Column('概算指标', Float)
-    price_diff_coefficient = Column('价差系数', String(50), nullable=False)
-    price_diff_area = Column('价差区号', String(20), nullable=False)
-    ending_scheme = Column('结尾方案', String(50), nullable=False)
-    material_cost_scheme = Column('材料费方案', String(50), nullable=False)
-    mechanical_cost_scheme = Column('机械费方案', String(50), nullable=False)
-    equipment_cost_scheme = Column('设备费方案', String(50), nullable=False)
-    labor_cost_scheme = Column('工费方案', String(50), nullable=False)
-    compilation_status = Column('编制状态', Integer, nullable=False)
-    train_interference_count = Column('行车干扰次数', Integer)
-    train_interference_10_count = Column('行干10号工次数', Integer)
-    deduct_supplied_materials = Column('扣甲供料', Integer)
-    auto_calculate_quantity = Column('是否自动计算工程量', Boolean)
-    mechanical_depreciation_adjustment = Column('机械折旧费调差系数', Float)
-    construction_supervision_group = Column('施工监理分组', Integer)
-    construction_management_group = Column('建设管理分组', Integer)
-    survey_group = Column('勘察分组', Integer)
-    design_group = Column('设计分组', Integer)
-    compilation_scope_group = Column('编制范围分组', Integer)
-    enable_total_budget_group = Column('启用总概算分组', Integer)
-    
     # items = relationship('TotalBudgetItemModel', back_populates='budget_info', lazy='dynamic')
 
     def __repr__(self):

+ 36 - 36
SourceCode/IntelligentRailwayCosting/app/core/models/total_budget_item.py

@@ -1,48 +1,48 @@
-from sqlalchemy import Column, String, Integer, Float, Text, ForeignKey
+from sqlalchemy import Column, String, Integer, Float, Text
 from sqlalchemy.ext.declarative import declarative_base
-from sqlalchemy.orm import relationship
 
 Base = declarative_base()
 
+
 class TotalBudgetItemModel(Base):
-    __tablename__ = '总概算条目'
-    
+    __tablename__ = "总概算条目"
+
     # budget_id = Column('总概算序号', Integer, ForeignKey('总概算信息.总概算序号'), primary_key=True)
     # item_id = Column('条目序号', Integer, ForeignKey('章节表.条目序号'), primary_key=True)
-    budget_id = Column('总概算序号', Integer, primary_key=True)
-    item_id = Column('条目序号', Integer, primary_key=True)
-    project_quantity1 = Column('工程数量1', Float)
-    project_quantity2 = Column('工程数量2', Float)
-    budget_value = Column('概算价值', Float)
-    budget_index1 = Column('概算指标1', Float)
-    budget_index2 = Column('概算指标2', Float)
-    construction_cost = Column('建筑工程费', Float)
-    installation_cost = Column('安装工程费', Float)
-    equipment_cost = Column('设备工器具', Float)
-    other_cost = Column('其他费', Float)
-    selected_labor_cost = Column('选用工费', Integer)
-    shift_labor_cost = Column('台班工费', Integer)
-    rate_scheme = Column('费率方案', String(50))
-    formula_code = Column('公式代码', String(50))
-    transport_scheme = Column('运输方案', String(50))
-    transport_unit_price = Column('运输单价', Float)
-    parameter_adjustment = Column('参数调整', Text)
-    calculation_formula = Column('计算公式', Text)
-    unit1 = Column('单位1', String(20))
-    unit2 = Column('单位2', String(20))
-    project_quantity1_input = Column('工程数量1输入', String(100))
-    project_quantity2_input = Column('工程数量2输入', String(100))
-    seat_count = Column('座数', Integer)
-    installation_sub_item = Column('安装子目', String(500))
-    cooperation_fee_code = Column('配合费代码', String(50))
-    tax_category = Column('税金类别', String(50))
-    tax_rate = Column('税率', Float)
-    bridge_type = Column('桥梁类型', String(50))
-    electricity_price_category = Column('电价分类', Integer)
-    tax = Column('税金', Float)
+    budget_id = Column("总概算序号", Integer, primary_key=True)
+    item_id = Column("条目序号", Integer, primary_key=True)
+    project_quantity1 = Column("工程数量1", Float)
+    project_quantity2 = Column("工程数量2", Float)
+    budget_value = Column("概算价值", Float)
+    budget_index1 = Column("概算指标1", Float)
+    budget_index2 = Column("概算指标2", Float)
+    construction_cost = Column("建筑工程费", Float)
+    installation_cost = Column("安装工程费", Float)
+    equipment_cost = Column("设备工器具", Float)
+    other_cost = Column("其他费", Float)
+    selected_labor_cost = Column("选用工费", Integer)
+    shift_labor_cost = Column("台班工费", Integer)
+    rate_scheme = Column("费率方案", String(50))
+    formula_code = Column("公式代码", String(50))
+    transport_scheme = Column("运输方案", String(50))
+    transport_unit_price = Column("运输单价", Float)
+    parameter_adjustment = Column("参数调整", Text)
+    calculation_formula = Column("计算公式", Text)
+    unit1 = Column("单位1", String(20))
+    unit2 = Column("单位2", String(20))
+    project_quantity1_input = Column("工程数量1输入", String(100))
+    project_quantity2_input = Column("工程数量2输入", String(100))
+    seat_count = Column("座数", Integer)
+    installation_sub_item = Column("安装子目", String(500))
+    cooperation_fee_code = Column("配合费代码", String(50))
+    tax_category = Column("税金类别", String(50))
+    tax_rate = Column("税率", Float)
+    bridge_type = Column("桥梁类型", String(50))
+    electricity_price_category = Column("电价分类", Integer)
+    tax = Column("税金", Float)
 
     # budget_info = relationship('TotalBudgetInfoModel', back_populates='items')
     # chapter = relationship('ChapterModel')
 
     def __repr__(self):
-        return f"<TotalBudgetItem(budget_id={self.budget_id}, item_id={self.item_id})>"
+        return f"<TotalBudgetItem(budget_id={self.budget_id}, item_id={self.item_id})>"

+ 16 - 8
SourceCode/IntelligentRailwayCosting/app/core/user_session/current_user.py

@@ -1,20 +1,27 @@
 from dataclasses import dataclass
-from typing import Optional, List
+from typing import Optional
 from flask_login import UserMixin
-from core.dtos import UserDto
+
 
 @dataclass
 class CurrentUser(UserMixin):
     """当前用户信息结构体"""
+
     _user_id: Optional[int] = None
     _username: Optional[str] = None
     _item_range: Optional[str] = None
     _specialty: Optional[str] = None
+
     # _auth_supplement_quota: Optional[str] = None
     # _project_supplement: Optional[str] = None
 
-
-    def __init__(self, user_id: Optional[int] = None, username: Optional[str] = None, item_range: Optional[str] = None, specialty: Optional[str] = None):
+    def __init__(
+            self,
+            user_id: Optional[int] = None,
+            username: Optional[str] = None,
+            item_range: Optional[str] = None,
+            specialty: Optional[str] = None,
+    ):
         self._user_id = user_id
         self._username = username
         self._item_range = item_range
@@ -22,16 +29,20 @@ class CurrentUser(UserMixin):
 
     def get_id(self):
         return self.user_id
+
     @property
     def user_id(self):
         """实现Flask-Login要求的get_id方法"""
         return str(self._user_id) if self._user_id else None
+
     @property
     def username(self):
         return self._username
+
     @property
     def item_range(self):
         return self._item_range
+
     @property
     def specialty(self):
         return self._specialty
@@ -45,7 +56,6 @@ class CurrentUser(UserMixin):
         """
         return self.user_id is not None and self.username is not None
 
-
     @property
     def is_admin(self) -> bool:
         """检查用户是否为超级管理员
@@ -53,6 +63,4 @@ class CurrentUser(UserMixin):
         Returns:
             bool: 如果用户是超级管理员返回True,否则返回False
         """
-        return self.username == 'admin'
-
-
+        return self.username == "admin"

+ 1 - 1
SourceCode/IntelligentRailwayCosting/app/stores/quota_input.py

@@ -1,6 +1,6 @@
 from sqlalchemy import and_, or_
 from datetime import datetime
-from typing import Optional, List, Tuple
+from typing import Optional
 
 import tools.db_helper as db_helper
 from core.dtos import QuotaInputDto

+ 1 - 1
SourceCode/IntelligentRailwayCosting/app/stores/railway_costing_sqlserver/log.py

@@ -1,6 +1,6 @@
 from typing import List, Optional, Dict, Any
 from datetime import datetime
-from sqlalchemy import and_, desc
+from sqlalchemy import and_
 import tools.db_helper as db_helper
 
 from core.models import LogModel

+ 57 - 38
SourceCode/IntelligentRailwayCosting/app/test/mysqy_test.py

@@ -1,32 +1,33 @@
 import unittest
-from tools.db_helper.mysql import MySQLHelper
-from tools.db_helper.base import DBHelper, Base
+import tools.db_helper.mysql_helper
+from tools.db_helper.base import Base
 from sqlalchemy import Column, Integer, String
-from typing import Optional, Dict
+
 
 # 定义测试用的模型类
 class TestUser(Base):
-    __tablename__ = 'test_users'
-    
+    __tablename__ = "test_users"
+
     id = Column(Integer, primary_key=True)
     name = Column(String(50), nullable=False)
     email = Column(String(100), unique=True)
 
+
 class TestMySQLHelper(unittest.TestCase):
     @classmethod
     def setUpClass(cls):
         """测试类初始化"""
-        cls.db_helper = MySQLHelper()
+        cls.db_helper = tools.db_helper.mysql_helper.MySQLHelper()
         # 设置测试数据库配置
         cls.test_config = {
-            'host': 'localhost',
-            'port': 3306,
-            'user': 'test_user',
-            'password': 'test_password',
-            'db': 'test_db',
-            'charset': 'utf8mb4'
+            "host": "localhost",
+            "port": 3306,
+            "user": "test_user",
+            "password": "test_password",
+            "db": "test_db",
+            "charset": "utf8mb4",
         }
-        cls.test_db = 'test_db'
+        cls.test_db = "test_db"
 
     def setUp(self):
         """每个测试用例执行前的设置"""
@@ -43,16 +44,16 @@ class TestMySQLHelper(unittest.TestCase):
 
     def test_singleton(self):
         """测试单例模式"""
-        mysql1 = MySQLHelper()
-        mysql2 = MySQLHelper()
+        mysql1 = tools.db_helper.mysql_helper.MySQLHelper()
+        mysql2 = tools.db_helper.mysql_helper.MySQLHelper()
         self.assertIs(mysql1, mysql2)
 
     def test_set_default_config(self):
         """测试设置默认配置"""
-        test_config = {'host': 'test_host', 'port': 3307}
+        test_config = {"host": "test_host", "port": 3307}
         self.db_helper.set_default_config(test_config)
-        self.assertEqual(self.db_helper._default_config['host'], 'test_host')
-        self.assertEqual(self.db_helper._default_config['port'], 3307)
+        self.assertEqual(self.db_helper._default_config["host"], "test_host")
+        self.assertEqual(self.db_helper._default_config["port"], 3307)
 
     def test_get_config_for_database(self):
         """测试获取数据库配置"""
@@ -75,13 +76,15 @@ class TestMySQLHelper(unittest.TestCase):
 
             # 插入测试数据
             insert_sql = "INSERT INTO test_table (name) VALUES (%s)"
-            self.db_helper.execute_non_query(self.test_db, insert_sql, ('test_name',))
+            self.db_helper.execute_non_query(self.test_db, insert_sql, ("test_name",))
 
             # 测试查询
             query_sql = "SELECT * FROM test_table WHERE name = %s"
-            results = self.db_helper.execute_query(self.test_db, query_sql, ('test_name',))
+            results = self.db_helper.execute_query(
+                self.test_db, query_sql, ("test_name",)
+            )
             self.assertTrue(len(results) > 0)
-            self.assertEqual(results[0][1], 'test_name')
+            self.assertEqual(results[0][1], "test_name")
 
         except Exception as e:
             self.fail(f"查询操作测试失败: {str(e)}")
@@ -90,18 +93,23 @@ class TestMySQLHelper(unittest.TestCase):
         """测试标量查询"""
         try:
             # 创建测试表并插入数据
-            self.db_helper.execute_non_query(self.test_db, """
+            self.db_helper.execute_non_query(
+                self.test_db,
+                """
                 CREATE TABLE IF NOT EXISTS test_scalar (
                     id INT PRIMARY KEY AUTO_INCREMENT,
                     value INT NOT NULL
                 )
-            """)
-            self.db_helper.execute_non_query(self.test_db, 
-                "INSERT INTO test_scalar (value) VALUES (%s)", (42,))
+            """,
+            )
+            self.db_helper.execute_non_query(
+                self.test_db, "INSERT INTO test_scalar (value) VALUES (%s)", (42,)
+            )
 
             # 测试标量查询
-            result = self.db_helper.execute_scalar(self.test_db, 
-                "SELECT value FROM test_scalar WHERE id = 1")
+            result = self.db_helper.execute_scalar(
+                self.test_db, "SELECT value FROM test_scalar WHERE id = 1"
+            )
             self.assertEqual(result, 42)
 
         except Exception as e:
@@ -111,30 +119,41 @@ class TestMySQLHelper(unittest.TestCase):
         """测试会话作用域和事务管理"""
         try:
             # 测试成功的事务
-            with self.db_helper.session_scope(self.test_db, self.test_config) as session:
-                user = TestUser(name='test_user', email='test@example.com')
+            with self.db_helper.session_scope(
+                    self.test_db, self.test_config
+            ) as session:
+                user = TestUser(name="test_user", email="test@example.com")
                 session.add(user)
 
             # 验证数据已保存
-            with self.db_helper.session_scope(self.test_db, self.test_config) as session:
-                saved_user = session.query(TestUser).filter_by(name='test_user').first()
+            with self.db_helper.session_scope(
+                    self.test_db, self.test_config
+            ) as session:
+                saved_user = session.query(TestUser).filter_by(name="test_user").first()
                 self.assertIsNotNone(saved_user)
-                self.assertEqual(saved_user.email, 'test@example.com')
+                self.assertEqual(saved_user.email, "test@example.com")
 
             # 测试事务回滚
             with self.assertRaises(Exception):
-                with self.db_helper.session_scope(self.test_db, self.test_config) as session:
-                    user = TestUser(name='rollback_user', email='invalid_email')
+                with self.db_helper.session_scope(
+                        self.test_db, self.test_config
+                ) as session:
+                    user = TestUser(name="rollback_user", email="invalid_email")
                     session.add(user)
                     raise Exception("测试回滚")
 
             # 验证数据已回滚
-            with self.db_helper.session_scope(self.test_db, self.test_config) as session:
-                rollback_user = session.query(TestUser).filter_by(name='rollback_user').first()
+            with self.db_helper.session_scope(
+                    self.test_db, self.test_config
+            ) as session:
+                rollback_user = (
+                    session.query(TestUser).filter_by(name="rollback_user").first()
+                )
                 self.assertIsNone(rollback_user)
 
         except Exception as e:
             self.fail(f"会话作用域测试失败: {str(e)}")
 
-if __name__ == '__main__':
-    unittest.main()
+
+if __name__ == "__main__":
+    unittest.main()

+ 23 - 16
SourceCode/IntelligentRailwayCosting/app/test/sqlserver_test.py

@@ -1,21 +1,24 @@
+import time
 import unittest
-from tools.db_helper.sqlserver_helper import SQLServerHelper
-from tools.db_helper.base import Base
+
 from sqlalchemy import Column, Integer, String
-from typing import Dict, Any
-import time
+
+from tools.db_helper.base import Base
+from tools.db_helper.sqlserver_helper import SQLServerHelper
+
 
 class TestTable(Base):
-    __tablename__ = 'test_table'
+    __tablename__ = "test_table"
     id = Column(Integer, primary_key=True)
     name = Column(String(50), nullable=False)
 
+
 class TestSQLServerHelper(unittest.TestCase):
     @classmethod
     def setUpClass(cls):
         """测试类初始化,创建数据库帮助类实例"""
         cls.db_helper = SQLServerHelper()
-        cls.database = 'Iwb_RecoData2024'  # 使用配置文件中定义的测试数据库
+        cls.database = "Iwb_RecoData2024"  # 使用配置文件中定义的测试数据库
 
     def setUp(self):
         """每个测试用例开始前的准备工作"""
@@ -64,32 +67,34 @@ class TestSQLServerHelper(unittest.TestCase):
         """测试基本数据库操作"""
         try:
             # 测试查询操作
-            query_result = self.db_helper.execute_query(self.database, 'SELECT @@VERSION')
+            query_result = self.db_helper.execute_query(
+                self.database, "SELECT @@VERSION"
+            )
             self.assertIsNotNone(query_result, "查询操作失败")
             self.assertTrue(len(query_result) > 0, "查询结果为空")
 
             # 测试标量查询
-            scalar_result = self.db_helper.execute_scalar(self.database, 'SELECT DB_NAME()')
+            scalar_result = self.db_helper.execute_scalar(
+                self.database, "SELECT DB_NAME()"
+            )
             self.assertIsNotNone(scalar_result, "标量查询失败")
             self.assertEqual(scalar_result, self.database, "数据库名称不匹配")
 
             # 测试非查询操作
             # 创建临时表并插入数据
             self.db_helper.execute_non_query(
-                self.database,
-                "CREATE TABLE #temp_test (id INT, name NVARCHAR(50))"
+                self.database, "CREATE TABLE #temp_test (id INT, name NVARCHAR(50))"
             )
             insert_result = self.db_helper.execute_non_query(
                 self.database,
                 "INSERT INTO #temp_test (id, name) VALUES (:id, :name)",
-                {"id": 1, "name": "test"}
+                {"id": 1, "name": "test"},
             )
             self.assertEqual(insert_result, 1, "插入操作失败")
 
             # 验证插入结果
             result = self.db_helper.execute_scalar(
-                self.database,
-                "SELECT name FROM #temp_test WHERE id = 1"
+                self.database, "SELECT name FROM #temp_test WHERE id = 1"
             )
             self.assertEqual(result, "test", "数据验证失败")
 
@@ -108,7 +113,9 @@ class TestSQLServerHelper(unittest.TestCase):
 
             # 验证回滚成功
             with self.db_helper.session_scope(self.database) as session:
-                result = session.query(TestTable).filter_by(name="test_rollback").first()
+                result = (
+                    session.query(TestTable).filter_by(name="test_rollback").first()
+                )
                 self.assertIsNone(result, "事务回滚失败")
 
             # 测试正常事务提交
@@ -125,6 +132,6 @@ class TestSQLServerHelper(unittest.TestCase):
         except Exception as e:
             self.fail(f"会话管理测试失败: {str(e)}")
 
-if __name__ == '__main__':
-    unittest.main()
 
+if __name__ == "__main__":
+    unittest.main()

+ 1 - 2
SourceCode/IntelligentRailwayCosting/app/tools/db_helper/__init__.py

@@ -1,6 +1,5 @@
-from sqlalchemy.orm import Session
 from contextlib import contextmanager
-from typing import Generator, Optional, Dict, Any
+from typing import Optional, Dict, Any
 
 from .mysql_helper import MySQLHelper
 from .sqlserver_helper import SQLServerHelper

+ 9 - 9
SourceCode/IntelligentRailwayCosting/app/tools/db_helper/sqlserver_helper.py

@@ -41,7 +41,7 @@ class SQLServerHelper(DBHelper):
         )
 
     def _build_connection_string(
-            self, database: str, config: Optional[Dict[str, str]] = None
+        self, database: str, config: Optional[Dict[str, str]] = None
     ) -> str:
         """构建连接字符串"""
         conn_config = self._default_config.copy()
@@ -76,7 +76,7 @@ class SQLServerHelper(DBHelper):
         return conn_url
 
     def get_engine(
-            self, database: str, config: Optional[Dict[str, str]] = None
+        self, database: str, config: Optional[Dict[str, str]] = None
     ) -> Engine:
         """获取或创建数据库引擎"""
         conn_str = self._build_connection_string(database, config)
@@ -87,7 +87,7 @@ class SQLServerHelper(DBHelper):
         return engine
 
     def execute_query(
-            self, database: str, query: str, params: Optional[Dict[str, Any]] = None
+        self, database: str, query: str, params: Optional[Dict[str, Any]] = None
     ) -> List[Tuple]:
         """执行查询并返回结果"""
         with self.session_scope(database) as session:
@@ -95,7 +95,7 @@ class SQLServerHelper(DBHelper):
             return [tuple(row) for row in result.fetchall()]
 
     def execute_non_query(
-            self, database: str, query: str, params: Optional[Dict[str, Any]] = None
+        self, database: str, query: str, params: Optional[Dict[str, Any]] = None
     ) -> int:
         """执行非查询操作(如INSERT, UPDATE, DELETE)"""
         with self.session_scope(database) as session:
@@ -103,7 +103,7 @@ class SQLServerHelper(DBHelper):
             return result.rowcount
 
     def execute_scalar(
-            self, database: str, query: str, params: Optional[Dict[str, Any]] = None
+        self, database: str, query: str, params: Optional[Dict[str, Any]] = None
     ) -> Any:
         """执行查询并返回第一行第一列的值"""
         with self.session_scope(database) as session:
@@ -112,10 +112,10 @@ class SQLServerHelper(DBHelper):
             return row[0] if row else None
 
     def execute_procedure(
-            self,
-            database: str,
-            procedure_name: str,
-            params: Optional[Dict[str, Any]] = None,
+        self,
+        database: str,
+        procedure_name: str,
+        params: Optional[Dict[str, Any]] = None,
     ) -> List[Tuple]:
         """执行存储过程"""
         params = params or {}