from typing import Generic, TypeVar, Optional, List, Any, Dict from sqlalchemy import select, func, desc, asc, Sequence from sqlalchemy.ext.asyncio import AsyncSession from core.current_user import CurrentUserProvider from core.enums import DeleteTypeEnum from core.exceptions import ServiceWarning from domain.dtos import DtoBase, PageDto, PageResultDto from utils import DBUtil, logger class ServiceBase: __abstract__ = True def __init__(self, db_name: Optional[str] = None): self._db_name = db_name self._current_user = CurrentUserProvider.get_current_user() self._current_user_name = "" self._logger = logger async def _get_async_db(self) -> AsyncSession: """根据db_name获取异步数据库连接""" try: # 使用DBUtil获取异步连接 async with DBUtil()(db_name=self._db_name) as conn: return conn except Exception as e: self._logger.error(f"获取数据库连接失败: {str(e)}") raise @staticmethod def _apply_search_base(model_class, query, search: str): """ 根据关键字构建查询对象 :param query: 查询对象 :param search: 关键字 :return: 查询对象 """ if search and hasattr(model_class, "name"): query = query.where(model_class.name.contains(search)) return query def _apply_filter_base( self, model_class, query, filters: Optional[Dict[str, Any]] = None, filter_conditions: Optional[List[Dict[str, Any]]] = None, ): """ 根据过滤条件构建查询对象 :param query: 查询对象 :param filters: 过滤条件 :param filter_conditions: 过滤条件 :return: 查询对象 """ _del = True if filters or filter_conditions: filters_list = self._build_filter(model_class, filters, filter_conditions) for f in filters_list: query = query.where(f) if f.left.key == "is_del": _del = False # 处理软删除条件 if _del and hasattr(model_class, "is_del"): query = query.where(model_class.is_del == DeleteTypeEnum.NORMAL.value) return query @staticmethod def _apply_order_base(model_class, query, order: str): """ 根据排序字段构建查询对象 :param query: 查询对象 :param order: 排序字段 :return: 查询对象 """ if not order: from domain.models.base_model import OrderModelBase if issubclass(model_class, OrderModelBase): # 使用OrderModelBase的默认排序字段 order = getattr(model_class, "__default_order__", "id desc") else: order = "id desc" order_list = order.split(",") for order_item in order_list: try: order_by, order_direction = order_item.strip().split(" ") if order_by and hasattr(model_class, order_by): order_column = getattr(model_class, order_by) if order_direction.lower() == "desc": query = query.order_by(desc(order_column)) else: query = query.order_by(asc(order_column)) except ValueError: continue return query @staticmethod def _build_filter( model_class, filter_dict: Dict[str, Any] = None, filter_conditions: List[Dict[str, Any]] = None, ) -> List: """构建过滤条件 Args: model_class: 模型类 filter_dict: 过滤条件字典,格式为{字段名: 值} filter_conditions: 过滤条件列表,格式为[{"field": "字段名", "operator": "操作符", "value": 值}] Returns: 过滤条件列表 """ filters_list = [] if filter_dict: for key, value in filter_dict: if value is None: continue if not hasattr(model_class, key): continue column = getattr(model_class, key) # 处理特殊的过滤条件 if key.endswith("_like") and isinstance(value, str): base_key = key[:-5] # 移除 _like 后缀 if hasattr(model_class, base_key): column = getattr(model_class, base_key) filters_list.append(column.like(f"%{value}%")) elif key.endswith("_in") and isinstance(value, list): base_key = key[:-3] # 移除 _in 后缀 if hasattr(model_class, base_key): column = getattr(model_class, base_key) filters_list.append(column.in_(value)) elif key.endswith("_gt"): base_key = key[:-3] # 移除 _gt 后缀 if hasattr(model_class, base_key): column = getattr(model_class, base_key) filters_list.append(column > value) elif key.endswith("_lt"): base_key = key[:-3] # 移除 _lt 后缀 if hasattr(model_class, base_key): column = getattr(model_class, base_key) filters_list.append(column < value) elif key.endswith("_gte"): base_key = key[:-4] # 移除 _gte 后缀 if hasattr(model_class, base_key): column = getattr(model_class, base_key) filters_list.append(column >= value) elif key.endswith("_lte"): base_key = key[:-4] # 移除 _lte 后缀 if hasattr(model_class, base_key): column = getattr(model_class, base_key) filters_list.append(column <= value) elif key == "start_time" and hasattr(model_class, "created_at"): filters_list.append(model_class.created_at >= value) elif key == "end_time" and hasattr(model_class, "created_at"): filters_list.append(model_class.created_at <= value) else: filters_list.append(column == value) if filter_conditions: for condition in filter_conditions: field = condition.get("field") operator = condition.get("operator", "eq") value = condition.get("value") if value is None or not field: continue if not hasattr(model_class, field): continue column = getattr(model_class, field) # 根据操作符构建过滤条件 if operator == "eq" or operator == "==": filters_list.append(column == value) elif operator == "ne" or operator == "!=": filters_list.append(column != value) elif operator == "gt" or operator == ">": filters_list.append(column > value) elif operator == "lt" or operator == "<": filters_list.append(column < value) elif operator == "gte" or operator == ">=": filters_list.append(column >= value) elif operator == "lte" or operator == "<=": filters_list.append(column <= value) elif operator == "like": filters_list.append(column.like(f"%{value}%")) elif operator == "in" and isinstance(value, list): filters_list.append(column.in_(value)) elif operator == "not_in" and isinstance(value, list): filters_list.append(~column.in_(value)) elif ( operator == "between" and isinstance(value, list) and len(value) == 2 ): filters_list.append(column.between(value[0], value[1])) return filters_list # 定义泛型类型变量 ModelType = TypeVar("ModelType") GetSchemaType = TypeVar("GetSchemaType", bound=DtoBase) CreateSchemaType = TypeVar("CreateSchemaType", bound=DtoBase) UpdateSchemaType = TypeVar("UpdateSchemaType", bound=DtoBase) class CurdServiceBase( Generic[ModelType, GetSchemaType, CreateSchemaType, UpdateSchemaType], ServiceBase ): __abstract__ = True def __init__( self, model_class, get_schema_class, create_schema_class, update_schema_class, db_name: Optional[str] = None, ): super().__init__(db_name) self._model_class = model_class self._get_schema_class = get_schema_class self._create_schema_class = create_schema_class self._update_schema_class = update_schema_class async def get_page_list( self, page_dto: PageDto, ) -> PageResultDto[GetSchemaType]: """ 获取实体列表 :param page_dto: 查询参数 :return: 实体列表 """ try: async with await self._get_async_db() as db: query = self._apply_query() query = self._apply_search(query, page_dto.search) query = self._apply_filter( query, page_dto.filters, page_dto.filter_conditions ) total = await db.scalar(select(func.count()).select_from(query)) query = self._apply_order(query, page_dto.order) query = query.offset(page_dto.offset).limit(page_dto.limit) result = await db.execute(query) models = result.scalars().all() return PageResultDto( total=total, rows=[self._get_schema_class.from_model(model) for model in models], page_size=page_dto.page_size, ) except Exception as e: raise ServiceWarning(f"获取实体列表失败: {str(e)}") def _apply_query(self): return select(self._model_class) def _apply_search(self, query, search: str): """ 根据关键字构建查询对象 :param query: 查询对象 :param search: 关键字 :return: 查询对象 """ return self._apply_search_base(self._model_class, query, search) def _apply_filter( self, query, filters: Optional[Dict[str, Any]] = None, filter_conditions: Optional[List[Dict[str, Any]]] = None, ): """ 根据过滤条件构建查询对象 :param query: 查询对象 :param filters: 过滤条件 :return: 查询对象 """ return self._apply_filter_base( self._model_class, query, filters, filter_conditions ) def _apply_order(self, query, order: str): """ 根据排序字段构建查询对象 :param query: 查询对象 :param order: 排序字段 :return: 查询对象 """ return self._apply_order_base(self._model_class, query, order) async def get(self, id: int, include_deleted: bool = False) -> ModelType | None: try: async with await self._get_async_db() as db: query = select(self._model_class).where(self._model_class.id == id) # type: ignore if not include_deleted: query = query.where( self._model_class.is_del == DeleteTypeEnum.NORMAL.key ) result = await db.execute(query) model = result.scalar_one_or_none() return model except Exception as e: raise ServiceWarning(f"获取实体失败: {str(e)}") async def get_dto( self, id: int, include_deleted: bool = False ) -> GetSchemaType | None: """ 获取实体DTO :param id: :param include_deleted: 是否包含已删除的实体 :return: 实体DTO """ model = await self.get(id, include_deleted) try: if model: return self._apply_map_get_dto(model) else: return None except Exception as e: raise ServiceWarning(f"获取实体DTO失败: {str(e)}") async def create(self, obj_in: CreateSchemaType) -> GetSchemaType: """ 创建实体 :param obj_in: 实体数据 :return: 创建后的实体 """ try: db_obj = self._apply_map_create_model(obj_in) async with await self._get_async_db() as db: db.add(db_obj) await db.commit() await db.refresh(db_obj) return self._apply_map_get_dto(db_obj) except Exception as e: raise ServiceWarning(f"创建实体失败: {str(e)}") async def update(self, obj_in: UpdateSchemaType) -> GetSchemaType: """ 更新实体 :param obj_in: 实体数据 :return: 创建后的实体 """ try: async with await self._get_async_db() as db: db_obj = await self.get(obj_in.id) db_obj = self._apply_map_update_model(db_obj, obj_in) await db.commit() await db.refresh(db_obj) return self._apply_map_get_dto(db_obj) except Exception as e: raise ServiceWarning(f"更新实体失败: {str(e)}") async def delete(self, id: int) -> None: """ 删除实体 :param id: :return: """ try: async with await self._get_async_db() as db: db_obj = await self.get(id) if not db_obj: raise ValueError( f"{self._model_class.__name__} with id {id} not found" ) await db.delete(db_obj) await db.commit() await db.refresh(db_obj) except Exception as e: raise ServiceWarning(f"删除实体失败: {str(e)}") def _apply_map_get_dto_list( self, models: Sequence[ModelType] ) -> List[GetSchemaType]: """ 将模型对象转换为DTO对象 :param models: 模型对象列表 :return: DTO对象 """ return [self._apply_map_get_dto(model) for model in models] def _apply_map_get_dto(self, model: ModelType) -> GetSchemaType: """ 将模型对象转换为DTO对象 :param model: 模型对象 :return: DTO对象 """ return self._get_schema_class.from_model(model) def _apply_map_create_model(self, obj_in: CreateSchemaType) -> ModelType: dict_obj = obj_in.to_dict() return self._model_class.from_dict(dict_obj) @staticmethod def _apply_map_update_model( db_obj: ModelType, obj_in: UpdateSchemaType ) -> ModelType: dict_obj = obj_in.to_dict() db_obj.update(dict_obj) return db_obj