YueYunyun 3 月之前
父节点
当前提交
e921c410bc
共有 32 个文件被更改,包括 1994 次插入0 次删除
  1. 0 0
      SourceCode/IntelligentRailwayCosting/app/__init__.py
  2. 38 0
      SourceCode/IntelligentRailwayCosting/app/config.yml
  3. 0 0
      SourceCode/IntelligentRailwayCosting/app/core/__init__.py
  4. 0 0
      SourceCode/IntelligentRailwayCosting/app/core/user_session/__init__.py
  5. 99 0
      SourceCode/IntelligentRailwayCosting/app/core/user_session/current_user.py
  6. 124 0
      SourceCode/IntelligentRailwayCosting/app/core/user_session/user_session.py
  7. 13 0
      SourceCode/IntelligentRailwayCosting/app/main.py
  8. 0 0
      SourceCode/IntelligentRailwayCosting/app/models/__init__.py
  9. 43 0
      SourceCode/IntelligentRailwayCosting/app/models/project.py
  10. 16 0
      SourceCode/IntelligentRailwayCosting/app/models/team.py
  11. 19 0
      SourceCode/IntelligentRailwayCosting/app/models/user.py
  12. 0 0
      SourceCode/IntelligentRailwayCosting/app/routes/__init__.py
  13. 0 0
      SourceCode/IntelligentRailwayCosting/app/services/__init__.py
  14. 23 0
      SourceCode/IntelligentRailwayCosting/app/services/user.py
  15. 0 0
      SourceCode/IntelligentRailwayCosting/app/stores/__init__.py
  16. 33 0
      SourceCode/IntelligentRailwayCosting/app/stores/user.py
  17. 0 0
      SourceCode/IntelligentRailwayCosting/app/test/__init__.py
  18. 140 0
      SourceCode/IntelligentRailwayCosting/app/test/mysqy_test.py
  19. 130 0
      SourceCode/IntelligentRailwayCosting/app/test/sqlserver_test.py
  20. 0 0
      SourceCode/IntelligentRailwayCosting/app/tools/__init__.py
  21. 0 0
      SourceCode/IntelligentRailwayCosting/app/tools/db_helper/__init__.py
  22. 196 0
      SourceCode/IntelligentRailwayCosting/app/tools/db_helper/base.py
  23. 178 0
      SourceCode/IntelligentRailwayCosting/app/tools/db_helper/mysql_helper.py
  24. 107 0
      SourceCode/IntelligentRailwayCosting/app/tools/db_helper/sqlserver_helper.py
  25. 168 0
      SourceCode/IntelligentRailwayCosting/app/tools/utils/__init__.py
  26. 168 0
      SourceCode/IntelligentRailwayCosting/app/tools/utils/ai_helper.py
  27. 91 0
      SourceCode/IntelligentRailwayCosting/app/tools/utils/config_helper.py
  28. 186 0
      SourceCode/IntelligentRailwayCosting/app/tools/utils/file_helper.py
  29. 113 0
      SourceCode/IntelligentRailwayCosting/app/tools/utils/logger_helper.py
  30. 103 0
      SourceCode/IntelligentRailwayCosting/app/tools/utils/string_helper.py
  31. 0 0
      SourceCode/IntelligentRailwayCosting/app/views/__init__.py
  32. 6 0
      SourceCode/IntelligentRailwayCosting/requirements.txt

+ 0 - 0
SourceCode/IntelligentRailwayCosting/app/__init__.py


+ 38 - 0
SourceCode/IntelligentRailwayCosting/app/config.yml

@@ -0,0 +1,38 @@
+db:
+  # SQL Server 配置
+  # SQL Server 2008:'{SQL Server}' 或 '{SQL Server Native Client 10.0}'
+  # SQL Server 2016:'{ODBC Driver 13 for SQL Server}'
+  # SQL Server 2020:'{ODBC Driver 17 for SQL Server}'
+  # SQL Server 2022:'{ODBC Driver 18 for SQL Server}'
+  # 在Windows系统的ODBC数据源管理器中查看已安装的驱动程序,选择相应的驱动名称
+  # 在开始菜单的列表里面找到"Windows管理工具"打开, 然后点开里面的"ODBC数据源"。
+  # 打开以后,点开上方"驱动程序"。 就可以看到系统所安装的ODBC驱动程序
+  sqlserver_mian:
+    driver: '{ODBC Driver 17 for SQL Server}'
+    server: shvber.com,50535
+    username: sa
+    password: Iwb2017
+    database: Iwb_RecoData2024
+    trusted_connection: false
+  Iwb_RecoData2024:
+    driver: '{ODBC Driver 17 for SQL Server}'
+    server: shvber.com,50535
+    username: sa
+    password: Iwb2017
+    database: Iwb_RecoData2024
+    trusted_connection: false
+  # MySQL 配置
+  mysql_main:
+    db: iwb_data
+    host: localhost
+    port: 3306
+    user: root
+    password: your_password
+    charset: utf8mb4
+  # MySQL 示例数据库配置
+  example_db:
+    host: localhost
+    port: 3306
+    user: example_user
+    password: example_password
+    charset: utf8mb4

+ 0 - 0
SourceCode/IntelligentRailwayCosting/app/core/__init__.py


+ 0 - 0
SourceCode/IntelligentRailwayCosting/app/core/user_session/__init__.py


+ 99 - 0
SourceCode/IntelligentRailwayCosting/app/core/user_session/current_user.py

@@ -0,0 +1,99 @@
+from dataclasses import dataclass
+from typing import Optional, List
+from flask_login import UserMixin
+
+@dataclass
+class CurrentUser(UserMixin):
+    """当前用户信息结构体"""
+    user_id: Optional[int] = None
+    username: Optional[str] = None
+    name: Optional[str] = None
+    roles: List[str] = None
+
+    def __post_init__(self):
+        """初始化角色列表"""
+        if self.roles is None:
+            self.roles = []
+
+    def get_id(self):
+        """实现Flask-Login要求的get_id方法"""
+        return str(self.user_id) if self.user_id else None
+
+    @property
+    def is_authenticated(self) -> bool:
+        """检查用户是否已认证
+
+        Returns:
+            bool: 如果用户已认证返回True,否则返回False
+        """
+        return self.user_id is not None and self.username is not None
+
+    @property
+    def is_super_admin(self) -> bool:
+        """检查用户是否为超级管理员
+
+        Returns:
+            bool: 如果用户是超级管理员返回True,否则返回False
+        """
+        return 'super_admin' in self.roles
+
+    @property
+    def is_admin(self) -> bool:
+        """检查用户是否为超级管理员
+
+        Returns:
+            bool: 如果用户是超级管理员返回True,否则返回False
+        """
+        return self.username == 'admin' or 'admin' in self.roles
+
+    @property
+    def is_sys(self) -> bool:
+        """检查用户是否为系统管理员
+
+        Returns:
+            bool: 如果用户是系统管理员返回True,否则返回False
+        """
+        return self.is_admin or 'sys' in self.roles
+
+    @property
+    def is_edit(self) -> bool:
+        """检查用户是否有编辑权限
+
+        Returns:
+            bool: 如果用户有编辑权限返回True,否则返回False
+        """
+        return self.is_admin or 'edit' in self.roles
+
+
+    def has_role(self, role: str) -> bool:
+        """检查用户是否拥有指定角色
+
+        Args:
+            role (str): 角色名称
+
+        Returns:
+            bool: 如果用户拥有指定角色返回True,否则返回False
+        """
+        return role in self.roles
+    
+    def has_any_role(self, roles: List[str]) -> bool:
+        """检查用户是否拥有指定角色列表中的任意一个角色
+
+        Args:
+            roles (List[str]): 角色名称列表
+
+        Returns:
+            bool: 如果用户拥有指定角色列表中的任意一个角色返回True,否则返回False
+        """
+        return any(role in self.roles for role in roles)
+    
+    def has_all_roles(self, roles: List[str]) -> bool:
+        """检查用户是否拥有指定角色列表中的所有角色
+
+        Args:
+            roles (List[str]): 角色名称列表
+
+        Returns:
+            bool: 如果用户拥有指定角色列表中的所有角色返回True,否则返回False
+        """
+        return all(role in self.roles for role in roles)

+ 124 - 0
SourceCode/IntelligentRailwayCosting/app/core/user_session/user_session.py

@@ -0,0 +1,124 @@
+from flask import session
+from typing import Optional
+from .current_user import CurrentUser
+
+
+class UserSession:
+    """用户会话管理类"""
+
+    @staticmethod
+    def set_user(user_id: int, username: str, name: str, roles: Optional[list] = None) -> None:
+        """设置用户登录状态
+
+        Args:
+            user_id (int): 用户ID
+            username (str): 用户名
+            name (str): 用户姓名
+            roles (Optional[list], optional): 用户角色列表. Defaults to None.
+        """
+        session['user_id'] = user_id
+        session['username'] = username
+        session['name'] = name
+        session['roles'] = roles or []
+
+    @staticmethod
+    def get_current_user() -> CurrentUser:
+        """获取当前登录用户信息
+
+        Returns:
+            CurrentUser: 返回当前用户信息结构体
+        """
+        return CurrentUser(
+            user_id=session.get('user_id'),
+            username=session.get('username'),
+            name=session.get('name'),
+            roles=session.get('roles', [])
+        )
+    
+    @staticmethod
+    def get_current_username() -> Optional[str]:
+        """获取当前登录用户名
+
+        Returns:
+            Optional[str]: 返回用户名,未登录则返回None
+        """
+        return session.get('username')
+    
+    @staticmethod
+    def get_current_user_id() -> Optional[int]:
+        """获取当前登录用户ID
+
+        Returns:
+            Optional[int]: 返回用户ID,未登录则返回None
+        """
+        return session.get('user_id')
+    
+    @staticmethod
+    def get_current_name() -> Optional[str]:
+        """获取当前登录用户姓名
+
+        Returns:
+            Optional[str]: 返回用户姓名,未登录则返回None
+        """
+        return session.get('name')
+
+    
+    @staticmethod
+    def clear_user() -> None:
+        """清除用户登录状态"""
+        session.pop('user_id', None)
+        session.pop('username', None)
+        session.pop('name', None)
+        session.pop('roles', None)
+    
+    @staticmethod
+    def is_logged_in() -> bool:
+        """检查用户是否已登录
+
+        Returns:
+            bool: 如果用户已登录返回True,否则返回False
+        """
+        return 'user_id' in session and 'username' in session
+    
+    @staticmethod
+    def is_admin() -> bool:
+        """检查当前用户是否为管理员
+
+        Returns:
+            bool: 如果当前用户是管理员返回True,否则返回False
+        """
+        return session.get('username') == 'admin'
+    
+    @staticmethod
+    def get_current_roles() -> list:
+        """获取当前用户角色列表
+
+        Returns:
+            list: 返回用户角色列表,未登录则返回空列表
+        """
+        return session.get('roles', [])
+
+    @staticmethod
+    def has_role(role: str) -> bool:
+        """检查当前用户是否拥有指定角色
+
+        Args:
+            role (str): 角色名称
+
+        Returns:
+            bool: 如果用户拥有指定角色返回True,否则返回False
+        """
+        return role in session.get('roles', [])
+
+    @staticmethod
+    def has_all_roles(roles: list) -> bool:
+        """检查当前用户是否拥有指定角色列表中的所有角色
+
+        Args:
+            roles (list): 角色名称列表
+
+        Returns:
+            bool: 如果用户拥有指定角色列表中的所有角色返回True,否则返回False
+        """
+        user_roles = session.get('roles', [])
+        return all(role in user_roles for role in roles)

+ 13 - 0
SourceCode/IntelligentRailwayCosting/app/main.py

@@ -0,0 +1,13 @@
+from services.user import UserService
+def main():
+    user_service = UserService()
+    user = user_service.get_user_by_id(1)
+    if user:
+        print(user.username)
+        print(user.password)
+
+    print(user)
+    pass
+
+if __name__ == '__main__':
+    main()

+ 0 - 0
SourceCode/IntelligentRailwayCosting/app/models/__init__.py


+ 43 - 0
SourceCode/IntelligentRailwayCosting/app/models/project.py

@@ -0,0 +1,43 @@
+from sqlalchemy import Column, String, Integer, Float, DateTime, Boolean, Text
+from sqlalchemy.ext.declarative import declarative_base
+
+Base = declarative_base()
+
+class Project(Base):
+    __tablename__ = '项目信息'
+
+    project_id = Column('项目编号', String(30), primary_key=True)
+    compilation_method = Column('编制办法文号', String(50))
+    project_name = Column('建设项目名称', String(255))
+    short_name = Column('简称', String(10))
+    design_stage = Column('设计阶段', String(50))
+    compilation_scope = Column('编制范围', String(255))
+    total_engineering = Column('工程总量', Float)
+    unit = Column('单位', String(20))
+    project_manager = Column('项目负责人', String(20))
+    total_budget = Column('概算总值', Float)
+    budget_index = Column('概算指标', Float)
+    standard_quota = Column('标准定额应用', Text)
+    train_transport_standard = Column('火车运输标准', String(50))
+    project_version = Column('项目版本号', String(50))
+    create_time = Column('创建时间', DateTime)
+    material_library = Column('材料库', String(50))
+    work_shift_library = Column('台班库', String(50))
+    equipment_library = Column('设备库', String(50))
+    review_status = Column('审查状态', Integer)
+    years_to_construction = Column('编制年至开工年年限', Integer)
+    project_password = Column('项目密码', String(10))
+    railway_grade = Column('铁路等级', String(10))
+    main_line_count = Column('正线数目', Integer)
+    traction_type = Column('牵引种类', String(10))
+    blocking_mode = Column('闭塞方式', String(10))
+    station_count = Column('车站数量', String(50))
+    project_description = Column('项目简介', Text)
+    target_speed = Column('速度目标值', Integer)
+    print_compilation_review = Column('打印编制复核', Boolean)
+    project_type = Column('项目类型', String(20))
+    unit_conversion = Column('单位换算', Boolean)
+    completion_status = Column('完成状态', String(10))
+
+    def __repr__(self):
+        return f"<Project(project_id='{self.project_id}', project_name='{self.project_name}')>"

+ 16 - 0
SourceCode/IntelligentRailwayCosting/app/models/team.py

@@ -0,0 +1,16 @@
+from sqlalchemy import Column, String, Integer, Text
+from sqlalchemy.ext.declarative import declarative_base
+
+Base = declarative_base()
+
+class Team(Base):
+    __tablename__ = '团队人员'
+
+    project_id = Column('项目编号', String(30), primary_key=True)
+    name = Column('姓名', String(50), primary_key=True)
+    operation_permission = Column('操作权限', String(2000))
+    item_number = Column('条目编号', Text)
+    compilation_status = Column('编制状态', Integer)
+
+    def __repr__(self):
+        return f"<Team(project_id='{self.project_id}', name='{self.name}')>"

+ 19 - 0
SourceCode/IntelligentRailwayCosting/app/models/user.py

@@ -0,0 +1,19 @@
+from sqlalchemy import Column, Integer, String, Text
+from sqlalchemy.ext.declarative import declarative_base
+
+Base = declarative_base()
+
+class User(Base):
+    __tablename__ = '系统用户'
+
+    id = Column('序号', Integer, primary_key=True)
+    order_number = Column('顺号', Integer, nullable=False)
+    username = Column('用户名称', String(20))
+    password = Column('用户密码', String(20))
+    specialty = Column('专业名称', String(50))
+    auth_supplement_quota = Column('授权补充定额', Integer)
+    item_range_30 = Column('条目范围30', Text)
+    project_supplement = Column('项目补充', Integer)
+
+    def __repr__(self):
+        return f"<User(username='{self.username}', specialty='{self.specialty}')"

+ 0 - 0
SourceCode/IntelligentRailwayCosting/app/routes/__init__.py


+ 0 - 0
SourceCode/IntelligentRailwayCosting/app/services/__init__.py


+ 23 - 0
SourceCode/IntelligentRailwayCosting/app/services/user.py

@@ -0,0 +1,23 @@
+from stores.user import UserStore
+from models.user import User
+
+class UserService:
+    def __init__(self):
+        self.user_store = UserStore()
+
+    def get_user_by_id(self, user_id: int) -> User:
+        user = self.user_store.get_user_by_id(user_id)
+        return user
+
+    def get_user_by_username(self, username: str) -> User:
+        user = self.user_store.get_user_by_username(username)
+        return user
+
+    def authenticate_user(self, username: str, password: str) -> User:
+        user = self.user_store.authenticate_user(username, password)
+        return user
+
+    def get_all_users(self) -> list[User]:
+        users = self.user_store.get_all_users()
+        return users
+

+ 0 - 0
SourceCode/IntelligentRailwayCosting/app/stores/__init__.py


+ 33 - 0
SourceCode/IntelligentRailwayCosting/app/stores/user.py

@@ -0,0 +1,33 @@
+from sqlalchemy.orm import Session
+from typing import Optional, List
+from models.user import User
+from tools.db_helper.sqlserver_helper import SQLServerHelper
+
+class UserStore:
+    def __init__(self, session: Session = None):
+        if session is None:
+            session_maker = SQLServerHelper().get_session_maker("db.sqlserver_mian")
+            self.session = session_maker()
+        else:
+            self.session = session
+
+    def get_user_by_id(self, user_id: int) -> Optional[User]:
+        """根据用户ID获取用户信息"""
+        return self.session.query(User).filter(User.id == user_id).first()
+
+    def get_user_by_username(self, username: str) -> Optional[User]:
+        """根据用户名获取用户信息"""
+        return self.session.query(User).filter(User.username == username).first()
+
+    def get_all_users(self) -> List[User]:
+        """获取所有用户列表"""
+        return self.session.query(User).all()
+
+
+
+    def authenticate_user(self, username: str, password: str) -> Optional[User]:
+        """用户认证"""
+        user = self.get_user_by_username(username)
+        if user and user.password == password:
+            return user
+        return None

+ 0 - 0
SourceCode/IntelligentRailwayCosting/app/test/__init__.py


+ 140 - 0
SourceCode/IntelligentRailwayCosting/app/test/mysqy_test.py

@@ -0,0 +1,140 @@
+import unittest
+from tools.db_helper.mysql import MySQLHelper
+from tools.db_helper.base import DBHelper, Base
+from sqlalchemy import Column, Integer, String
+from typing import Optional, Dict
+
+# 定义测试用的模型类
+class TestUser(Base):
+    __tablename__ = 'test_users'
+    
+    id = Column(Integer, primary_key=True)
+    name = Column(String(50), nullable=False)
+    email = Column(String(100), unique=True)
+
+class TestMySQLHelper(unittest.TestCase):
+    @classmethod
+    def setUpClass(cls):
+        """测试类初始化"""
+        cls.db_helper = MySQLHelper()
+        # 设置测试数据库配置
+        cls.test_config = {
+            'host': 'localhost',
+            'port': 3306,
+            'user': 'test_user',
+            'password': 'test_password',
+            'db': 'test_db',
+            'charset': 'utf8mb4'
+        }
+        cls.test_db = 'test_db'
+
+    def setUp(self):
+        """每个测试用例执行前的设置"""
+        # 创建测试表
+        engine = self.db_helper.get_engine(self.test_db, self.test_config)
+        Base.metadata.create_all(engine)
+
+    def tearDown(self):
+        """每个测试用例执行后的清理"""
+        # 删除测试表
+        engine = self.db_helper.get_engine(self.test_db, self.test_config)
+        Base.metadata.drop_all(engine)
+        self.db_helper.dispose_all()
+
+    def test_singleton(self):
+        """测试单例模式"""
+        mysql1 = MySQLHelper()
+        mysql2 = MySQLHelper()
+        self.assertIs(mysql1, mysql2)
+
+    def test_set_default_config(self):
+        """测试设置默认配置"""
+        test_config = {'host': 'test_host', 'port': 3307}
+        self.db_helper.set_default_config(test_config)
+        self.assertEqual(self.db_helper._default_config['host'], 'test_host')
+        self.assertEqual(self.db_helper._default_config['port'], 3307)
+
+    def test_get_config_for_database(self):
+        """测试获取数据库配置"""
+        try:
+            config = self.db_helper.get_config_for_database(self.test_db)
+            self.assertIsInstance(config, dict)
+        except Exception as e:
+            self.fail(f"获取数据库配置失败: {str(e)}")
+
+    def test_execute_query(self):
+        """测试查询操作"""
+        try:
+            # 创建测试表
+            create_table_sql = """
+            CREATE TABLE IF NOT EXISTS test_table (
+                id INT PRIMARY KEY AUTO_INCREMENT,
+                name VARCHAR(50) NOT NULL
+            )"""
+            self.db_helper.execute_non_query(self.test_db, create_table_sql)
+
+            # 插入测试数据
+            insert_sql = "INSERT INTO test_table (name) VALUES (%s)"
+            self.db_helper.execute_non_query(self.test_db, insert_sql, ('test_name',))
+
+            # 测试查询
+            query_sql = "SELECT * FROM test_table WHERE name = %s"
+            results = self.db_helper.execute_query(self.test_db, query_sql, ('test_name',))
+            self.assertTrue(len(results) > 0)
+            self.assertEqual(results[0][1], 'test_name')
+
+        except Exception as e:
+            self.fail(f"查询操作测试失败: {str(e)}")
+
+    def test_execute_scalar(self):
+        """测试标量查询"""
+        try:
+            # 创建测试表并插入数据
+            self.db_helper.execute_non_query(self.test_db, """
+                CREATE TABLE IF NOT EXISTS test_scalar (
+                    id INT PRIMARY KEY AUTO_INCREMENT,
+                    value INT NOT NULL
+                )
+            """)
+            self.db_helper.execute_non_query(self.test_db, 
+                "INSERT INTO test_scalar (value) VALUES (%s)", (42,))
+
+            # 测试标量查询
+            result = self.db_helper.execute_scalar(self.test_db, 
+                "SELECT value FROM test_scalar WHERE id = 1")
+            self.assertEqual(result, 42)
+
+        except Exception as e:
+            self.fail(f"标量查询测试失败: {str(e)}")
+
+    def test_session_scope(self):
+        """测试会话作用域和事务管理"""
+        try:
+            # 测试成功的事务
+            with self.db_helper.session_scope(self.test_db, self.test_config) as session:
+                user = TestUser(name='test_user', email='test@example.com')
+                session.add(user)
+
+            # 验证数据已保存
+            with self.db_helper.session_scope(self.test_db, self.test_config) as session:
+                saved_user = session.query(TestUser).filter_by(name='test_user').first()
+                self.assertIsNotNone(saved_user)
+                self.assertEqual(saved_user.email, 'test@example.com')
+
+            # 测试事务回滚
+            with self.assertRaises(Exception):
+                with self.db_helper.session_scope(self.test_db, self.test_config) as session:
+                    user = TestUser(name='rollback_user', email='invalid_email')
+                    session.add(user)
+                    raise Exception("测试回滚")
+
+            # 验证数据已回滚
+            with self.db_helper.session_scope(self.test_db, self.test_config) as session:
+                rollback_user = session.query(TestUser).filter_by(name='rollback_user').first()
+                self.assertIsNone(rollback_user)
+
+        except Exception as e:
+            self.fail(f"会话作用域测试失败: {str(e)}")
+
+if __name__ == '__main__':
+    unittest.main()

+ 130 - 0
SourceCode/IntelligentRailwayCosting/app/test/sqlserver_test.py

@@ -0,0 +1,130 @@
+import unittest
+from tools.db_helper.sqlserver_helper import SQLServerHelper
+from tools.db_helper.base import Base
+from sqlalchemy import Column, Integer, String
+from typing import Dict, Any
+import time
+
+class TestTable(Base):
+    __tablename__ = 'test_table'
+    id = Column(Integer, primary_key=True)
+    name = Column(String(50), nullable=False)
+
+class TestSQLServerHelper(unittest.TestCase):
+    @classmethod
+    def setUpClass(cls):
+        """测试类初始化,创建数据库帮助类实例"""
+        cls.db_helper = SQLServerHelper()
+        cls.database = 'Iwb_RecoData2024'  # 使用配置文件中定义的测试数据库
+
+    def setUp(self):
+        """每个测试用例开始前的准备工作"""
+        self.start_time = time.time()
+        # 创建测试表
+        engine = self.db_helper.get_engine(self.database)
+        Base.metadata.create_all(engine)
+
+    def tearDown(self):
+        """每个测试用例结束后的清理工作"""
+        # 删除测试表
+        engine = self.db_helper.get_engine(self.database)
+        Base.metadata.drop_all(engine)
+        test_duration = time.time() - self.start_time
+        print(f"\n测试用例耗时: {test_duration:.3f}秒")
+
+    @classmethod
+    def tearDownClass(cls):
+        """测试类结束时的清理工作"""
+        cls.db_helper.dispose_all()
+
+    def test_singleton(self):
+        """测试单例模式"""
+        helper1 = SQLServerHelper()
+        helper2 = SQLServerHelper()
+        self.assertIs(helper1, helper2)
+
+    def test_database_connection(self):
+        """测试数据库连接"""
+        try:
+            # 测试获取数据库引擎
+            engine = self.db_helper.get_engine(self.database)
+            self.assertIsNotNone(engine, "数据库引擎创建失败")
+
+            # 测试会话作用域
+            with self.db_helper.session_scope(self.database) as session:
+                self.assertIsNotNone(session, "数据库会话创建失败")
+                # 测试简单查询
+                result = session.execute("SELECT 1").scalar()
+                self.assertEqual(result, 1, "数据库连接测试失败")
+
+        except Exception as e:
+            self.fail(f"数据库连接测试失败: {str(e)}")
+
+    def test_basic_operations(self):
+        """测试基本数据库操作"""
+        try:
+            # 测试查询操作
+            query_result = self.db_helper.execute_query(self.database, 'SELECT @@VERSION')
+            self.assertIsNotNone(query_result, "查询操作失败")
+            self.assertTrue(len(query_result) > 0, "查询结果为空")
+
+            # 测试标量查询
+            scalar_result = self.db_helper.execute_scalar(self.database, 'SELECT DB_NAME()')
+            self.assertIsNotNone(scalar_result, "标量查询失败")
+            self.assertEqual(scalar_result, self.database, "数据库名称不匹配")
+
+            # 测试非查询操作
+            # 创建临时表并插入数据
+            self.db_helper.execute_non_query(
+                self.database,
+                "CREATE TABLE #temp_test (id INT, name NVARCHAR(50))"
+            )
+            insert_result = self.db_helper.execute_non_query(
+                self.database,
+                "INSERT INTO #temp_test (id, name) VALUES (:id, :name)",
+                {"id": 1, "name": "test"}
+            )
+            self.assertEqual(insert_result, 1, "插入操作失败")
+
+            # 验证插入结果
+            result = self.db_helper.execute_scalar(
+                self.database,
+                "SELECT name FROM #temp_test WHERE id = 1"
+            )
+            self.assertEqual(result, "test", "数据验证失败")
+
+        except Exception as e:
+            self.fail(f"基本操作测试失败: {str(e)}")
+
+    def test_session_management(self):
+        """测试会话管理和事务"""
+        try:
+            # 测试事务回滚
+            with self.assertRaises(Exception):
+                with self.db_helper.session_scope(self.database) as session:
+                    test_record = TestTable(name="test_rollback")
+                    session.add(test_record)
+                    raise Exception("触发回滚")
+
+            # 验证回滚成功
+            with self.db_helper.session_scope(self.database) as session:
+                result = session.query(TestTable).filter_by(name="test_rollback").first()
+                self.assertIsNone(result, "事务回滚失败")
+
+            # 测试正常事务提交
+            with self.db_helper.session_scope(self.database) as session:
+                test_record = TestTable(name="test_commit")
+                session.add(test_record)
+
+            # 验证提交成功
+            with self.db_helper.session_scope(self.database) as session:
+                result = session.query(TestTable).filter_by(name="test_commit").first()
+                self.assertIsNotNone(result, "事务提交失败")
+                self.assertEqual(result.name, "test_commit")
+
+        except Exception as e:
+            self.fail(f"会话管理测试失败: {str(e)}")
+
+if __name__ == '__main__':
+    unittest.main()
+

+ 0 - 0
SourceCode/IntelligentRailwayCosting/app/tools/__init__.py


+ 0 - 0
SourceCode/IntelligentRailwayCosting/app/tools/db_helper/__init__.py


+ 196 - 0
SourceCode/IntelligentRailwayCosting/app/tools/db_helper/base.py

@@ -0,0 +1,196 @@
+from typing import Dict, Optional, Any, List, Tuple, Generator
+from contextlib import contextmanager
+import threading
+import tools.utils as utils
+from sqlalchemy.orm import sessionmaker, declarative_base
+
+# 创建基础模型类
+Base = declarative_base()
+
+class DBHelper:
+    _instance = None
+    _lock = threading.Lock()
+    _main_config_key = ""
+    def __new__(cls, *args, **kwargs):
+        with cls._lock:
+            if cls._instance is None:
+                cls._instance = super(DBHelper, cls).__new__(cls)
+                cls._instance._initialized = False
+            return cls._instance
+    
+    def __init__(self):
+        if self._initialized:
+            return
+            
+        self._config_cache: Dict[str, Dict[str, str]] = {}
+        self._engines: Dict[str, Any] = {}
+        self._sessions: Dict[str, sessionmaker] = {}
+        self._default_config = {
+            'pool_size': 5,
+            'max_overflow': 10,
+            'pool_timeout': 30,
+            'pool_recycle': 3600
+        }
+        self._initialized = True
+        
+    def set_default_config(self, config: Dict[str, str]) -> None:
+        """设置默认连接配置"""
+        self._default_config.update(config)
+    
+    def get_config_for_database(self, database: str) -> Dict[str, str]:
+        """获取数据库的连接配置
+        
+        按照以下顺序查找配置:
+        1. 配置缓存
+        2. 配置文件中特定数据库的配置
+        3. 配置文件中的main_config配置
+        4. 默认配置
+        
+        Args:
+            database: 数据库名称
+            
+        Returns:
+            数据库连接配置
+            
+        Raises:
+            Exception: 当找不到特定数据库配置且main_config配置也不存在时
+        """
+        if database in self._config_cache:
+            return self._config_cache[database]
+        
+        db_config = utils.get_config_object(f"db.{database}")
+        if db_config:
+            self._config_cache[database] = db_config
+            return db_config
+        
+        main_config = utils.get_config_object(self._main_config_key)
+        if not main_config:
+            raise Exception(f"未找到数据库 {database} 的配置,且main_config配置不存在")
+        
+        self._config_cache[database] = main_config
+        return main_config
+    
+    def execute_query(self, database: str, query: str, params: Optional[Any] = None) -> List[Tuple]:
+        """执行查询并返回结果
+        
+        Args:
+            database: 数据库名称
+            query: SQL查询语句
+            params: 查询参数
+            
+        Returns:
+            查询结果列表
+        """
+        raise NotImplementedError("子类必须实现execute_query方法")
+    
+    def execute_non_query(self, database: str, query: str, params: Optional[Any] = None) -> int:
+        """执行非查询操作(如INSERT, UPDATE, DELETE)
+        
+        Args:
+            database: 数据库名称
+            query: SQL语句
+            params: 查询参数
+            
+        Returns:
+            受影响的行数
+        """
+        raise NotImplementedError("子类必须实现execute_non_query方法")
+    
+    def execute_scalar(self, database: str, query: str, params: Optional[Any] = None) -> Any:
+        """执行查询并返回第一行第一列的值
+        
+        Args:
+            database: 数据库名称
+            query: SQL查询语句
+            params: 查询参数
+            
+        Returns:
+            查询结果的第一行第一列的值
+        """
+        raise NotImplementedError("子类必须实现execute_scalar方法")
+    
+    def execute_procedure(self, database: str, procedure_name: str, params: Optional[Dict[str, Any]] = None) -> List[Tuple]:
+        """执行存储过程
+        
+        Args:
+            database: 数据库名称
+            procedure_name: 存储过程名称
+            params: 存储过程参数
+            
+        Returns:
+            存储过程执行结果
+        """
+        raise NotImplementedError("子类必须实现execute_procedure方法")
+    
+    def get_engine(self, database: str, config: Optional[Dict[str, Any]] = None) -> Any:
+        """获取或创建数据库引擎
+        
+        Args:
+            database: 数据库名称
+            config: 数据库配置信息
+            
+        Returns:
+            SQLAlchemy引擎实例
+        """
+        raise NotImplementedError("子类必须实现get_engine方法")
+    
+    def get_session_maker(self, database: str, config: Optional[Dict[str, Any]] = None) -> sessionmaker:
+        """获取或创建会话工厂
+        
+        Args:
+            database: 数据库名称
+            config: 数据库配置信息
+            
+        Returns:
+            会话工厂实例
+        """
+        if database in self._sessions:
+            return self._sessions[database]
+        
+        engine = self.get_engine(database, config)
+        session = sessionmaker(bind=engine)
+        self._sessions[database] = session
+        return session
+    
+    @contextmanager
+    def session_scope(self, database: str, config: Optional[Dict[str, Any]] = None) -> Generator:
+        """创建会话的上下文管理器
+        
+        Args:
+            database: 数据库名称
+            config: 数据库配置信息
+            
+        Yields:
+            数据库会话
+        """
+        session = self.get_session_maker(database, config)
+        session = session()
+        try:
+            yield session
+            session.commit()
+        except:
+            session.rollback()
+            raise
+        finally:
+            session.close()
+    
+    def dispose_engine(self, database: str) -> None:
+        """释放指定数据库的引擎资源
+        
+        Args:
+            database: 数据库名称
+        """
+        if database in self._engines:
+            self._engines[database].dispose()
+            del self._engines[database]
+        if database in self._sessions:
+            del self._sessions[database]
+    
+    def dispose_all(self) -> None:
+        """释放所有数据库资源"""
+        for database in list(self._engines.keys()):
+            self.dispose_engine(database)
+    
+    def __del__(self):
+        """析构函数,确保所有资源被释放"""
+        self.dispose_all()

+ 178 - 0
SourceCode/IntelligentRailwayCosting/app/tools/db_helper/mysql_helper.py

@@ -0,0 +1,178 @@
+import pymysql
+from typing import Dict, Optional, Any, List, Tuple
+from contextlib import contextmanager
+from sqlalchemy import create_engine
+from sqlalchemy.engine import Engine
+from .base import DBHelper
+from pymysql import Error
+
+class MySQLHelper(DBHelper):
+    def __init__(self):
+        super().__init__()
+        self._connections: Dict[str, pymysql.Connection] = {}
+        self._default_config = {
+            'db': '',
+            'host': 'localhost',
+            'port': 3306,
+            'user': '',
+            'password': '',
+            'charset': 'utf8mb4'
+        }
+        self._main_config_key = "db.mysql_main"
+
+    def get_engine(self, database: str, config: Optional[Dict[str, Any]] = None) -> Engine:
+        """获取或创建数据库引擎
+        
+        Args:
+            database: 数据库名称
+            config: 可选的连接配置
+            
+        Returns:
+            SQLAlchemy引擎实例
+        """
+        if database in self._engines:
+            return self._engines[database]
+        
+        conn_config = self._default_config.copy()
+        db_config = self.get_config_for_database(database)
+        conn_config.update(db_config)
+        if config:
+            conn_config.update(config)
+        
+        if 'db' not in conn_config or not conn_config['db']:
+            conn_config['db'] = database
+        
+        url = f"mysql+pymysql://{conn_config['user']}:{conn_config['password']}@{conn_config['host']}:{conn_config['port']}/{conn_config['db']}"
+        self._engines[database] = create_engine(
+            url,
+            pool_size=self._default_config['pool_size'],
+            max_overflow=self._default_config['max_overflow'],
+            pool_timeout=self._default_config['pool_timeout'],
+            pool_recycle=self._default_config['pool_recycle']
+        )
+        return self._engines[database]
+    
+    def connect(self, database: str, config: Optional[Dict[str, str]] = None) -> pymysql.Connection:
+        """连接到指定数据库
+        
+        Args:
+            database: 数据库名称
+            config: 可选的连接配置,如果不提供则使用默认配置
+            
+        Returns:
+            数据库连接对象
+        """
+        if database in self._connections:
+            try:
+                self._connections[database].ping(reconnect=True)
+                return self._connections[database]
+            except Error:
+                try:
+                    self._connections[database].close()
+                except Error:
+                    pass
+                del self._connections[database]
+        
+        conn_config = self._default_config.copy()
+        db_config = self.get_config_for_database(database)
+        conn_config.update(db_config)
+        if config:
+            conn_config.update(config)
+        
+        if 'db' not in conn_config or not conn_config['db']:
+            conn_config['db'] = database
+        
+        connection = pymysql.connect(**conn_config)
+        self._connections[database] = connection
+        return connection
+    
+    @contextmanager
+    def connection(self, database: str):
+        """获取数据库连接的上下文管理器
+        
+        Args:
+            database: 数据库名称
+            
+        Yields:
+            数据库连接对象
+        """
+        connection = self.connect(database)
+        try:
+            yield connection
+        finally:
+            pass
+    
+    def execute_query(self, database: str, query: str, params: Optional[Tuple] = None) -> List[Tuple]:
+        connection = self.connect(database)
+        cursor = connection.cursor()
+        
+        try:
+            if params:
+                cursor.execute(query, params)
+            else:
+                cursor.execute(query)
+            
+            results = cursor.fetchall()
+            return [tuple(row) for row in results]
+        finally:
+            cursor.close()
+    
+    def execute_non_query(self, database: str, query: str, params: Optional[Tuple] = None) -> int:
+        connection = self.connect(database)
+        cursor = connection.cursor()
+        
+        try:
+            if params:
+                cursor.execute(query, params)
+            else:
+                cursor.execute(query)
+            
+            connection.commit()
+            return cursor.rowcount
+        finally:
+            cursor.close()
+    
+    def execute_scalar(self, database: str, query: str, params: Optional[Tuple] = None) -> Any:
+        connection = self.connect(database)
+        cursor = connection.cursor()
+        
+        try:
+            if params:
+                cursor.execute(query, params)
+            else:
+                cursor.execute(query)
+            
+            row = cursor.fetchone()
+            return row[0] if row else None
+        finally:
+            cursor.close()
+    
+    def execute_procedure(self, database: str, procedure_name: str, params: Optional[Dict[str, Any]] = None) -> List[Tuple]:
+        connection = self.connect(database)
+        cursor = connection.cursor()
+        
+        try:
+            if params:
+                param_placeholders = ", ".join([f"%s" for _ in params.keys()])
+                query = f"CALL {procedure_name}({param_placeholders})"
+                cursor.execute(query, list(params.values()))
+            else:
+                cursor.execute(f"CALL {procedure_name}()")
+            
+            results = cursor.fetchall()
+            return [tuple(row) for row in results]
+        finally:
+            cursor.close()
+    
+    def dispose_all(self) -> None:
+        """释放所有数据库连接"""
+        for conn in self._connections.values():
+            try:
+                conn.close()
+            except Error:
+                pass
+        self._connections.clear()
+    
+    def __del__(self):
+        """析构函数,确保所有连接被关闭"""
+        self.dispose_all()

+ 107 - 0
SourceCode/IntelligentRailwayCosting/app/tools/db_helper/sqlserver_helper.py

@@ -0,0 +1,107 @@
+from typing import Dict, Optional, Any, List, Tuple
+
+from sqlalchemy import create_engine, text
+from sqlalchemy.engine import Engine
+from sqlalchemy.orm import sessionmaker
+
+from .base import DBHelper
+
+
+class SQLServerHelper(DBHelper):
+    def __init__(self):
+        super().__init__()
+        self._engines: Dict[str, Engine] = {}
+        self._session_makers: Dict[str, sessionmaker] = {}
+        self._default_config = {
+            'driver': 'ODBC Driver 17 for SQL Server',
+            'server': 'localhost',
+            'username': '',
+            'password': '',
+            'trusted_connection': 'yes'
+        }
+        self._main_config_key = "db.sqlserver_mian"
+
+    def _build_connection_string(self, database: str, config: Optional[Dict[str, str]] = None) -> str:
+        """构建连接字符串"""
+        conn_config = self._default_config.copy()
+        db_config = self.get_config_for_database(database)
+        conn_config.update(db_config)
+        if config:
+            conn_config.update(config)
+
+        # 构建认证字符串
+        auth_params = []
+        if conn_config.get('trusted_connection', True):
+            auth_params.append("Trusted_Connection=yes")
+        else:
+            auth_params.extend([
+                f"UID={conn_config['username']}",
+                f"PWD={conn_config['password']}"
+            ])
+
+        # 构建ODBC连接字符串
+        conn_parts = [
+            f"DRIVER={conn_config['driver']}",
+            f"SERVER={conn_config['server']}",
+            f"DATABASE={conn_config['database'] if 'database' in conn_config else database}"
+        ]
+        conn_parts.extend(auth_params)
+        
+        # 构建SQLAlchemy连接URL
+        conn_str = ";".join(conn_parts)
+        conn_url = f"mssql+pyodbc:///?odbc_connect={conn_str}"
+
+        return conn_url
+
+    def get_engine(self, database: str, config: Optional[Dict[str, str]] = None) -> Engine:
+        """获取或创建数据库引擎"""
+        if database not in self._engines:
+            conn_str = self._build_connection_string(database, config)
+            self._engines[database] = create_engine(
+                conn_str,
+                pool_size=5,
+                max_overflow=10,
+                pool_timeout=30,
+                pool_recycle=1800
+            )
+        return self._engines[database]
+
+    def execute_query(self, database: str, query: str, params: Optional[Dict[str, Any]] = None) -> List[Tuple]:
+        """执行查询并返回结果"""
+        with self.session_scope(database) as session:
+            result = session.execute(text(query), params or {})
+            return [tuple(row) for row in result.fetchall()]
+
+    def execute_non_query(self, database: str, query: str, params: Optional[Dict[str, Any]] = None) -> int:
+        """执行非查询操作(如INSERT, UPDATE, DELETE)"""
+        with self.session_scope(database) as session:
+            result = session.execute(text(query), params or {})
+            return result.rowcount
+
+    def execute_scalar(self, database: str, query: str, params: Optional[Dict[str, Any]] = None) -> Any:
+        """执行查询并返回第一行第一列的值"""
+        with self.session_scope(database) as session:
+            result = session.execute(text(query), params or {})
+            row = result.fetchone()
+            return row[0] if row else None
+
+    def execute_procedure(self, database: str, procedure_name: str, params: Optional[Dict[str, Any]] = None) -> List[Tuple]:
+        """执行存储过程"""
+        params = params or {}
+        param_str = ", ".join([f"@{key}=:{key}" for key in params.keys()])
+        query = f"EXEC {procedure_name} {param_str}"
+        
+        with self.session_scope(database) as session:
+            result = session.execute(text(query), params)
+            return [tuple(row) for row in result.fetchall()]
+
+    def dispose_all(self) -> None:
+        """释放所有数据库引擎资源"""
+        for engine in self._engines.values():
+            engine.dispose()
+        self._engines.clear()
+        self._session_makers.clear()
+
+    def __del__(self):
+        """析构函数,确保所有引擎资源被释放"""
+        self.dispose_all()

+ 168 - 0
SourceCode/IntelligentRailwayCosting/app/tools/utils/__init__.py

@@ -0,0 +1,168 @@
+import json
+
+from tools.utils.ai_helper import AiHelper
+from tools.utils.config_helper import ConfigHelper
+from tools.utils.file_helper import FileHelper
+from tools.utils.logger_helper import LoggerHelper
+from tools.utils.string_helper import StringHelper
+
+def get_logger():
+    """
+    获取日志记录器实例。
+
+    该函数通过调用LoggerHelper类的静态方法get_logger()来获取一个日志记录器实例。
+    主要用于需要记录日志的位置,通过该函数获取日志记录器实例,然后进行日志记录。
+    这样做可以保持日志记录的一致性和集中管理。
+
+    :return: Logger实例,用于记录日志。
+    """
+    return LoggerHelper.get_logger()
+
+def clean_log_file(day: int):
+    """
+    清理指定天数之前的日志文件。
+
+    :param day: 整数,表示清理多少天前的日志文件。
+    """
+    LoggerHelper.clean_log_file(day)
+
+def get_config():
+    """
+    获取配置管理器实例。
+
+    该函数返回一个ConfigHelper实例,用于读取和管理应用程序的配置信息。
+
+    :return: ConfigHelper实例,用于配置管理。
+    """
+    return ConfigHelper()
+
+def reload_config():
+    """
+    重新加载配置文件。
+
+    该函数会重新加载配置文件中的内容,适用于配置文件发生更改后需要重新加载的情况。
+    """
+    get_config().load_config()
+
+def get_config_value(key: str, default: str = None):
+    """
+    获取配置项的值。
+
+    :param key: 字符串,配置项的键。
+    :param default: 字符串,默认值(可选)。
+    :return: 配置项的值,如果不存在则返回默认值。
+    """
+    return get_config().get(key, default)
+
+def get_config_int(key: str, default: int = None):
+    """
+    获取配置项的整数值。
+
+    :param key: 字符串,配置项的键。
+    :param default: 整数,默认值(可选)。
+    :return: 配置项的整数值,如果不存在则返回默认值。
+    """
+    return get_config().get_int(key, default)
+
+def get_config_object(key: str, default: dict = None):
+    """
+    获取配置项的JSON对象。
+
+    :param key: 字符串,配置项的键。
+    :param default: 字典,默认值(可选)。
+    :return: 字典,配置项的JSON对象。
+    """
+    return get_config().get_object(key, default)
+def get_config_bool(key: str):
+    """
+    获取配置项的布尔值。
+
+    :param key: 字符串,配置项的键。
+    :return: 配置项的布尔值。
+    """
+    return get_config().get_bool(key)
+
+def download_remote_file(file_url: str, file_name: str) -> str:
+    """
+    下载远程文件并保存到本地。
+
+    :param file_url: 字符串,远程文件的URL。
+    :param file_name: 字符串,保存到本地的文件名。
+    :return: 字符串,下载后的文件路径。
+    """
+    return FileHelper().download_remote_file(file_url, file_name)
+
+def clean_attach_file(day: int):
+    """
+    清理指定天数之前的附件文件。
+
+    :param day: 整数,表示清理多少天前的附件文件。
+    """
+    FileHelper().clean_attach_file(day)
+
+def save_report_excel(data, file_name: str = None) -> str:
+    """
+    保存报表数据到Excel文件。
+
+    :param data: 列表,报表数据。
+    :param file_name: 字符串,保存的文件名(可选)。
+    :return: 字符串,保存的文件路径。
+    """
+    return FileHelper().save_report_excel(data, file_name)
+
+def clean_report_file(day: int):
+    """
+    清理指定天数之前的报表文件。
+
+    :param day: 整数,表示清理多少天前的报表文件。
+    """
+    FileHelper().clean_report_file(day)
+
+def encode_file(path: str):
+    return FileHelper.encode_file(path)
+
+def to_array(s: str, split: str = ",") -> list[str]:
+    """
+    将字符串按指定分隔符拆分为数组。
+
+    :param s: 字符串,待拆分的字符串。
+    :param split: 字符串,分隔符。
+    :return: 列表,拆分后的数组。
+    """
+    return StringHelper.to_array(s, split)
+
+def to_str(data:dict|list|tuple):
+    """
+    将对象转成字符串
+    :param data:
+    :return:
+    """
+    return StringHelper.to_str(data)
+
+def is_email(email: str) -> bool:
+    """
+    判断字符串是否为有效的电子邮件地址。
+
+    :param email: 字符串,待判断的电子邮件地址。
+    :return: 布尔值,是否为有效的电子邮件地址。
+    """
+    return StringHelper.is_email(email)
+
+def is_phone(phone: str) -> bool:
+    """
+    判断字符串是否为有效的手机号码。
+
+    :param phone: 字符串,待判断的手机号码。
+    :return: 布尔值,是否为有效的手机号码。
+    """
+    return StringHelper.is_phone(phone)
+
+def call_openai(system_prompt: str, user_prompt: str) -> json:
+    """
+    调用OpenAI API进行对话。
+
+    :param system_prompt: 字符串,系统提示信息。
+    :param user_prompt: 字符串,用户输入的提示信息。
+    :return: JSON对象,API返回的结果。
+    """
+    return AiHelper().call_openai(system_prompt, user_prompt)

+ 168 - 0
SourceCode/IntelligentRailwayCosting/app/tools/utils/ai_helper.py

@@ -0,0 +1,168 @@
+import json, re
+
+import tools.utils as  utils
+from openai import OpenAI
+from pathlib import Path
+
+class AiHelper:
+
+    _ai_api_key = None
+    _ai_api_url = None
+    _ai_max_tokens = 150
+
+    def __init__(self, api_url: str=None, api_key: str=None, api_model: str=None):
+        self._ai_api_url = api_url if api_url else utils.get_config_value("ai.url")
+        self._ai_api_key = api_key if api_key else utils.get_config_value("ai.key")
+        self._api_model = api_model if api_model else utils.get_config_value("ai.model")
+        max_tokens = utils.get_config_value("ai.max_tokens")
+        if max_tokens:
+            self._ai_max_tokens = int(max_tokens)
+
+    def call_openai(self, system_prompt: str, user_prompt: str,api_url: str=None,api_key: str=None,api_model: str=None) -> json:
+        self.check_api(api_key, api_model, api_url)
+        utils.get_logger().info(f"调用AI API ==> Url:{self._ai_api_url},Model:{self._api_model}")
+
+        client = OpenAI(api_key=self._ai_api_key, base_url=self._ai_api_url)
+        completion = client.chat.completions.create(
+            model=self._api_model,
+            messages=[
+                {
+                    "role": "system",
+                    "content": system_prompt,
+                },
+                {
+                    "role": "user",
+                    "content": user_prompt,
+                },
+            ],
+            stream=False,
+            temperature=0.7,
+            response_format={"type": "json_object"},
+            # max_tokens=self._ai_max_tokens,
+        )
+        try:
+            response = completion.model_dump_json()
+            result = {}
+            response_json = json.loads(response)
+            res_str = self._extract_message_content(response_json)
+            result_data = self._parse_response(res_str, True)
+            if result_data:
+                result["data"] = result_data
+                usage = response_json["usage"]
+                result["completion_tokens"] = usage.get("completion_tokens", 0)
+                result["prompt_tokens"] = usage.get("prompt_tokens", 0)
+                result["total_tokens"] = usage.get("total_tokens", 0)
+                utils.get_logger().info(f"AI Process JSON: {result}")
+            else:
+                utils.get_logger().info(f"AI Response: {response}")
+            return result
+        except Exception as e:
+            raise Exception(f"解析 AI 响应错误: {e}")
+
+    def check_api(self, api_key, api_model, api_url):
+        if api_url:
+            self._ai_api_url = api_url
+        if api_key:
+            self._ai_api_key = api_key
+        if api_model:
+            self._api_model = api_model
+        if self._ai_api_key is None:
+            raise Exception("AI API key 没有配置")
+        if self._ai_api_url is None:
+            raise Exception("AI API url 没有配置")
+        if self._api_model is None:
+            raise Exception("AI API model 没有配置")
+
+    @staticmethod
+    def _extract_message_content(response_json: dict) -> str:
+        utils.get_logger().info(f"AI Response JSON: {response_json}")
+        if "choices" in response_json and len(response_json["choices"]) > 0:
+            choice = response_json["choices"][0]
+            message_content = choice.get("message", {}).get("content", "")
+        elif "message" in response_json:
+            message_content = response_json["message"].get("content", "")
+        else:
+            raise Exception("AI 响应中未找到有效的 choices 或 message 数据")
+
+        # 移除多余的 ```json 和 ```
+        if message_content.startswith("```json") and message_content.endswith(
+                "```"):
+            message_content = message_content[6:-3]
+
+        # 去除开头的 'n' 字符
+        if message_content.startswith("n"):
+            message_content = message_content[1:]
+        # 移除无效的转义字符和时间戳前缀
+        message_content = re.sub(r"\\[0-9]{2}", "",
+                                 message_content)  # 移除 \32 等无效转义字符
+        message_content = re.sub(r"\d{4}-\d{2}-\dT\d{2}:\d{2}:\d{2}\.\d+Z", "",
+                                 message_content)  # 移除时间戳
+        message_content = message_content.strip()  # 去除首尾空白字符
+
+        # 替换所有的反斜杠
+        message_content = message_content.replace("\\", "")
+
+        return message_content
+
+    def _parse_response(self, response: str, first=True) -> json:
+        # utils.get_logger().info(f"AI Response JSON STR: {response}")
+        try:
+            data = json.loads(response)
+            return data
+
+        except json.JSONDecodeError as e:
+            if first:
+                utils.get_logger().error(f"JSON 解析错误,去除部分特殊字符重新解析一次: {e}")
+                # 替换中文引号为空
+                message_content = re.sub(r"[“”]", "", response)  # 替换双引号
+                message_content = re.sub(r"[‘’]", "", message_content)  # 替换单引号
+                return self._parse_response(message_content, False)
+            else:
+                raise Exception(f"解析 AI 响应错误: {response} {e}")
+
+    def call_openai_with_image(self, image_path,system_prompt: str, user_prompt: str, api_url: str=None,api_key: str=None,api_model: str=None) -> json:
+        pass
+
+    def call_openai_with_file(self, file_path,system_prompt: str, user_prompt: str, api_url: str=None,api_key: str=None,api_model: str=None)->json:
+        self.check_api(api_key, api_model, api_url)
+        utils.get_logger().info(f"调用AI API File==> Url:{self._ai_api_url},Model:{self._api_model}")
+
+        client = OpenAI(api_key=self._ai_api_key, base_url=self._ai_api_url)
+        file_object = client.files.create(    file=Path(file_path),purpose='file-extract',)
+        completion = client.chat.completions.create(
+            model=self._api_model,
+            messages=[
+                {
+                    "role": "system",
+                    # "content": system_prompt,
+                    'content': f'fileid://{file_object.id}'
+                },
+                {
+                    "role": "user",
+                    "content": user_prompt,
+                },
+            ],
+            stream=False,
+            temperature=0.7,
+            response_format={"type": "json_object"},
+            # max_tokens=self._ai_max_tokens,
+        )
+        try:
+            response = completion.model_dump_json()
+            result = {}
+            response_json = json.loads(response)
+            res_str = self._extract_message_content(response_json)
+            result_data = self._parse_response(res_str, True)
+            if result_data:
+                result["data"] = result_data
+                usage = response_json["usage"]
+                result["completion_tokens"] = usage.get("completion_tokens", 0)
+                result["prompt_tokens"] = usage.get("prompt_tokens", 0)
+                result["total_tokens"] = usage.get("total_tokens", 0)
+                utils.get_logger().info(f"AI Process JSON: {result}")
+            else:
+                utils.get_logger().info(f"AI Response: {response}")
+            return result
+        except Exception as e:
+            raise Exception(f"解析 AI 响应错误: {e}")
+        pass

+ 91 - 0
SourceCode/IntelligentRailwayCosting/app/tools/utils/config_helper.py

@@ -0,0 +1,91 @@
+import os, yaml
+
+class ConfigHelper:
+    _instance = None
+
+    # 默认配置文件路径
+    default_config_path = os.path.join(os.path.dirname(__file__), "..\..", "config.yml")
+
+    # 类变量存储加载的配置
+    _config = None
+    _path = None
+
+    def __new__(cls, *args, **kwargs):
+        if not cls._instance:
+            cls._instance = super(ConfigHelper, cls).__new__(cls)
+        return cls._instance
+
+    def load_config(self, path=None):
+        if self._config is None:
+            if not path:
+                # 从环境变量中获取配置路径
+                path = os.environ.get("CONFIG_PATH", self.default_config_path)
+                self._path = path
+            else:
+                self._path = path
+            if not os.path.exists(self._path):
+                raise FileNotFoundError(f"没有找到配置文件或目录:'{self._path}'")
+        with open(self._path, "r", encoding="utf-8") as file:
+            self._config = yaml.safe_load(file)
+        # 合并环境变量配置
+        self._merge_env_vars()
+        # print(f"加载的配置文件内容:{self._config}")
+        return self._config
+
+    def _merge_env_vars(self, env_prefix="APP_"):  # 环境变量前缀为 APP_
+        for key, value in os.environ.items():
+            if key.startswith(env_prefix):
+                config_key = key[len(env_prefix) :].lower()
+                self._set_nested_key(self._config, config_key.split("__"), value)
+
+    def _set_nested_key(self, config, keys, value):
+        if len(keys) > 1:
+            if keys[0] not in config or not isinstance(config[keys[0]], dict):
+                config[keys[0]] = {}
+            self._set_nested_key(config[keys[0]], keys[1:], value)
+        else:
+            config[keys[0]] = value
+
+    def get(self, key: str, default: str = None):
+        if self._config is None:
+            self.load_config(self._path)
+        keys = key.split(".")
+        config = self._config
+        for k in keys:
+            if isinstance(config, dict) and k in config:
+                config = config[k]
+            else:
+                return default
+        return config
+
+    def get_object(self, key: str, default: dict = None):
+        val = self.get(key)
+        if not val:
+            return default
+        if isinstance(val, dict):
+            return val
+        try:
+            return yaml.safe_load(val)
+        except yaml.YAMLError as e:
+            print(f"Error loading YAML object: {e}")
+            return default
+    def get_bool(self, key: str) -> bool:
+        val = self.get(key, "0")
+        if isinstance(val, bool):
+            return val
+        val_str = str(val)
+        return True if val_str.lower() == "true" or val_str == "1" else False
+
+    def get_int(self, key: str, default: int = 0) -> int:
+        val = self.get(key)
+        if not val:
+            return default
+        try:
+            return int(val)
+        except ValueError:
+            return default
+
+    def get_all(self):
+        if self._config is None:
+            self.load_config(self._path)
+        return self._config

+ 186 - 0
SourceCode/IntelligentRailwayCosting/app/tools/utils/file_helper.py

@@ -0,0 +1,186 @@
+import os
+import shutil
+import tools.utils as utils
+from datetime import datetime, timedelta
+from urllib.parse import urlparse
+
+import pandas as pd
+import requests
+import mimetypes
+import base64
+
+
+class FileHelper:
+
+    DEFAULT_ATTACH_PATH = "./temp_files/attaches/"
+    DEFAULT_REPORT_PATH = "./temp_files/report/"
+
+    def __init__(self):
+        attach_path = utils.get_config_value(
+            "save.attach_file_path", self.DEFAULT_ATTACH_PATH
+        )
+        attach_path = attach_path.replace("\\", "/")
+        attach_path = attach_path.replace("//", "/")
+        self._attach_file_path = attach_path
+        report_path = utils.get_config_value(
+            "save.report_file_path", self.DEFAULT_REPORT_PATH
+        )
+        report_path = report_path.replace("\\", "/")
+        report_path = report_path.replace("//", "/")
+        self._report_file_path = report_path
+
+    def download_remote_file(self, file_url: str, file_name: str) -> str | None:
+        utils.get_logger().info(f"下载远程文件: {file_url}  文件名:{file_name}")
+        current_timestamp = datetime.now().strftime("%H%M%S%f")[:-3]  # 取前三位毫秒
+        file_name = f"{current_timestamp}@{file_name}"
+        file_path = os.path.join(
+            self._attach_file_path, f'{datetime.now().strftime("%Y-%m-%d")}'
+        )
+        if not os.path.exists(file_path):
+            os.makedirs(file_path)
+        path = os.path.join(file_path, file_name)
+        path = path.replace("\\", "/")
+        path = path.replace("//", "/")
+        # 10个不同的 User-Agent
+        user_agents = [
+            "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
+            "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.1.2 Safari/605.1.15",
+            "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
+            "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:89.0) Gecko/20100101 Firefox/89.0",
+            "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) CriOS/91.0.4472.124 Safari/605.1.15",
+            "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:89.0) Gecko/20100101 Firefox/89.0",
+            "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Edge/91.0.864.59 Safari/537.36",
+            "Mozilla/5.0 (iPhone; CPU iPhone OS 14_6 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.1.2 Mobile/15E148 Safari/604.1",
+            "Mozilla/5.0 (iPad; CPU OS 14_6 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.1.2 Mobile/15E148 Safari/604.1",
+            "Mozilla/5.0 (Linux; Android 11; SM-G973F) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Mobile Safari/537.36",
+        ]
+
+        # 根据文件名长度选择一个 User-Agent
+        ua_index = len(file_name) % len(user_agents)
+        # 解析 file_url 获取 Referer
+        parsed_url = urlparse(file_url)
+        referer = f"{parsed_url.scheme}://{parsed_url.netloc}/".replace(
+            "//download.", "//www."
+        )
+        headers = {
+            "User-Agent": user_agents[ua_index],
+            "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
+            "Accept-Encoding": "gzip, deflate, br",
+            "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8,zh-TW;q=0.7",
+            "Referer": referer,
+        }
+
+        try:
+            response = requests.get(file_url, headers=headers, allow_redirects=True)
+            response.raise_for_status()
+            with open(path, "wb") as f:
+                f.write(response.content)
+            utils.get_logger().info(f"文件下载成功: {file_name}")
+            return path
+        except requests.exceptions.HTTPError as http_err:
+            utils.get_logger().error(f"HTTP 错误: {http_err}")
+        except Exception as e:
+            utils.get_logger().error(f"文件下载失败: {file_name}。Exception: {e}")
+            return None
+
+    def clean_attach_file(self, day: int) -> None:
+        try:
+            current_time = datetime.now()
+            cutoff_time = current_time - timedelta(days=day)
+            for root, dirs, _ in os.walk(self._attach_file_path):
+                for dir_name in dirs:
+                    path = os.path.join(root, dir_name)
+                    dir_path = (
+                        str(path).replace(self._attach_file_path, "").replace("\\", "/")
+                    )
+                    if dir_path.count("/") > 0:
+                        continue
+                    try:
+                        dir_date = datetime.strptime(dir_path, "%Y-%m-%d")
+                        if dir_date < cutoff_time:
+                            try:
+                                shutil.rmtree(path)
+                                utils.get_logger().info(
+                                    f"  删除目录及其内容: {dir_path}"
+                                )
+                            except PermissionError:
+                                utils.get_logger().error(
+                                    f"  权限错误,无法删除目录: {dir_path}"
+                                )
+                            except Exception as e:
+                                utils.get_logger().error(
+                                    f"  删除目录失败: {dir_path}。Exception: {e}"
+                                )
+                    except ValueError:
+                        # 如果目录名称不符合 %Y-%m/%d 格式,跳过
+                        continue
+        except Exception as e:
+            utils.get_logger().error(f"attach 文件清理失败。Exception: {e}")
+
+    def save_report_excel(self, data, file_name: str = None) -> str:
+        try:
+            df = pd.DataFrame(data)
+            file_path = os.path.join(
+                self._report_file_path, f'{datetime.now().strftime("%Y-%m-%d")}'
+            )
+            if not os.path.exists(file_path):
+                os.makedirs(file_path)
+            file_name = f"{file_name}_{datetime.now().strftime('%H%M%S')}.xlsx"
+            path = os.path.join(file_path, file_name)
+            path = path.replace("\\", "/")
+            path = path.replace("//", "/")
+            df.to_excel(path, index=False)
+            utils.get_logger().debug(f"Report报存成功: {file_name}")
+            return path
+        except Exception as e:
+            utils.get_logger().error(f"保存 Report Excel 文件失败。Exception: {e}")
+            return ""
+
+    def clean_report_file(self, day: int) -> None:
+        try:
+            current_time = datetime.now()
+            cutoff_time = current_time - timedelta(days=day)
+            for root, dirs, _ in os.walk(self._report_file_path):
+                for dir_name in dirs:
+                    path = os.path.join(root, dir_name)
+                    dir_path = (
+                        str(path).replace(self._report_file_path, "").replace("\\", "/")
+                    )
+                    if dir_path.count("/") > 0:
+                        continue
+                    try:
+                        dir_date = datetime.strptime(dir_path, "%Y-%m-%d")
+                        if dir_date < cutoff_time:
+                            try:
+                                shutil.rmtree(path)
+                                utils.get_logger().info(
+                                    f"  Report 删除目录及其内容: {dir_path}"
+                                )
+                            except PermissionError:
+                                utils.get_logger().error(
+                                    f"  Report 权限错误,无法删除目录: {dir_path}"
+                                )
+                            except Exception as e:
+                                utils.get_logger().error(
+                                    f"  Report 删除目录失败: {dir_path}。Exception: {e}"
+                                )
+                    except ValueError:
+                        # 如果目录名称不符合 %Y-%m/%d 格式,跳过
+                        continue
+        except Exception as e:
+            utils.get_logger().error(f"Report 文件清理失败。Exception: {e}")
+
+    @staticmethod
+    def encode_file(file_path: str):
+        if not os.path.exists(file_path):
+            utils.get_logger().error(f"文件不存在: {file_path}")
+            raise FileNotFoundError(f"文件不存在: {file_path}")
+        # 根据文件扩展名获取 MIME 类型
+        mime_type, _ = mimetypes.guess_type(file_path)
+        if mime_type is None:
+            mime_type = 'image/jpeg'  # 默认使用 jpeg 类型
+        # 将图片编码为 base64 字符串
+        with open(file_path, "rb") as image_file:
+            encoded_string = base64.b64encode(image_file.read())
+            base64_str = encoded_string.decode("utf-8")
+            return f"data:{mime_type};base64,{base64_str}"

+ 113 - 0
SourceCode/IntelligentRailwayCosting/app/tools/utils/logger_helper.py

@@ -0,0 +1,113 @@
+import logging
+import os
+from datetime import datetime
+from logging.handlers import TimedRotatingFileHandler
+
+from tools.utils.config_helper import ConfigHelper
+
+
+class LoggerHelper:
+    """
+    日志辅助类,用于创建和提供日志记录器实例
+    该类实现了单例模式,确保在整个应用程序中只有一个日志记录器实例被创建和使用
+    """
+
+    _instance = None
+    config = ConfigHelper()
+    _log_file_name = f"{config.get("logger.file_name", "log")}.log"
+    _log_file_path = config.get("logger.file_path", "./logs")
+    _log_level_string = config.get("logger.level", "INFO")
+
+    def __new__(cls, *args, **kwargs):
+        """
+        实现单例模式,确保日志记录器仅被创建一次
+        如果尚未创建实例,则创建并初始化日志记录器
+        """
+        if not cls._instance:
+            cls._instance = super(LoggerHelper, cls).__new__(cls, *args, **kwargs)
+            try:
+                cls._instance._initialize_logger()
+            except Exception as e:
+                raise Exception(f"配置logger出错: {e}")
+        return cls._instance
+
+    @property
+    def logger(self):
+        return self._logger
+
+    def _initialize_logger(self):
+        """
+        初始化日志记录器,包括设置日志级别、创建处理器和格式化器,并将它们组合起来
+        """
+        log_level = self._get_log_level()
+        self._logger = logging.getLogger("app_logger")
+        self._logger.setLevel(log_level)
+
+        if not os.path.exists(self._log_file_path):
+            os.makedirs(self._log_file_path)
+
+        # 创建按日期分割的文件处理器
+        file_handler = TimedRotatingFileHandler(
+            os.path.join(self._log_file_path, self._log_file_name),
+            when="midnight",
+            interval=1,
+            backupCount=7,
+            encoding="utf-8",
+        )
+        file_handler.setLevel(log_level)
+
+        # 创建控制台处理器
+        console_handler = logging.StreamHandler()
+        console_handler.setLevel(logging.DEBUG)
+
+        # 创建格式化器
+        formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
+
+        # 将格式化器添加到处理器
+        file_handler.setFormatter(formatter)
+        console_handler.setFormatter(formatter)
+
+        # 将处理器添加到日志记录器
+        self._logger.addHandler(file_handler)
+        self._logger.addHandler(console_handler)
+
+    def _get_log_level(self):
+        try:
+            # 尝试将字符串转换为 logging 模块中的日志级别常量
+            log_level = getattr(logging, self._log_level_string.upper())
+            if not isinstance(log_level, int):
+                raise ValueError
+            return log_level
+        except (AttributeError, ValueError):
+            raise ValueError(
+                f"配置logger出错: Unknown level: '{self._log_level_string}'"
+            )
+
+    @classmethod
+    def get_logger(cls):
+        """
+        提供初始化后的日志记录器实例
+        :return: 初始化后的日志记录器实例
+        """
+        if not cls._instance:
+            cls._instance = cls()
+        return cls._instance._logger
+
+    @classmethod
+    def clean_log_file(cls, day: int):
+        if not os.path.exists(cls._log_file_path):
+            return
+        for filename in os.listdir(cls._log_file_path):
+            if filename != cls._log_file_name and filename.startswith(
+                cls._log_file_name
+            ):
+                try:
+                    file_path = os.path.join(cls._log_file_path, filename)
+                    file_time = datetime.strptime(
+                        filename.replace(f"{cls._log_file_name}.", ""), "%Y-%m-%d"
+                    )
+                    if (datetime.now() - file_time).days > day:
+                        os.remove(file_path)
+                        cls.get_logger().info(f"  删除日志文件: {file_path}")
+                except Exception as e:
+                    cls.get_logger().error(f"删除日志文件出错: {filename} {e}")

+ 103 - 0
SourceCode/IntelligentRailwayCosting/app/tools/utils/string_helper.py

@@ -0,0 +1,103 @@
+import json
+class StringHelper:
+
+    @staticmethod
+    def check_empty(s: str, default: str) -> str:
+        """
+        检查字符串是否为空
+        """
+        if s:
+            return s
+        return default
+
+    @staticmethod
+    def to_array(s: str, sep: str = ",") -> list[str]:
+        """
+        将字符串按指定分隔符分割成数组。
+
+        :param s: 要分割的字符串。
+        :param sep: 分隔符,默认为逗号。
+        :return: 分割后的字符串数组。
+        """
+        if not s:
+            return []
+        if sep == ",":
+            s = s.replace(",", ",")
+        return s.split(sep)
+
+    @staticmethod
+    def e_startswith(s: str, prefix: str) -> str:
+        """
+        检查字符串是否以特定前缀开头,如果没有则补全。
+
+        :param s: 要检查的字符串。
+        :param prefix: 前缀。
+        :return: 如果字符串以指定前缀开头,返回原字符串;否则返回补全后的字符串。
+        """
+        if not s.startswith(prefix):
+            return prefix + s
+        return s
+
+    @staticmethod
+    def e_endswith(s: str, suffix: str) -> str:
+        """
+        检查字符串是否以特定后缀结尾,如果没有则补全。
+
+        :param s: 要检查的字符串。
+        :param suffix: 后缀。
+        :return: 如果字符串以指定后缀结尾,返回原字符串;否则返回补全后的字符串。
+        """
+        if not s.endswith(suffix):
+            return s + suffix
+        return s
+
+    @staticmethod
+    def split_and_clean(s: str, sep: str = ",") -> list[str]:
+        """
+        将字符串按指定分隔符分割并去除空字符串。
+
+        :param s: 要分割的字符串。
+        :param sep: 分隔符,默认为逗号。
+        :return: 分割后的字符串数组,去除空字符串。
+        """
+        if not s:
+            return []
+        parts = StringHelper.to_array(s, sep)
+        return [part.strip() for part in parts if part.strip()]
+
+    @staticmethod
+    def remove_extra_spaces(s: str) -> str:
+        """
+        将字符串中的多个连续空格替换为单个空格。
+
+        :param s: 要处理的字符串。
+        :return: 替换后的字符串。
+        """
+        return " ".join(s.split())
+
+    @staticmethod
+    def to_str(data:dict|list|tuple):
+        return json.dumps(data, ensure_ascii=False)
+    @staticmethod
+    def is_email(s: str) -> bool:
+        """
+        验证字符串是否为有效的邮箱地址。
+
+        :param s: 要验证的字符串。
+        :return: 如果是有效的邮箱地址返回True,否则返回False。
+        """
+        import re
+        pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
+        return bool(re.match(pattern, s)) if s else False
+
+    @staticmethod
+    def is_phone(s: str) -> bool:
+        """
+        验证字符串是否为有效的手机号码(中国大陆手机号)。
+
+        :param s: 要验证的字符串。
+        :return: 如果是有效的手机号码返回True,否则返回False。
+        """
+        import re
+        pattern = r'^1[3-9]\d{9}$'
+        return bool(re.match(pattern, s)) if s else False

+ 0 - 0
SourceCode/IntelligentRailwayCosting/app/views/__init__.py


+ 6 - 0
SourceCode/IntelligentRailwayCosting/requirements.txt

@@ -0,0 +1,6 @@
+pymysql==1.1.0
+sqlalchemy==1.4.50
+python-dotenv==1.0.0
+typing-extensions>=4.12.2
+pydantic==2.10.6
+pydantic-core==2.27.2