YueYunyun 5 maanden geleden
bovenliggende
commit
b7fa868ec6
21 gewijzigde bestanden met toevoegingen van 645 en 438 verwijderingen
  1. 1 1
      SourceCode/IntelligentRailwayCosting/app/config.yml
  2. 7 1
      SourceCode/IntelligentRailwayCosting/app/core/configs/app_config.py
  3. 88 3
      SourceCode/IntelligentRailwayCosting/app/executor/collector.py
  4. 9 4
      SourceCode/IntelligentRailwayCosting/app/executor/processor.py
  5. 11 4
      SourceCode/IntelligentRailwayCosting/app/executor/sender.py
  6. 2 2
      SourceCode/IntelligentRailwayCosting/app/routes/project.py
  7. 2 2
      SourceCode/IntelligentRailwayCosting/app/routes/project_task.py
  8. 18 27
      SourceCode/IntelligentRailwayCosting/app/services/project.py
  9. 2 1
      SourceCode/IntelligentRailwayCosting/app/services/project_quota.py
  10. 4 17
      SourceCode/IntelligentRailwayCosting/app/services/project_task.py
  11. 100 86
      SourceCode/IntelligentRailwayCosting/app/stores/budget.py
  12. 43 43
      SourceCode/IntelligentRailwayCosting/app/stores/log.py
  13. 62 62
      SourceCode/IntelligentRailwayCosting/app/stores/project.py
  14. 95 97
      SourceCode/IntelligentRailwayCosting/app/stores/project_quota.py
  15. 67 65
      SourceCode/IntelligentRailwayCosting/app/stores/project_task.py
  16. 10 8
      SourceCode/IntelligentRailwayCosting/app/stores/user.py
  17. 96 7
      SourceCode/IntelligentRailwayCosting/app/tools/db_helper/__init__.py
  18. 3 3
      SourceCode/IntelligentRailwayCosting/app/tools/db_helper/base.py
  19. 1 1
      SourceCode/IntelligentRailwayCosting/app/tools/db_helper/mysql_helper.py
  20. 1 1
      SourceCode/IntelligentRailwayCosting/app/tools/db_helper/sqlserver_helper.py
  21. 23 3
      SourceCode/IntelligentRailwayCosting/app/views/static/project/budget_info.js

+ 1 - 1
SourceCode/IntelligentRailwayCosting/app/config.yml

@@ -3,6 +3,7 @@ app:
   name: '铁路造价智能化工具'
   version: '2024' # 应用版本 2020|2024
   user_version: true
+  collect_api_url: 'http://192.168.0.104:8020/api'
 db:
   # SQL Server 配置
   # SQL Server 2008:'{SQL Server}' 或 '{SQL Server Native Client 10.0}'
@@ -77,6 +78,5 @@ fastgpt_ai:
     app_02_2024:
       api_url: http://192.168.0.104:8020/api
       api_key: fastgpt-o4CF7Pu1FRTvHjWFqeNcClBS6ApyflNfkBGXo9p51fuBMAX1L0erU8yz8
-
 file:
   source_path: './temp_files'

+ 7 - 1
SourceCode/IntelligentRailwayCosting/app/core/configs/app_config.py

@@ -3,6 +3,7 @@ class AppConfig:
     _name = None
     _version = None
     _user_version = None
+    _collect_api_url = None
 
     @property
     def name(self):
@@ -16,6 +17,10 @@ class AppConfig:
     def user_version(self)->bool:
         return self._user_version
 
+    @property
+    def collect_api_url(self)->str:
+        return self._collect_api_url
+
     def update_config(self, config):
         """更新应用配置
         
@@ -24,4 +29,5 @@ class AppConfig:
         """
         self._name = config.get('name')
         self._version = config.get('version')
-        self._user_version = config.get('user_version')
+        self._user_version = config.get('user_version')
+        self._collect_api_url = config.get('collect_api_url')

+ 88 - 3
SourceCode/IntelligentRailwayCosting/app/executor/collector.py

@@ -1,16 +1,101 @@
 import tools.utils as utils
-from core.dtos import ProjectTaskDto
+from core.dtos import ProjectTaskDto, ProjectQuotaDto
+from stores import ProjectTaskStore,ProjectQuotaStore,BudgetStore
+from tools import db_helper
 
 
 class Collector:
     def __init__(self):
         self._logger = utils.get_logger()
+        self._budget_store = BudgetStore()
+        self._task_store = ProjectTaskStore()
+        self._quota_store = ProjectQuotaStore()
 
     def collect(self,task:ProjectTaskDto):
         try:
             self._logger.info(f"开始采集任务:{task.task_name}")
+            self._task_store.update_collect_status(task.id,1)
+            if not task.file_path:
+                raise Exception("任务文件不存在")
+            chapters,msg= self._get_chapters(task)
+            if not chapters:
+                raise Exception(msg)
+            files,msg = self._read_files(task.file_path)
+            if not files:
+                raise Exception(msg)
+            result,msg = self._call_api(task,chapters,files)
+            if not result:
+                raise Exception(msg)
+            self._insert_data(task,result)
+            self._task_store.update_collect_status(task.id,2)
             self._logger.info(f"采集任务:{task.task_name}完成")
             return None
         except Exception as e:
-            self._logger.error(f"采集任务:{task.task_name}失败,原因:{e}")
-            return f"采集失败,原因:{e}"
+            msg = f"任务采集失败,原因:{e}"
+            self._logger.error(f"采集任务:{task.task_name}, {msg}")
+            self._task_store.update_collect_status(task.id,3, msg)
+            return msg
+    def _read_files(self,paths:str):
+        try:
+            files=[]
+            self._logger.debug(f"开始读取文件:{paths}")
+            path_list= paths.split(",")
+            for path in path_list:
+                file = utils.encode_file(path)
+                files.append(file)
+            self._logger.debug(f"读取文件完成:{paths}")
+            return files, ''
+        except Exception as e:
+            msg = f"读取文件失败,原因:{e}"
+            self._logger.error(f"读取文件失败,原因:{e}")
+            return None,msg
+
+    def _get_chapters(self,task:ProjectTaskDto):
+        try:
+            self._logger.debug(f"开始调用接口:{task.task_name}")
+
+            data = self._budget_store.get_all_budget_items_not_children(task.project_id,task.budget_id, task.item_code)
+            # print(len(data))
+            # for item in data:
+            #     print(item)
+            self._logger.debug(f"调用接口:{task.task_name}完成")
+            return data,''
+        except Exception as e:
+            msg = f"调用接口失败,原因:{e}"
+            self._logger.error(f"调用接口:{task.task_name}, {msg}")
+            return None,msg
+
+    def _call_api(self,task:ProjectTaskDto,chapters:list,files:list):
+       try:
+           self._logger.debug(f"开始调用接口:{task.task_name}")
+           self._task_store.update_process_status(task.id, 1)
+           self._task_store.update_process_status(task.id, 2)
+           self._logger.debug(f"调用接口:{task.task_name}完成")
+           return [], ""
+       except Exception as e:
+           msg = f"调用接口失败,原因:{e}"
+           self._logger.error(f"调用接口:{task.task_name}, {msg}")
+           return None,msg
+
+    def _insert_data(self,task:ProjectTaskDto,data:list):
+        try:
+            self._logger.debug(f"开始插入数据:{task.task_name}")
+            for item in data:
+                quota =  ProjectQuotaDto(
+                    budget_id=task.budget_id,
+                    project_id=task.project_id,
+                    item_code=item['item_code'],
+                    item_id=item['item_id'],
+                    project_name=item['project_name'],
+                    project_quantity=item['project_quantity'],
+                    project_quantity_input=item['project_quantity_input'],
+                    unit=item['unit'],
+                    unit_weight=item['unit_weight'],
+                )
+                self._quota_store.create_quota(quota)
+            self._logger.debug(f"插入数据完成:{task.task_name}")
+            return True
+        except Exception as e:
+            msg = f"插入数据失败,原因:{e}"
+            self._logger.error(f"插入数据失败,原因:{e}")
+            return False,msg

+ 9 - 4
SourceCode/IntelligentRailwayCosting/app/executor/processor.py

@@ -1,16 +1,18 @@
 
 import tools.utils as utils
 from core.dtos import ProjectTaskDto, ProjectQuotaDto
-from stores import ProjectQuotaStore
+from stores import ProjectTaskStore,ProjectQuotaStore
 
 class Processor:
     def __init__(self):
         self._logger = utils.get_logger()
+        self._task_store = ProjectTaskStore()
         self._quota_store = ProjectQuotaStore()
 
     def process(self,task:ProjectTaskDto):
         try:
             self._logger.info(f"开始处理任务:{task.task_name}")
+            self._task_store.update_process_status(task.id,1)
             data_list = self._quota_store.get_quotas_by_task_id(task.id)
             error_count = 0
             for quota in data_list:
@@ -19,11 +21,14 @@ class Processor:
                 if msg:
                     error_count += 1
                     continue
+            self._task_store.update_process_status(task.id,2)
             self._logger.info(f"处理任务:{task.task_name}完成,{error_count}项错误/共{len(data_list)}项")
             return None
         except Exception as e:
-            self._logger.error(f"处理任务:{task.task_name}失败,原因:{e}")
-            return f"处理失败,原因:{e}"
+            msg = f"任务处理失败,原因:{e}"
+            self._logger.error(f"处理任务:{task.task_name}, {msg}")
+            self._task_store.update_process_status(task.id,3,msg)
+            return msg
 
     def process_quota(self,quota:ProjectQuotaDto):
         try:
@@ -34,7 +39,7 @@ class Processor:
             self._logger.info(f"处理定额:{quota.id}完成")
             return None
         except Exception as e:
-            msg = f"处理失败,原因:{e}"
+            msg = f"定额处理失败,原因:{e}"
             self._logger.error(f"处理定额:{quota.id},{msg}")
             self._quota_store.update_process_status(quota.id,3, msg)
             return msg

+ 11 - 4
SourceCode/IntelligentRailwayCosting/app/executor/sender.py

@@ -1,16 +1,19 @@
 import tools.utils as utils
 from core.dtos import ProjectTaskDto, ProjectQuotaDto
-from stores import ProjectQuotaStore
+from stores import ProjectQuotaStore, ProjectTaskStore
 
 
 class Sender:
     def __init__(self):
         self._logger = utils.get_logger()
+        self._task_store = ProjectTaskStore()
         self._quota_store = ProjectQuotaStore()
 
     def send(self,task:ProjectTaskDto):
         try:
             self._logger.info(f"开始发送任务:{task.task_name}")
+            self._task_store.update_send_status(task.id,1)
+
             error_count = 0
             data_list = self._quota_store.get_quotas_by_task_id(task.id,True)
             for data in data_list:
@@ -18,11 +21,15 @@ class Sender:
                 if msg:
                     error_count+=1
                     continue
+
+            self._task_store.update_send_status(task.id,2)
             self._logger.info(f"发送任务:{task.task_name}完成,{error_count}项错误/共{len(data_list)}项")
             return None
         except Exception as e:
-            self._logger.error(f"发送任务:{task.task_name}失败,原因:{e}")
-            return f"发送失败,原因:{e}"
+            msg = f"任务发送失败,原因:{e}"
+            self._logger.error(f"发送任务:{task.task_name},{msg}")
+            self._task_store.update_send_status(task.id, 3, msg)
+            return msg
 
 
     def send_quota(self,quota:ProjectQuotaDto):
@@ -34,7 +41,7 @@ class Sender:
             self._logger.info(f"发送定额:{quota.id}完成")
             return None
         except Exception as e:
-            msg = f"发送失败,原因:{e}"
+            msg = f"定额发送失败,原因:{e}"
             self._logger.error(f"发送定额:{quota.id},{msg}")
             self._quota_store.update_send_status(quota.id, 3, msg)
             return msg

+ 2 - 2
SourceCode/IntelligentRailwayCosting/app/routes/project.py

@@ -39,9 +39,9 @@ def get_budget_info(project_id:str):
     except Exception as e:
         return ResponseBase.error(f'获取项目概算信息失败:{str(e)}')
 
-@project_api.route('/budget-item/top/<budget_id>/<project_id>', methods=['POST'])
+@project_api.route('/budget-item/top/<int:budget_id>/<project_id>', methods=['POST'])
 @Permission.authorize
-def get_budget_top_items(budget_id:str,project_id:str):
+def get_budget_top_items(budget_id:int,project_id:str):
     try:
         data,msg = project_srvice.get_top_budget_items(budget_id, project_id)
         if not data:

+ 2 - 2
SourceCode/IntelligentRailwayCosting/app/routes/project_task.py

@@ -45,7 +45,7 @@ def save_task(task_id:int):
         item_code = form_data.get('item_code')
         task_name = form_data.get('task_name')
         task_desc = form_data.get('task_desc')
-        delete_file = form_data.get('delete_file', 'false').lower() == 'true'
+        delete_old = form_data.get('delete_old', 'false').lower() == 'true'
         # 获取上传的文件
         files = request.files.getlist('files')
         # 验证必要参数
@@ -63,7 +63,7 @@ def save_task(task_id:int):
         )
         
         # 保存任务
-        task = task_service.save_task(task_id, task_dto, files, delete_file)
+        task = task_service.save_task(task_id, task_dto, files, delete_old)
         return ResponseBase.success(task.to_dict())
     except ValueError as ve:
         return ResponseBase.error(f'参数格式错误:{str(ve)}')

+ 18 - 27
SourceCode/IntelligentRailwayCosting/app/services/project.py

@@ -6,14 +6,13 @@ from core.dtos.project import ProjectDto
 from core.dtos.tree import TreeDto
 from core.user_session import UserSession
 from stores import ProjectStore,BudgetStore
-from tools import db_helper
 
 
 class ProjectService:
 
     def __init__(self):
         self._project_store = ProjectStore()
-        self._budget_store =None
+        self._budget_store = BudgetStore()
 
     def get_projects_paginated(self, page: int, page_size: int, keyword: Optional[str] = None,
         start_time: Optional[str] = None,
@@ -36,30 +35,25 @@ class ProjectService:
         return [ProjectDto.from_model(item).to_dict() for item in data.get('data',[])],data.get('total',0)
 
     def get_budget_info(self, project_id: str):
-        db_session, msg = self._create_project_db_session(project_id)
-        if not db_session:
+        msg = self._check_project_db_exit(project_id)
+        if msg:
             return None, msg
-        budget_store = BudgetStore(db_session)
-        data = budget_store.get_budget_info()
+        data =  self._budget_store.get_budget_info(project_id)
         return [TotalBudgetInfoDto.from_model(item).to_dict() for item in data],""
 
-    def get_top_budget_items(self, budget_id: str, project_id: str):
-        if not budget_id:
-            return None,' budget_id不能为空'
-        db_session,msg = self._create_project_db_session(project_id)
-        if not db_session:
+    def get_top_budget_items(self, budget_id: int, project_id: str):
+        msg = self._check_project_db_exit(project_id)
+        if msg:
             return None, msg
-        budget_store = BudgetStore(db_session)
-        items = budget_store.get_top_budget_items(budget_id)
+        items = self._budget_store.get_top_budget_items(project_id,budget_id)
         return [TotalBudgetItemDto.from_model(item).to_dict() for item in items],""
 
     def get_budget_items(self, budget_id: str, project_id: str, item_code:str):
         if not budget_id:
             return None,'budget_id不能为空'
-        db_session,msg = self._create_project_db_session(project_id)
-        if not db_session:
+        msg = self._check_project_db_exit(project_id)
+        if msg:
             return None, msg
-        budget_store = BudgetStore(db_session)
         data_list = []
         if not item_code:
             team_item_code = None
@@ -68,26 +62,23 @@ class ProjectService:
                 team_item_code_str = self._project_store.get_team_project_item_code(project_id,current_user.username)
                 if team_item_code_str:
                     team_item_code = None if team_item_code_str == 'None' or team_item_code_str == '0' or team_item_code_str == ''  else team_item_code_str.split(',')
-            items =  budget_store.get_top_budget_items(budget_id,team_item_code)
+            items =  self._budget_store.get_top_budget_items(project_id,budget_id,team_item_code)
         else:
-            items = budget_store.get_child_budget_items(budget_id,item_code)
+            items = self._budget_store.get_child_budget_items(project_id,budget_id,item_code)
         parent = "#"
         if item_code:
-            item = budget_store.get_budget_item_by_item_code(budget_id,item_code)
+            item = self._budget_store.get_budget_item_by_item_code(project_id,budget_id,item_code)
             parent = item.item_id
         for item in items:
             text = f"第{item.chapter}章、{item.project_name}" if item.chapter else ( f"{item.section}  {item.project_name}" if item.section else item.project_name)
             data_list.append(TreeDto(item.item_id,parent,text,item.children_count>0,item).to_dict())
         return data_list,""
 
-    def _create_project_db_session(self,project_id:str):
+    def _check_project_db_exit(self, project_id:str):
         if not project_id:
-            return None,'project_id不能为空'
+            return 'project_id不能为空'
         if not self._project_store.get(project_id):
-            return None,'项目不存在'
+            return '项目不存在'
         if not project_id.startswith('Reco'):
-            return None,'项目id格式错误'
-        db_session = db_helper.create_sqlServer_session(project_id)
-        if not db_session:
-            return None,'数据库连接失败'
-        return db_session,""
+            return '项目id格式错误'
+        return None

+ 2 - 1
SourceCode/IntelligentRailwayCosting/app/services/project_quota.py

@@ -76,6 +76,8 @@ class ProjectQuotaService:
                 quota_dto = self.create_quota(quota_dto)
             else:
                 quota_dto = self.update_quota(quota_dto)
+                self.update_process_status(quota_dto.id,4)
+                self.update_send_status(quota_dto.id,4)
             if need_process:
                 self.start_process(quota_dto.id)
             return quota_dto
@@ -115,7 +117,6 @@ class ProjectQuotaService:
             # 业务验证
             if not quota_dto.id:
                 raise ValueError("定额ID不能为空")
-
             return self.store.update_quota(quota_dto)
         except Exception as e:
             self._logger.error(f"更新项目定额失败: {str(e)}")

+ 4 - 17
SourceCode/IntelligentRailwayCosting/app/services/project_task.py

@@ -140,7 +140,7 @@ class ProjectTaskService:
                         delete_paths.append(target_path)
                 if len(delete_paths) > 0:
                     LogRecordHelper.log_success(OperationType.DELETE, OperationModule.TASK,
-                                          f"删除任务文件:{task.sub_project_name}", utils.to_str(delete_paths))
+                                          f"删除任务文件:{task.task_name}", utils.to_str(delete_paths))
         file_paths = [] if delete_old or not task.file_path else task.file_path.split(',')
         if files and len(files) > 0:
             for file in files:
@@ -271,7 +271,6 @@ class ProjectTaskService:
                 return '没有上传文件'
             if task.collect_status == 1:
                 return  '正在采集中'
-            self.update_collect_status(task_id,1)
             thread = threading.Thread(target=self._collect_task, args=(task,))
             thread.start()
             return None
@@ -281,10 +280,7 @@ class ProjectTaskService:
     def _collect_task(self,task:ProjectTaskDto):
         try:
             msg = executor.collect_task(task)
-            if msg:
-                self.update_collect_status(task.id,3,msg)
-            else:
-                self.update_collect_status(task.id,2)
+            if not msg:
                 self.start_process(task.id)
         except Exception as e:
             self._logger.error(f"采集项目任务失败: {str(e)}")
@@ -297,7 +293,6 @@ class ProjectTaskService:
                 return '还未采集完成'
             if task.process_status == 1:
                 return  '正在处理中'
-            self.update_process_status(task_id,1)
             thread = threading.Thread(target=self._process_task, args=(task,))
             thread.start()
             return None
@@ -307,10 +302,7 @@ class ProjectTaskService:
     def _process_task(self,task:ProjectTaskDto):
         try:
            msg = executor.process_task(task)
-           if msg:
-                self.update_process_status(task.id,3,msg)
-           else:
-                self.update_process_status(task.id,2)
+           if not msg:
                 self.start_send(task.id)
         except Exception as e:
             self._logger.error(f"处理项目任务失败: {str(e)}")
@@ -322,7 +314,6 @@ class ProjectTaskService:
                 return '还未处理完成'
             if task.send_status == 1:
                 return  '正在发送中'
-            self.update_send_status(task_id,1)
             thread = threading.Thread(target=self._send_task, args=(task,))
             thread.start()
             return None
@@ -330,11 +321,7 @@ class ProjectTaskService:
             return '没有查询到任务'
     def _send_task(self,task:ProjectTaskDto):
         try:
-            msg = executor.send_task(task)
-            if msg:
-                self.update_send_status(task.id,3,msg)
-            else:
-                self.update_send_status(task.id,2)
+            executor.send_task(task)
         except Exception as e:
             self._logger.error(f"发送项目任务失败: {str(e)}")
             raise

+ 100 - 86
SourceCode/IntelligentRailwayCosting/app/stores/budget.py

@@ -1,47 +1,51 @@
-from typing import Optional
 from sqlalchemy.orm import Session, aliased
-from sqlalchemy import and_, or_, asc, case, func
+from sqlalchemy import and_, or_,  func
 
 from core.models import TotalBudgetInfoModel,TotalBudgetItemModel,ChapterModel
-from core.dtos import TotalBudgetInfoDto, TotalBudgetItemDto
+from tools import db_helper
 
 
 class BudgetStore:
-    def __init__(self, db_session: Session ):
-        if db_session is None:
-            raise Exception("db_session is None")
-        self.db_session = db_session
-
-    def get_budget_info(self):
-        budgets = self.db_session.query(TotalBudgetInfoModel).all()
-        if budgets is None:
-            return None
-        return budgets
-
-    def get_budget_item_by_item_code(self, budget_id: str,item_code: str):
-        budget = self.db_session.query( 
-            TotalBudgetItemModel.budget_id,
-            TotalBudgetItemModel.item_id,
-            ChapterModel.item_code,
-            ChapterModel.chapter,
-            ChapterModel.section,
-            ChapterModel.project_name,
-            ChapterModel.item_type,
-            ChapterModel.unit,)\
-            .join(ChapterModel,ChapterModel.item_id == TotalBudgetItemModel.item_id)\
-            .filter(and_(TotalBudgetItemModel.budget_id == budget_id,ChapterModel.item_code == item_code))\
-            .first()
-        if budget is None:
-            return None
-        return budget
-
-    def _build_budget_items_query(self, budget_id: str):
+    def __init__(self):
+        self._database = None
+        self._db_session = None
+        pass
+
+
+    def get_budget_info(self, project_id: str):
+        self._database=project_id
+        with db_helper.sqlserver_query_session(self._database) as db_session:
+            budgets = db_session.query(TotalBudgetInfoModel).all()
+            if budgets is None:
+                return None
+            return budgets
+
+    def get_budget_item_by_item_code(self, project_id: str, budget_id: str,item_code: str):
+        self._database=project_id
+        with db_helper.sqlserver_query_session(self._database) as db_session:
+            budget = db_session.query(
+                TotalBudgetItemModel.budget_id,
+                TotalBudgetItemModel.item_id,
+                ChapterModel.item_code,
+                ChapterModel.chapter,
+                ChapterModel.section,
+                ChapterModel.project_name,
+                ChapterModel.item_type,
+                ChapterModel.unit,)\
+                .join(ChapterModel,ChapterModel.item_id == TotalBudgetItemModel.item_id)\
+                .filter(and_(TotalBudgetItemModel.budget_id == budget_id,ChapterModel.item_code == item_code))\
+                .first()
+            if budget is None:
+                return None
+            return budget
+
+    def _build_children_count_subquery(self, model_class):
         # 创建父节点和子节点的别名
-        parent = aliased(ChapterModel, name='parent')
-        child = aliased(ChapterModel, name='child')
+        parent = aliased(model_class, name='parent')
+        child = aliased(model_class, name='child')
 
         # 子查询:计算每个节点的直接子节点数量
-        children_count = self.db_session.query(
+        return self.db_session.query(
             parent.item_code.label('parent_code'),
             func.count(child.item_code).label('child_count')
         ).outerjoin(
@@ -54,6 +58,10 @@ class BudgetStore:
             )
         ).group_by(parent.item_code).subquery()
 
+    def _build_budget_items_query(self, budget_id: int):
+        # 子查询:计算每个节点的直接子节点数量
+        children_count = self._build_children_count_subquery(ChapterModel)
+
         return (self.db_session.query(
             TotalBudgetItemModel.budget_id,
             TotalBudgetItemModel.item_id,
@@ -64,65 +72,71 @@ class BudgetStore:
             ChapterModel.item_type,
             ChapterModel.unit,
             func.coalesce(children_count.c.child_count, 0).label('children_count')
-        ).distinct()
+        )
         .join(ChapterModel, ChapterModel.item_id == TotalBudgetItemModel.item_id)
         .outerjoin(children_count, children_count.c.parent_code == ChapterModel.item_code)
-        ).filter(TotalBudgetItemModel.budget_id == budget_id)
+        .filter(TotalBudgetItemModel.budget_id == budget_id)
+        .distinct()
+        )
 
-    def get_top_budget_items(self, budget_id: str, item_code: list[str]=None):
-        query = self._build_budget_items_query(budget_id)
+    def get_top_budget_items(self, project_id: str, budget_id: int, item_code: list[str]=None):
+        self._database=project_id
+        with db_helper.sqlserver_query_session(self._database) as self.db_session:
+            query = self._build_budget_items_query(budget_id)
 
-        if item_code:
-            query = query.filter(ChapterModel.item_code.in_(item_code))
-        else:
-            query = query.filter(ChapterModel.item_code.like('__'))\
-                .filter(ChapterModel.chapter.is_not(None))
-        query = query.order_by(ChapterModel.item_code)
-        items = query.all()
-        return items
+            if item_code:
+                query = query.filter(ChapterModel.item_code.in_(item_code))
+            else:
+                query = query.filter(ChapterModel.item_code.like('__'))\
+                    .filter(ChapterModel.chapter.is_not(None))
+            query = query.order_by(ChapterModel.item_code)
+            items = query.all()
+            return items
 
-    def get_child_budget_items(self, budget_id: str, parent_item_code: str):
+    def get_child_budget_items(self, project_id: str, budget_id: int, parent_item_code: str):
         # 构建子节点的模式:支持两种格式
         # 1. 父级编号后跟-和两位数字(如:01-01)
         # 2. 父级编号直接跟两位数字(如:0101)
         pattern_with_dash = f'{parent_item_code}-__'
         pattern_without_dash = f'{parent_item_code}__'
-        
-        query = self._build_budget_items_query(budget_id)\
-            .filter(or_(ChapterModel.item_code.like(pattern_with_dash),
-                       ChapterModel.item_code.like(pattern_without_dash)))\
-            .order_by(ChapterModel.item_code)
-        items = query.all()
-        return items
-
-    def get_top_budget_items_by_budget_id(self, budget_id: str):
-        query = ((self.db_session.query(
-            TotalBudgetItemModel.budget_id,
-            TotalBudgetItemModel.item_id,
-            ChapterModel.item_code,
-            ChapterModel.chapter,
-            ChapterModel.section,
-            ChapterModel.project_name,
-            ChapterModel.item_type,
-            ChapterModel.unit,
-        ).distinct().join(ChapterModel, ChapterModel.item_id == TotalBudgetItemModel.item_id))
-             .filter(TotalBudgetItemModel.budget_id == budget_id)
-             .filter(or_(ChapterModel.item_code==0,ChapterModel.chapter is not None,ChapterModel.section is not None))
-        )
-        items = query.order_by(TotalBudgetItemModel.item_id.asc()).all()
-        return items
+        self._database = project_id
+        with db_helper.sqlserver_query_session(self._database) as self.db_session:
+            query = self._build_budget_items_query(budget_id)\
+                .filter(or_(ChapterModel.item_code.like(pattern_with_dash),
+                           ChapterModel.item_code.like(pattern_without_dash)))\
+                .order_by(ChapterModel.item_code)
+            items = query.all()
+            return items
+
+    def get_all_budget_items_not_children(self, project_id: str, budget_id: int, item_code: str):
+        self._database = project_id
+        with db_helper.sqlserver_query_session(self._database) as self.db_session:
+            # 添加叶子节点过滤条件,使用复用的子查询方法
+            children_count = self._build_children_count_subquery(ChapterModel)
+            query =  (self.db_session.query(
+                    TotalBudgetItemModel.budget_id,
+                    TotalBudgetItemModel.item_id,
+                    ChapterModel.item_code,
+                    ChapterModel.chapter,
+                    ChapterModel.section,
+                    ChapterModel.project_name,
+                    ChapterModel.item_type,
+                    ChapterModel.unit,
+                    func.coalesce(children_count.c.child_count, 0).label('children_count')
+                )
+                .join(ChapterModel, ChapterModel.item_id == TotalBudgetItemModel.item_id)
+                .outerjoin(children_count, children_count.c.parent_code == ChapterModel.item_code)
+                .filter(TotalBudgetItemModel.budget_id == budget_id)
+                .distinct()
+                )
+
+            query = query.filter(func.coalesce(children_count.c.child_count, 0) == 0)
+            # 如果指定了item_code,添加前缀匹配条件
+            if item_code:
+                pattern_with_dash = f'{item_code}-%'
+                query = query.filter(ChapterModel.item_code.like(pattern_with_dash))
+            query = query.order_by(ChapterModel.item_code)
+            items = query.all()
+            return items
+
 
-    def get_budget_items_children(self, budget_id: str, item_code: str):
-        query = ((self.db_session.query(
-            TotalBudgetItemModel.budget_id,
-            TotalBudgetItemModel.item_id,
-            ChapterModel.item_code,
-            ChapterModel.chapter,
-            ChapterModel.section,
-            ChapterModel.project_name,
-            ChapterModel.item_type,
-            ChapterModel.unit,
-        ).distinct().join(ChapterModel, ChapterModel.item_id == TotalBudgetItemModel.item_id))
-             .filter(TotalBudgetItemModel.budget_id == budget_id)
-             .filter(or_(ChapterModel.item_code.startswith(item_code)))
-        )

+ 43 - 43
SourceCode/IntelligentRailwayCosting/app/stores/log.py

@@ -1,14 +1,13 @@
 from typing import List, Optional, Dict, Any
 from datetime import datetime
 from sqlalchemy import and_,  desc
-from sqlalchemy.orm import Session
 import tools.db_helper as db_helper
 
 from core.models import LogModel
 
 class LogStore:
-    def __init__(self, db_session: Session=None):
-        self.db_session = db_session or db_helper.create_mysql_session()
+    def __init__(self):
+        self._database= None
 
     def query_logs_paginated(
         self,
@@ -33,39 +32,41 @@ class LogStore:
         :param end_time: 结束时间
         :return: 包含总记录数和日志列表的字典
         """
-        query = self.db_session.query(LogModel)
-        
-        # 构建查询条件
-        conditions = []
-        if username:
-            conditions.append(LogModel.username.like(f'%{username}%'))
-        if operation_type:
-            conditions.append(LogModel.operation_type == operation_type)
-        if operation_module:
-            conditions.append(LogModel.operation_module == operation_module)
-        if operation_result is not None:
-            conditions.append(LogModel.operation_result == operation_result)
-        if start_time:
-            conditions.append(LogModel.created_at >= start_time)
-        if end_time:
-            conditions.append(LogModel.created_at < end_time)
-        
-        if conditions:
-            query = query.filter(and_(*conditions))
-        
-        # 计算总记录数
-        total = query.count()
-        
-        # 分页并按创建时间倒序排序
-        logs = query.order_by(desc(LogModel.created_at))\
-            .offset((page - 1) * page_size)\
-            .limit(page_size)\
-            .all()
-        
-        return {
-            'total': total,
-            'data': logs
-        }
+        with db_helper.mysql_query_session(self._database) as db_session:
+            pass
+            query = db_session.query(LogModel)
+
+            # 构建查询条件
+            conditions = []
+            if username:
+                conditions.append(LogModel.username.like(f'%{username}%'))
+            if operation_type:
+                conditions.append(LogModel.operation_type == operation_type)
+            if operation_module:
+                conditions.append(LogModel.operation_module == operation_module)
+            if operation_result is not None:
+                conditions.append(LogModel.operation_result == operation_result)
+            if start_time:
+                conditions.append(LogModel.created_at >= start_time)
+            if end_time:
+                conditions.append(LogModel.created_at < end_time)
+
+            if conditions:
+                query = query.filter(and_(*conditions))
+
+            # 计算总记录数
+            total = query.count()
+
+            # 分页并按创建时间倒序排序
+            logs = query.order_by(desc(LogModel.created_at))\
+                .offset((page - 1) * page_size)\
+                .limit(page_size)\
+                .all()
+
+            return {
+                'total': total,
+                'data': logs
+            }
     
     def insert_log(
         self,
@@ -100,10 +101,9 @@ class LogStore:
             data_changes=data_changes,
             operation_ip=operation_ip
         )
-        
-        self.db_session.add(log)
-        self.db_session.commit()
-        return log
+        with db_helper.mysql_session(self._database) as db_session:
+            db_session.add(log)
+            return log
     
     def batch_insert_logs(self, logs: List[Dict[str, Any]]) -> List[LogModel]:
         """
@@ -112,6 +112,6 @@ class LogStore:
         :return: 创建的日志记录列表
         """
         log_models = [LogModel(**log) for log in logs]
-        self.db_session.add_all(log_models)
-        self.db_session.commit()
-        return log_models
+        with db_helper.mysql_session(self._database) as db_session:
+            db_session.add_all(log_models)
+            return log_models

+ 62 - 62
SourceCode/IntelligentRailwayCosting/app/stores/project.py

@@ -1,5 +1,4 @@
 from sqlalchemy import and_, or_
-from sqlalchemy.orm import Session
 from datetime import datetime
 from typing import Optional
 
@@ -10,8 +9,8 @@ from core.models.team import TeamModel
 from core.user_session import UserSession
 
 class ProjectStore:
-    def __init__(self, db_session: Session = None):
-        self.db_session = db_session or db_helper.create_sqlServer_session()
+    def __init__(self):
+        self._database= None
 
     def get_user_projects_paginated(
         self,
@@ -36,71 +35,72 @@ class ProjectStore:
         Returns:
             Tuple[total_count, projects]
         """
+
         # 构建基础查询
-        query = (self.db_session.query(
-            ProjectModel.project_id,
-            ProjectModel.project_name,
-            ProjectModel.project_manager,
-            ProjectModel.design_stage,
-            ProjectModel.project_description,
-            ProjectModel.short_name,
-            ProjectModel.project_version,
-            ProjectModel.project_type,
-            ProjectModel.unit,
-            ProjectModel.create_time,
-        ) .distinct())
-        user = UserSession.get_current_user()
-        if not user.is_admin:
-            query = query.outerjoin(TeamModel, ProjectModel.project_id == TeamModel.project_id)
-            if can_edit:
-                query = query.filter(
-                    or_(ProjectModel.project_manager == user.username,
-                        and_(TeamModel.name == user.username, TeamModel.compilation_status == can_edit)
+        with db_helper.sqlserver_query_session(self._database) as db_session:
+            query = (db_session.query(
+                ProjectModel.project_id,
+                ProjectModel.project_name,
+                ProjectModel.project_manager,
+                ProjectModel.design_stage,
+                ProjectModel.project_description,
+                ProjectModel.short_name,
+                ProjectModel.project_version,
+                ProjectModel.project_type,
+                ProjectModel.unit,
+                ProjectModel.create_time,
+            ) .distinct())
+            user = UserSession.get_current_user()
+            if not user.is_admin:
+                query = query.outerjoin(TeamModel, ProjectModel.project_id == TeamModel.project_id)
+                if can_edit:
+                    query = query.filter(
+                        or_(ProjectModel.project_manager == user.username,
+                            and_(TeamModel.name == user.username, TeamModel.compilation_status == can_edit)
+                            )
                         )
-                    )
-            else:
-                query = query.filter(or_(ProjectModel.project_manager == user.username,TeamModel.name == user.username))
-        
-        # 添加编辑权限过滤
-        # if can_edit:
-        #     query = query.filter(or_(ProjectModel.project_manager == user.username,TeamModel.compilation_status == 1))
-        #
-        # 添加过滤条件
-        if keyword:
-            query = query.filter(or_(
-                ProjectModel.project_id.like(f'%{keyword}%'),
-                ProjectModel.project_name.like(f'%{keyword}%'),
-                ProjectModel.project_manager.like(f'%{keyword}%'),
-                ProjectModel.project_description.like(f'%{keyword}%'),
-                ProjectModel.short_name.like(f'%{keyword}%')
-            ))
-        
-        if start_time:
-            query = query.filter(ProjectModel.create_time >= start_time)
-            
-        if end_time:
-            query = query.filter(ProjectModel.create_time < end_time)
-            
-        # 获取总记录数
-        total_count = query.count()
-        
-        # 分页并按创建时间倒序排序
-        projects = query.order_by(ProjectModel.create_time.desc())\
-            .offset((page - 1) * page_size)\
-            .limit(page_size)\
-            .all()
+                else:
+                    query = query.filter(or_(ProjectModel.project_manager == user.username,TeamModel.name == user.username))
+
+            # 添加过滤条件
+            if keyword:
+                query = query.filter(or_(
+                    ProjectModel.project_id.like(f'%{keyword}%'),
+                    ProjectModel.project_name.like(f'%{keyword}%'),
+                    ProjectModel.project_manager.like(f'%{keyword}%'),
+                    ProjectModel.project_description.like(f'%{keyword}%'),
+                    ProjectModel.short_name.like(f'%{keyword}%')
+                ))
+
+            if start_time:
+                query = query.filter(ProjectModel.create_time >= start_time)
+
+            if end_time:
+                query = query.filter(ProjectModel.create_time < end_time)
+
+            # 获取总记录数和数据
+            total_count = query.count()
+            projects = query.order_by(ProjectModel.create_time.desc())\
+                .offset((page - 1) * page_size)\
+                .limit(page_size)\
+                .all()
+
+            return {
+                'total': total_count,
+                'data': projects
+            }
 
-        return {
-            'total': total_count,
-            'data': projects
-        }
 
     def get_team_project_item_code(self,project_id:str, user_name:str):
-        data = self.db_session.query(TeamModel.item_code).filter(and_(TeamModel.project_id == project_id,TeamModel.name == user_name)).first()
-        return data[0] if data else None
+        with db_helper.sqlserver_query_session(self._database) as session:
+            db_session = session
+            data = db_session.query(TeamModel.item_code).filter(and_(TeamModel.project_id == project_id,TeamModel.name == user_name)).first()
+            return data[0] if data else None
 
 
     def get(self,project_id:str):
-        data = self.db_session.query(ProjectModel).filter(ProjectModel.project_id == project_id).first()
-        return ProjectDto.from_model(data).to_dict()
+        with db_helper.sqlserver_query_session(self._database) as session:
+            db_session = session
+            data = db_session.query(ProjectModel).filter(ProjectModel.project_id == project_id).first()
+            return ProjectDto.from_model(data).to_dict()
 

+ 95 - 97
SourceCode/IntelligentRailwayCosting/app/stores/project_quota.py

@@ -1,7 +1,6 @@
 from sqlalchemy import and_, or_
-from sqlalchemy.orm import Session
 from datetime import datetime
-from typing import Optional, List, Tuple
+from typing import Optional
 
 import tools.db_helper as db_helper
 from core.dtos import ProjectQuotaDto
@@ -9,8 +8,8 @@ from core.models import ProjectQuotaModel
 from core.user_session import UserSession
 
 class ProjectQuotaStore:
-    def __init__(self, db_session: Session = None):
-        self.db_session = db_helper.create_mysql_session()
+    def __init__(self):
+        self._database = None
         self._current_user = None
 
     @property
@@ -30,8 +29,6 @@ class ProjectQuotaStore:
         process_status: Optional[int] = None,
         send_status: Optional[int] = None,
     ):
-        # 每次查询创建新的会话
-        session = db_helper.create_mysql_session()
         """分页查询定额列表
 
         Args:
@@ -47,8 +44,8 @@ class ProjectQuotaStore:
         Returns:
             Tuple[total_count, quotas]
         """
-        try:
-            query = session.query(ProjectQuotaModel)
+        with db_helper.mysql_query_session(self._database) as db_session:
+            query = db_session.query(ProjectQuotaModel)
 
             # 构建查询条件
             conditions = [
@@ -57,16 +54,16 @@ class ProjectQuotaStore:
                 ProjectQuotaModel.budget_id == budget_id,
                 ProjectQuotaModel.item_code.like(f"{item_code}%")
             ]
+
+            if process_status is not None:
+                conditions.append(ProjectQuotaModel.process_status == process_status)
+            if send_status is not None:
+                conditions.append(ProjectQuotaModel.send_status == send_status)
             if keyword:
                 conditions.append(or_(
                     ProjectQuotaModel.quota_code.like(f"%{keyword}%"),
                     ProjectQuotaModel.project_name.like(f"%{keyword}%"),
                 ))
-            if process_status is not None:
-                conditions.append(ProjectQuotaModel.process_status == process_status)
-            if send_status is not None:
-                conditions.append(ProjectQuotaModel.send_status == send_status)
-
             query = query.filter(and_(*conditions))
 
             # 计算总数
@@ -81,20 +78,20 @@ class ProjectQuotaStore:
                 'total': total_count,
                 'data': quotas
             }
-        finally:
-            session.close()
+
 
     def get_quotas_by_task_id(self,task_id:int, with_quota_code:bool=False):
-        query = self.db_session.query(ProjectQuotaModel).filter(
-            and_(
-                ProjectQuotaModel.task_id == task_id,
-                ProjectQuotaModel.is_del == 0
+        with db_helper.mysql_query_session(self._database) as db_session:
+            query = db_session.query(ProjectQuotaModel).filter(
+                and_(
+                    ProjectQuotaModel.task_id == task_id,
+                    ProjectQuotaModel.is_del == 0
+                )
             )
-        )
-        if with_quota_code:
-            query = query.filter(and_(ProjectQuotaModel.quota_code!=None,ProjectQuotaModel.quota_code!='') )
-        quotas = query.all()
-        return quotas
+            if with_quota_code:
+                query = query.filter(and_(ProjectQuotaModel.quota_code!=None,ProjectQuotaModel.quota_code!='') )
+            quotas = query.all()
+            return quotas
 
 
 
@@ -107,13 +104,14 @@ class ProjectQuotaStore:
         Returns:
             Optional[ProjectQuotaDto]
         """
-        quota = self.db_session.query(ProjectQuotaModel).filter(
-            and_(
-                ProjectQuotaModel.id == quota_id,
-                ProjectQuotaModel.is_del == 0
-            )
-        ).first()
-        return quota
+        with db_helper.mysql_query_session(self._database) as db_session:
+            quota = db_session.query(ProjectQuotaModel).filter(
+                and_(
+                    ProjectQuotaModel.id == quota_id,
+                    ProjectQuotaModel.is_del == 0
+                )
+            ).first()
+            return quota
     def get_quota_dto(self, quota_id: int) -> Optional[ProjectQuotaDto]:
         """根据ID查询定额
 
@@ -136,36 +134,36 @@ class ProjectQuotaStore:
         Returns:
             ProjectQuotaDto
         """
-        quota = ProjectQuotaModel(
-            project_id=quota_dto.project_id,
-            budget_id=quota_dto.budget_id,
-            item_id=quota_dto.item_id,
-            item_code=quota_dto.item_code,
-            quota_code=quota_dto.quota_code,
-            project_name=quota_dto.project_name,
-            unit=quota_dto.unit,
-            project_quantity=quota_dto.project_quantity,
-            project_quantity_input=quota_dto.project_quantity_input,
-            quota_adjustment=quota_dto.quota_adjustment,
-            unit_price=quota_dto.unit_price,
-            total_price=quota_dto.total_price,
-            unit_weight=quota_dto.unit_weight,
-            total_weight=quota_dto.total_weight,
-            labor_cost=quota_dto.labor_cost,
-            process_status=quota_dto.process_status,
-            process_time=quota_dto.process_time,
-            process_error=quota_dto.process_error,
-            send_status=quota_dto.send_status,
-            send_time=quota_dto.send_time,
-            send_error=quota_dto.send_error,
-            created_by=self.current_user.username,
-            created_at=datetime.now(),
-        )
-
-        self.db_session.add(quota)
-        self.db_session.commit()
-
-        return ProjectQuotaDto.from_model(quota)
+        with db_helper.mysql_session(self._database) as db_session:
+            quota = ProjectQuotaModel(
+                project_id=quota_dto.project_id,
+                budget_id=quota_dto.budget_id,
+                item_id=quota_dto.item_id,
+                item_code=quota_dto.item_code,
+                quota_code=quota_dto.quota_code,
+                project_name=quota_dto.project_name,
+                unit=quota_dto.unit,
+                project_quantity=quota_dto.project_quantity,
+                project_quantity_input=quota_dto.project_quantity_input,
+                quota_adjustment=quota_dto.quota_adjustment,
+                unit_price=quota_dto.unit_price,
+                total_price=quota_dto.total_price,
+                unit_weight=quota_dto.unit_weight,
+                total_weight=quota_dto.total_weight,
+                labor_cost=quota_dto.labor_cost,
+                process_status=quota_dto.process_status,
+                process_time=quota_dto.process_time,
+                process_error=quota_dto.process_error,
+                send_status=quota_dto.send_status,
+                send_time=quota_dto.send_time,
+                send_error=quota_dto.send_error,
+                created_by=self.current_user.username,
+                created_at=datetime.now(),
+            )
+
+            db_session.add(quota)
+            db_session.flush()
+            return ProjectQuotaDto.from_model(quota)
 
     def update_quota(self, quota_dto: ProjectQuotaDto) -> Optional[ProjectQuotaDto]:
         """更新定额
@@ -180,23 +178,23 @@ class ProjectQuotaStore:
 
         if not quota:
             return None
-        quota.quota_code = quota_dto.quota_code
-        quota.project_name = quota_dto.project_name
-        quota.unit = quota_dto.unit
-        quota.project_quantity = quota_dto.project_quantity
-        quota.project_quantity_input = quota_dto.project_quantity_input
-        quota.quota_adjustment = quota_dto.quota_adjustment
-        quota.unit_price = quota_dto.unit_price
-        quota.total_price = quota_dto.total_price
-        quota.unit_weight = quota_dto.unit_weight
-        quota.total_weight = quota_dto.total_weight
-        quota.labor_cost = quota_dto.labor_cost
-        quota.updated_by = self.current_user.username
-        quota.updated_at = datetime.now()
-
-        self.db_session.commit()
-
-        return ProjectQuotaDto.from_model(quota)
+        with db_helper.mysql_session(self._database) as db_session:
+            quota.quota_code = quota_dto.quota_code
+            quota.project_name = quota_dto.project_name
+            quota.unit = quota_dto.unit
+            quota.project_quantity = quota_dto.project_quantity
+            quota.project_quantity_input = quota_dto.project_quantity_input
+            quota.quota_adjustment = quota_dto.quota_adjustment
+            quota.unit_price = quota_dto.unit_price
+            quota.total_price = quota_dto.total_price
+            quota.unit_weight = quota_dto.unit_weight
+            quota.total_weight = quota_dto.total_weight
+            quota.labor_cost = quota_dto.labor_cost
+            quota.updated_by = self.current_user.username
+            quota.updated_at = datetime.now()
+
+            quota = db_session.merge(quota)
+            return ProjectQuotaDto.from_model(quota)
 
     def delete_quota(self, quota_id: int) -> bool:
         """删除定额
@@ -207,21 +205,17 @@ class ProjectQuotaStore:
         Returns:
             bool
         """
-        quota = self.db_session.query(ProjectQuotaModel).filter(
-            and_(
-                ProjectQuotaModel.id == quota_id,
-                ProjectQuotaModel.is_del == 0
-            )
-        ).first()
 
+        quota = self.get_quota(quota_id)
         if not quota:
             return False
 
-        quota.is_del = 1
-        quota.deleted_by = self.current_user.username
-        quota.deleted_at = datetime.now()
-        self.db_session.commit()
-        return True
+        with db_helper.mysql_session(self._database) as db_session:
+            quota.is_del = 1
+            quota.deleted_by = self.current_user.username
+            quota.deleted_at = datetime.now()
+            quota = db_session.merge(quota)
+            return True
 
     def update_process_status(self,quota_id:int, status:int, err:str = None):
         """
@@ -236,10 +230,12 @@ class ProjectQuotaStore:
         quota = self.get_quota(quota_id)
         if not quota:
             return False
-        quota.process_status = status
-        quota.process_error = err
-        quota.process_time = datetime.now()
-        self.db_session.commit()
+        with db_helper.mysql_session(self._database) as db_session:
+            quota.process_status = status
+            quota.process_error = err
+            quota.process_time = datetime.now()
+            quota = db_session.merge(quota)
+            return True
 
     def update_send_status(self,quota_id:int, status:int, err:str = None) -> bool:
         """
@@ -254,8 +250,10 @@ class ProjectQuotaStore:
         quota = self.get_quota(quota_id)
         if not quota:
             return False
-        quota.send_status = status
-        quota.send_error = err
-        quota.send_time = datetime.now()
-        self.db_session.commit()
+        with db_helper.mysql_session(self._database) as db_session:
+            quota.send_status = status
+            quota.send_error = err
+            quota.send_time = datetime.now()
+            quota = db_session.merge(quota)
+            return True
 

+ 67 - 65
SourceCode/IntelligentRailwayCosting/app/stores/project_task.py

@@ -1,5 +1,4 @@
 from sqlalchemy import and_
-from sqlalchemy.orm import Session
 from datetime import datetime
 from typing import Optional
 
@@ -10,9 +9,9 @@ from core.user_session import UserSession
 
 
 class ProjectTaskStore:
-    def __init__(self, db_session: Session = None):
-        self.db_session = db_session or db_helper.create_mysql_session()
+    def __init__(self):
         self._current_user = None
+        self._database=None
 
     @property
     def current_user(self):
@@ -48,8 +47,8 @@ class ProjectTaskStore:
         Returns:
 
         """
-        try:
-            query = self.db_session.query(ProjectTaskModel)
+        with db_helper.mysql_query_session(self._database) as db_session:
+            query = db_session.query(ProjectTaskModel)
 
             # 构建查询条件
             conditions = [
@@ -81,19 +80,15 @@ class ProjectTaskStore:
                 'total': total_count,
                 'data': tasks
             }
-        except Exception as e:
-            self.db_session.rollback()
-            raise e
-        finally:
-            self.db_session.close()
 
     def get_task(self, task_id: int) -> Optional[ProjectTaskModel]:
-        task = self.db_session.query(ProjectTaskModel).filter(
-            and_(
-                ProjectTaskModel.id == task_id,
-                ProjectTaskModel.is_del == 0
-            )).first()
-        return task
+        with db_helper.mysql_query_session(self._database) as db_session:
+            task = db_session.query(ProjectTaskModel).filter(
+                and_(
+                    ProjectTaskModel.id == task_id,
+                    ProjectTaskModel.is_del == 0
+                )).first()
+            return task
 
 
     def get_task_dto(self, task_id: int) -> Optional[ProjectTaskDto]:
@@ -129,10 +124,10 @@ class ProjectTaskStore:
             created_at=datetime.now(),
         )
 
-        self.db_session.add(task)
-        self.db_session.commit()
-
-        return ProjectTaskDto.from_model(task)
+        with db_helper.mysql_session(self._database) as db_session:
+            db_session.add(task)
+            db_session.flush()
+            return ProjectTaskDto.from_model(task)
 
     def update_task(self, task_dto: ProjectTaskDto) -> Optional[ProjectTaskDto]:
         """更新任务
@@ -147,33 +142,37 @@ class ProjectTaskStore:
 
         if not task:
             return None
-
-        task.task_name = task_dto.task_name
-        task.task_desc = task_dto.task_desc
-        # task.project_id = task_dto.project_id
-        # task.budget_id = task_dto.budget_id
-        # task.item_id = task_dto.item_id
-        # task.item_code = task_dto.item_code
-        # task.file_path = task_dto.file_path
-        task.updated_by=self.current_user.username
-        task.updated_at=datetime.now()
-
-        self.db_session.commit()
-
-        return ProjectTaskDto.from_model(task)
+        with db_helper.mysql_session(self._database) as db_session:
+            task.task_name = task_dto.task_name
+            task.task_desc = task_dto.task_desc
+            # task.project_id = task_dto.project_id
+            # task.budget_id = task_dto.budget_id
+            # task.item_id = task_dto.item_id
+            # task.item_code = task_dto.item_code
+            # task.file_path = task_dto.file_path
+            task.updated_by=self.current_user.username
+            task.updated_at=datetime.now()
+            task = db_session.merge(task)
+            return ProjectTaskDto.from_model(task)
 
 
     def update_task_files(self, task_id: int,files: str):
         task = self.get_task(task_id)
         if not task:
             return None
-        task.file_path = files
-        task.collect_status=0
-        task.process_status=0
-        task.send_status=0
-        task.updated_by=self.current_user.username
-        task.updated_at=datetime.now()
-        self.db_session.commit()
+        with db_helper.mysql_session(self._database) as db_session:
+            task.file_path = files
+            if task.collect_status != 0:
+                task.collect_status = 4
+            if task.process_status != 0:
+                task.process_status = 4
+            if task.send_status != 0:
+                task.send_status = 4
+            task.updated_by=self.current_user.username
+            task.updated_at=datetime.now()
+            task = db_session.merge(task)
+            return ProjectTaskDto.from_model(task)
+
     def delete_task(self, task_id: int) -> bool:
         """删除任务
 
@@ -187,44 +186,47 @@ class ProjectTaskStore:
         if not task:
             return False
 
-        task.is_del = 1
-        task.deleted_by = self.current_user.username
-        task.deleted_at = datetime.now()
-
-        self.db_session.commit()
 
-        return True
+        with db_helper.mysql_session(self._database) as db_session:
+            task.is_del = 1
+            task.deleted_by = self.current_user.username
+            task.deleted_at = datetime.now()
+            task = db_session.merge(task)
+            return True
 
     def update_collect_status(self,task_id:int, status:int, err:str = None):
         task = self.get_task(task_id)
         if not task:
             return False
-        task.collect_status = status
-        if err:
-            task.collect_error = err
-        task.collect_time = datetime.now()
-        self.db_session.commit()
-        return True
+        with db_helper.mysql_session(self._database) as db_session:
+            task.collect_status = status
+            if err:
+                task.collect_error = err
+            task.collect_time = datetime.now()
+            task = db_session.merge(task)
+            return True
 
     def update_process_status(self,task_id:int, status:int, err:str = None):
         task = self.get_task(task_id)
         if not task:
             return False
-        task.process_status = status
-        if err:
-            task.process_error = err
-        task.process_time = datetime.now()
-        self.db_session.commit()
-        return True
+        with db_helper.mysql_session(self._database) as db_session:
+            task.process_status = status
+            if err:
+                task.process_error = err
+            task.process_time = datetime.now()
+            task = db_session.merge(task)
+            return True
 
     def update_send_status(self,task_id:int, status:int, err:str = None):
         task = self.get_task(task_id)
         if not task:
             return False
-        task.send_status = status
-        if err:
-            task.send_error = err
-        task.send_time = datetime.now()
-        self.db_session.commit()
-        return True
+        with db_helper.mysql_session(self._database) as db_session:
+            task.send_status = status
+            if err:
+                task.send_error = err
+            task.send_time = datetime.now()
+            task = db_session.merge(task)
+            return True
 

+ 10 - 8
SourceCode/IntelligentRailwayCosting/app/stores/user.py

@@ -1,26 +1,28 @@
-from sqlalchemy.orm import Session
 from typing import Optional, List
 
 import tools.db_helper  as db_helper
 from core.models import UserModel
 
 class UserStore:
-    def __init__(self, db_session: Session = None):
-        self.db_session = db_session or db_helper.create_sqlServer_session()
-
+    def __init__(self):
+        self._database = None
 
     def get_user_by_id(self, user_id: int) -> Optional[UserModel]:
         """根据用户ID获取用户信息"""
-        return self.db_session.query(UserModel).filter(UserModel.id == user_id).first()
+        with db_helper.sqlserver_query_session(self._database) as db_session:
+            user = db_session.query(UserModel).filter(UserModel.id == user_id).first()
+            return user
 
     def get_user_by_username(self, username: str) -> Optional[UserModel]:
         """根据用户名获取用户信息"""
-        return self.db_session.query(UserModel).filter(UserModel.username == username).first()
+        with db_helper.sqlserver_query_session(self._database) as db_session:
+            user = db_session.query(UserModel).filter(UserModel.username == username).first()
+            return user
 
     def get_all_users(self) -> List[UserModel]:
         """获取所有用户列表"""
-        return self.db_session.query(UserModel)
-
+        with db_helper.sqlserver_query_session(self._database) as db_session:
+            return db_session.query(UserModel)
 
 
     def authenticate_user(self, username: str, password: str) -> Optional[UserModel]:

+ 96 - 7
SourceCode/IntelligentRailwayCosting/app/tools/db_helper/__init__.py

@@ -1,18 +1,107 @@
 from sqlalchemy.orm import Session
+from contextlib import contextmanager
+from typing import Generator, Optional, Dict, Any
 
 from .mysql_helper import MySQLHelper
 from .sqlserver_helper import SQLServerHelper
 
+# def get_sqlServer_main_db():
+#     return SQLServerHelper().main_database_name
 
-def create_sqlServer_session(database:str=None)->Session:
-    return SQLServerHelper().get_session_maker(database)()
+# def get_mysql_main_db():
+#     return MySQLHelper().main_database_name
 
-def create_mysql_session(database:str=None)->Session:
-    return MySQLHelper().get_session_maker(database)()
+# def create_sqlServer_session(database:str=None)->Session:
+#     return SQLServerHelper().get_session_maker(database)()
+
+# def create_mysql_session(database:str=None)->Session:
+#     return MySQLHelper().get_session_maker(database)()
+
+@contextmanager
+def sqlserver_session(database: str=None, config: Optional[Dict[str, Any]]=None):
+    """SQLServer数据库会话的上下文管理器
+    
+    Args:
+        database: 数据库名称
+        config: 数据库配置信息
+        
+    Yields:
+        数据库会话
+    """
+    session = SQLServerHelper().get_session_maker(database, config)()
+    try:
+        yield session
+        session.commit()
+    except:
+        session.rollback()
+        raise
+    finally:
+        session.close()
+
+@contextmanager
+def mysql_session(database: str=None, config: Optional[Dict[str, Any]]=None):
+    """MySQL数据库会话的上下文管理器
+    
+    Args:
+        database: 数据库名称
+        config: 数据库配置信息
+        
+    Yields:
+        数据库会话
+    """
+    session = MySQLHelper().get_session_maker(database, config)()
+    try:
+        yield session
+        session.commit()
+    except:
+        session.rollback()
+        raise
+    finally:
+        session.close()
+
+@contextmanager
+def sqlserver_query_session(database: str=None, config: Optional[Dict[str, Any]]=None):
+    """SQLServer数据库会话的上下文管理器(只读查询专用)
+    
+    Args:
+        database: 数据库名称
+        config: 数据库配置信息
+        
+    Yields:
+        数据库会话
+    """
+    session = SQLServerHelper().get_session_maker(database, config)()
+    try:
+        yield session
+    finally:
+        session.close()
+
+@contextmanager
+def mysql_query_session(database: str=None, config: Optional[Dict[str, Any]]=None):
+    """MySQL数据库会话的上下文管理器(只读查询专用)
+    
+    Args:
+        database: 数据库名称
+        config: 数据库配置信息
+        
+    Yields:
+        数据库会话
+    """
+    session = MySQLHelper().get_session_maker(database, config)()
+    try:
+        yield session
+    finally:
+        session.close()
 
 __all__ = [
     'MySQLHelper',
     'SQLServerHelper',
-    'create_sqlServer_session',
-    'create_mysql_session'
-]
+    # 'create_sqlServer_session',
+    # 'create_mysql_session',
+    'sqlserver_session',
+    'mysql_session',
+    'sqlserver_query_session',
+    'mysql_query_session'
+]
+
+

+ 3 - 3
SourceCode/IntelligentRailwayCosting/app/tools/db_helper/base.py

@@ -10,7 +10,7 @@ Base = declarative_base()
 class DBHelper:
     _instance = None
     _lock = threading.Lock()
-    _main_database_name = ""
+    main_database_name = ""
     def __new__(cls, *args, **kwargs):
         with cls._lock:
             if cls._instance is None:
@@ -63,7 +63,7 @@ class DBHelper:
             self._config_cache[database] = db_config
             return db_config
         
-        main_config = configs.database[self._main_database_name]
+        main_config = configs.database[self.main_database_name]
         if not main_config:
             raise Exception(f"未找到数据库 {database} 的配置,且main_config配置不存在")
         main_config['database'] = database
@@ -145,7 +145,7 @@ class DBHelper:
         Returns:
             会话工厂实例
         """
-        database = database or self._main_database_name
+        database = database or self.main_database_name
         if database in self._sessions:
             return self._sessions[database]
         

+ 1 - 1
SourceCode/IntelligentRailwayCosting/app/tools/db_helper/mysql_helper.py

@@ -22,7 +22,7 @@ class MySQLHelper(DBHelper):
             'pool_timeout': 30,
             'pool_recycle': 3600
         }
-        self._main_database_name = "mysql_main"
+        self.main_database_name = "mysql_main"
 
     def get_engine(self, database: str, config: Optional[Dict[str, Any]] = None) -> Engine:
         """获取或创建数据库引擎

+ 1 - 1
SourceCode/IntelligentRailwayCosting/app/tools/db_helper/sqlserver_helper.py

@@ -35,7 +35,7 @@ class SQLServerHelper(DBHelper):
             }
         }
 
-        self._main_database_name = f"sqlserver_mian_{configs.app.version}" if configs.app.user_version else "sqlserver_mian"
+        self.main_database_name = f"sqlserver_mian_{configs.app.version}" if configs.app.user_version else "sqlserver_mian"
     def _build_connection_string(self, database: str, config: Optional[Dict[str, str]] = None) -> str:
         """构建连接字符串"""
         conn_config = self._default_config.copy()

+ 23 - 3
SourceCode/IntelligentRailwayCosting/app/views/static/project/budget_info.js

@@ -51,6 +51,7 @@ const nav_tab_template = `
 															<option value="1">采集中</option>
 															<option value="2">已采集</option>
 															<option value="3">采集失败</option>
+															<option value="4">数据变更</option>
 														</select>
 														<select class="form-select form-select-sm me-5" name="process_status">
 															<option value="">全部处理状态</option>
@@ -58,6 +59,7 @@ const nav_tab_template = `
 															<option value="1">处理中</option>
 															<option value="2">已处理</option>
 															<option value="3">处理失败</option>
+															<option value="4">数据变更</option>
 														</select>
 														<select class="form-select form-select-sm me-5" name="send_status">
 															<option value="">全部发送状态</option>
@@ -65,6 +67,7 @@ const nav_tab_template = `
 															<option value="1">发送中</option>
 															<option value="2">已发送</option>
 															<option value="3">发送失败</option>
+															<option value="4">数据变更</option>
 														</select>
 														<input type="text" class="form-control form-control-sm w-200px" placeholder="请输入关键字" name="keyword" />
 													</div>
@@ -320,8 +323,9 @@ function RenderTabCondent(data) {
 								}
 							} else if (row.collect_status === 3){
 								str += `<span class="badge badge-light-danger">采集失败</span>`
+							} else if (row.collect_status === 4){
+								str += `<span class="badge badge-light-info">数据变更</span>`
 							}
-
 							return str
 						}
 					},
@@ -345,12 +349,18 @@ function RenderTabCondent(data) {
 										str += `<button type="button" class="btn btn-icon btn-sm btn-light-warning" data-bs-toggle="tooltip" data-bs-placement="top" title="重新发送" onclick="ReStartSendTask(${row.id}, ${data.budget_id})"><i class="ki-duotone ki-send fs-1"><span class="path1"></span><span class="path2"></span></i></button>`
 									} else if (row.send_status === 3) {
 										str += `<button type="button" class="btn btn-icon btn-sm btn-light-danger" data-bs-toggle="tooltip" data-bs-placement="top" title="重新发送" onclick="ReStartSendTask(${row.id}, ${data.budget_id})"><i class="ki-duotone ki-send fs-1"><span class="path1"></span><span class="path2"></span></i></button>`
+									} else if (row.send_status === 4) {
+										str += `<button type="button" class="btn btn-icon btn-sm btn-light-info" data-bs-toggle="tooltip" data-bs-placement="top" title="重新发送" onclick="ReStartSendTask(${row.id}, ${data.budget_id})"><i class="ki-duotone ki-send fs-1"><span class="path1"></span><span class="path2"></span></i></button>`
 									}
 								} else if (row.process_status === 3) {
 									str += `<button type="button" class="btn btn-icon btn-sm btn-light-danger" data-bs-toggle="tooltip" data-bs-placement="top" title="重新处理" onclick="ReStartProcessTask(${row.id}, ${data.budget_id})"><i class="ki-duotone ki-book-square fs-1"><span class="path1"></span><span class="path2"></span><span class="path3"></span></i></button>`
+								} else if (row.process_status === 4) {
+									str += `<button type="button" class="btn btn-icon btn-sm btn-light-info" data-bs-toggle="tooltip" data-bs-placement="top" title="重新处理" onclick="ReStartProcessTask(${row.id}, ${data.budget_id})"><i class="ki-duotone ki-book-square fs-1"><span class="path1"></span><span class="path2"></span><span class="path3"></span></i></button>`
 								}
 							} else if (row.collect_status === 3) {
 								str += `<button type="button" class="btn btn-icon btn-sm btn-light-danger" data-bs-toggle="tooltip" data-bs-placement="top" title="重新采集" onclick="ReStartCollectTask(${row.id}, ${data.budget_id})"><i class="ki-duotone ki-add-notepad fs-1"><span class="path1"></span><span class="path2"></span><span class="path3"></span><span class="path4"></span></i></button>`
+							}else if (row.collect_status === 4) {
+								str += `<button type="button" class="btn btn-icon btn-sm btn-light-info" data-bs-toggle="tooltip" data-bs-placement="top" title="重新采集" onclick="ReStartCollectTask(${row.id}, ${data.budget_id})"><i class="ki-duotone ki-add-notepad fs-1"><span class="path1"></span><span class="path2"></span><span class="path3"></span><span class="path4"></span></i></button>`
 							}
 							str+=`<button type="button" class="btn btn-icon btn-sm btn-light-primary" data-bs-toggle="tooltip" data-bs-placement="top" title="编辑" onclick="Edit(${row.id})"><i class="ki-duotone ki-message-edit fs-1"><span class="path1"></span><span class="path2"></span></i></button>`
 							str+=`<button type="button" class="btn btn-icon btn-sm btn-light-danger"  data-bs-toggle="tooltip" data-bs-placement="top" title="删除" onclick="Delete(${row.id})"><i class="ki-duotone ki-trash-square fs-1"><span class="path1"></span><span class="path2"></span><span class="path3"></span><span class="path4"></span></i></button>`
@@ -407,6 +417,8 @@ function RenderTabCondent(data) {
 							str+= `<span class="badge badge-success">已处理</span>`
 						}else if (row.process_status === 3){
 							str+= `<span class="badge badge-danger">处理失败</span>`
+						}else if (row.process_status === 4){
+							str+= `<span class="badge badge-danger">数据变更</span>`
 						}
 						if(row.send_status === 0){
 							str+= `<span class="badge badge-primary ms-3">未发送</span>`
@@ -416,6 +428,8 @@ function RenderTabCondent(data) {
 							str+= `<span class="badge badge-success ms-3">已发送</span>`
 						}else if (row.send_status === 3){
 							str+= `<span class="badge badge-danger ms-3">发送失败</span>`
+						}else if (row.send_status === 4){
+							str+= `<span class="badge badge-danger ms-3">数据变更</span>`
 						}
 
 						return str
@@ -437,9 +451,13 @@ function RenderTabCondent(data) {
 								str += `<button type="button" class="btn btn-icon btn-sm btn-light-warning" data-bs-toggle="tooltip" data-bs-placement="top" title="重新发送" onclick="ReStartSendQuota(${row.id}, ${data.budget_id})"><i class="ki-duotone ki-send fs-1"><span class="path1"></span><span class="path2"></span></i></button>`
 							} else if (row.send_status === 3) {
 								str += `<button type="button" class="btn btn-icon btn-sm btn-light-danger" data-bs-toggle="tooltip" data-bs-placement="top" title="重新发送" onclick="ReStartSendQuota(${row.id}, ${data.budget_id})"><i class="ki-duotone ki-send fs-1"><span class="path1"></span><span class="path2"></span></i></button>`
+							} else if (row.send_status === 4) {
+								str += `<button type="button" class="btn btn-icon btn-sm btn-light-info" data-bs-toggle="tooltip" data-bs-placement="top" title="重新发送" onclick="ReStartSendQuota(${row.id}, ${data.budget_id})"><i class="ki-duotone ki-send fs-1"><span class="path1"></span><span class="path2"></span></i></button>`
 							}
 						} else if (row.process_status === 3) {
 							str += `<button type="button" class="btn btn-icon btn-sm btn-light-danger" data-bs-toggle="tooltip" data-bs-placement="top" title="重新处理" onclick="ReStartProcessQuota(${row.id}, ${data.budget_id})"><i class="ki-duotone ki-book-square fs-1"><span class="path1"></span><span class="path2"></span><span class="path3"></span></i></button>`
+						} else if (row.process_status === 4) {
+							str += `<button type="button" class="btn btn-icon btn-sm btn-light-info" data-bs-toggle="tooltip" data-bs-placement="top" title="重新处理" onclick="ReStartProcessQuota(${row.id}, ${data.budget_id})"><i class="ki-duotone ki-book-square fs-1"><span class="path1"></span><span class="path2"></span><span class="path3"></span></i></button>`
 						}
 						str+=`<button type="button" class="btn btn-icon btn-sm btn-light-primary" data-bs-toggle="tooltip" data-bs-placement="top" title="编辑" onclick="Edit_Quota(${row.id})"><i class="ki-duotone ki-message-edit fs-1"><span class="path1"></span><span class="path2"></span></i></button>`
 						str+=`<button type="button" class="btn btn-icon btn-sm btn-light-danger"  data-bs-toggle="tooltip" data-bs-placement="top" title="删除" onclick="Delete_Quota(${row.id})"><i class="ki-duotone ki-trash-square fs-1"><span class="path1"></span><span class="path2"></span><span class="path3"></span><span class="path4"></span></i></button>`
@@ -457,12 +475,13 @@ function Add(budget_id) {
 	AddModal($modal, () => {
 		$modal.find('[name="task_id"]').val('0');
 		$modal.find('#delete_file_box').hide();
-
+		$modal.find('[name="delete_file"]').prop('checked',false)
 		SetBudgetData($modal,budget_id)
 	})
 }
 
 function Edit(id) {
+
 	_fileUploadDropzone.removeAllFiles()
     EditModal($modal,()=>{
         IwbAjax_1({
@@ -482,6 +501,7 @@ function Edit(id) {
 				$modal.find('[name="item_code"]').val(data.item_code);
                 $modal.find('[name="task_name"]').val(data.task_name);
                 $modal.find('[name="task_desc"]').val(data.task_desc);
+				$modal.find('[name="delete_file"]').prop('checked',false)
             }
         })
     })
@@ -504,7 +524,7 @@ function SaveProject() {
 		task_id=  $modal.find('[name="task_id"]').val(),
 		task_name = $modal.find('[name="task_name"]').val(),
 		task_desc = $modal.find('[name="task_desc"]').val(),
-		delete_file = $modal.find('[name="delete_file"]').checked ? 'true':'false',
+		delete_file = $modal.find('[name="delete_file"]').prop('checked')? 'true':'false',
 		files = _fileUploadDropzone.getAcceptedFiles();
 	// console.log("FILES",files)