| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410 |
- from typing import Generic, TypeVar, Optional, List, Any, Dict
- from sqlalchemy import select, func, desc, asc, Sequence
- from sqlalchemy.ext.asyncio import AsyncSession
- from core.current_user import CurrentUserProvider
- from core.enums import DeleteTypeEnum
- from core.exceptions import ServiceWarning
- from domain.dtos import DtoBase, PageDto, PageResultDto
- from utils import DBUtil, logger
- class ServiceBase:
- __abstract__ = True
- def __init__(self, db_name: Optional[str] = None):
- self._db_name = db_name
- self._current_user = CurrentUserProvider.get_current_user()
- self._current_user_name = ""
- self._logger = logger
- async def _get_async_db(self) -> AsyncSession:
- """根据db_name获取异步数据库连接"""
- try:
- # 使用DBUtil获取异步连接
- async with DBUtil()(db_name=self._db_name) as conn:
- return conn
- except Exception as e:
- self._logger.error(f"获取数据库连接失败: {str(e)}")
- raise
- @staticmethod
- def _apply_search_base(model_class, query, search: str):
- """
- 根据关键字构建查询对象
- :param query: 查询对象
- :param search: 关键字
- :return: 查询对象
- """
- if search and hasattr(model_class, "name"):
- query = query.where(model_class.name.contains(search))
- return query
- def _apply_filter_base(
- self,
- model_class,
- query,
- filters: Optional[Dict[str, Any]] = None,
- filter_conditions: Optional[List[Dict[str, Any]]] = None,
- ):
- """
- 根据过滤条件构建查询对象
- :param query: 查询对象
- :param filters: 过滤条件
- :param filter_conditions: 过滤条件
- :return: 查询对象
- """
- _del = True
- if filters or filter_conditions:
- filters_list = self._build_filter(model_class, filters, filter_conditions)
- for f in filters_list:
- query = query.where(f)
- if f.left.key == "is_del":
- _del = False
- # 处理软删除条件
- if _del and hasattr(model_class, "is_del"):
- query = query.where(model_class.is_del == DeleteTypeEnum.NORMAL.value)
- return query
- @staticmethod
- def _apply_order_base(model_class, query, order: str):
- """
- 根据排序字段构建查询对象
- :param query: 查询对象
- :param order: 排序字段
- :return: 查询对象
- """
- if not order:
- from domain.models.base_model import OrderModelBase
- if issubclass(model_class, OrderModelBase):
- # 使用OrderModelBase的默认排序字段
- order = getattr(model_class, "__default_order__", "id desc")
- else:
- order = "id desc"
- order_list = order.split(",")
- for order_item in order_list:
- try:
- order_by, order_direction = order_item.strip().split(" ")
- if order_by and hasattr(model_class, order_by):
- order_column = getattr(model_class, order_by)
- if order_direction.lower() == "desc":
- query = query.order_by(desc(order_column))
- else:
- query = query.order_by(asc(order_column))
- except ValueError:
- continue
- return query
- @staticmethod
- def _build_filter(
- model_class,
- filter_dict: Dict[str, Any] = None,
- filter_conditions: List[Dict[str, Any]] = None,
- ) -> List:
- """构建过滤条件
- Args:
- model_class: 模型类
- filter_dict: 过滤条件字典,格式为{字段名: 值}
- filter_conditions: 过滤条件列表,格式为[{"field": "字段名", "operator": "操作符", "value": 值}]
- Returns:
- 过滤条件列表
- """
- filters_list = []
- if filter_dict:
- for key, value in filter_dict:
- if value is None:
- continue
- if not hasattr(model_class, key):
- continue
- column = getattr(model_class, key)
- # 处理特殊的过滤条件
- if key.endswith("_like") and isinstance(value, str):
- base_key = key[:-5] # 移除 _like 后缀
- if hasattr(model_class, base_key):
- column = getattr(model_class, base_key)
- filters_list.append(column.like(f"%{value}%"))
- elif key.endswith("_in") and isinstance(value, list):
- base_key = key[:-3] # 移除 _in 后缀
- if hasattr(model_class, base_key):
- column = getattr(model_class, base_key)
- filters_list.append(column.in_(value))
- elif key.endswith("_gt"):
- base_key = key[:-3] # 移除 _gt 后缀
- if hasattr(model_class, base_key):
- column = getattr(model_class, base_key)
- filters_list.append(column > value)
- elif key.endswith("_lt"):
- base_key = key[:-3] # 移除 _lt 后缀
- if hasattr(model_class, base_key):
- column = getattr(model_class, base_key)
- filters_list.append(column < value)
- elif key.endswith("_gte"):
- base_key = key[:-4] # 移除 _gte 后缀
- if hasattr(model_class, base_key):
- column = getattr(model_class, base_key)
- filters_list.append(column >= value)
- elif key.endswith("_lte"):
- base_key = key[:-4] # 移除 _lte 后缀
- if hasattr(model_class, base_key):
- column = getattr(model_class, base_key)
- filters_list.append(column <= value)
- elif key == "start_time" and hasattr(model_class, "created_at"):
- filters_list.append(model_class.created_at >= value)
- elif key == "end_time" and hasattr(model_class, "created_at"):
- filters_list.append(model_class.created_at <= value)
- else:
- filters_list.append(column == value)
- if filter_conditions:
- for condition in filter_conditions:
- field = condition.get("field")
- operator = condition.get("operator", "eq")
- value = condition.get("value")
- if value is None or not field:
- continue
- if not hasattr(model_class, field):
- continue
- column = getattr(model_class, field)
- # 根据操作符构建过滤条件
- if operator == "eq" or operator == "==":
- filters_list.append(column == value)
- elif operator == "ne" or operator == "!=":
- filters_list.append(column != value)
- elif operator == "gt" or operator == ">":
- filters_list.append(column > value)
- elif operator == "lt" or operator == "<":
- filters_list.append(column < value)
- elif operator == "gte" or operator == ">=":
- filters_list.append(column >= value)
- elif operator == "lte" or operator == "<=":
- filters_list.append(column <= value)
- elif operator == "like":
- filters_list.append(column.like(f"%{value}%"))
- elif operator == "in" and isinstance(value, list):
- filters_list.append(column.in_(value))
- elif operator == "not_in" and isinstance(value, list):
- filters_list.append(~column.in_(value))
- elif (
- operator == "between"
- and isinstance(value, list)
- and len(value) == 2
- ):
- filters_list.append(column.between(value[0], value[1]))
- return filters_list
- # 定义泛型类型变量
- ModelType = TypeVar("ModelType")
- GetSchemaType = TypeVar("GetSchemaType", bound=DtoBase)
- CreateSchemaType = TypeVar("CreateSchemaType", bound=DtoBase)
- UpdateSchemaType = TypeVar("UpdateSchemaType", bound=DtoBase)
- class CurdServiceBase(
- Generic[ModelType, GetSchemaType, CreateSchemaType, UpdateSchemaType], ServiceBase
- ):
- __abstract__ = True
- def __init__(
- self,
- model_class,
- get_schema_class,
- create_schema_class,
- update_schema_class,
- db_name: Optional[str] = None,
- ):
- super().__init__(db_name)
- self._model_class = model_class
- self._get_schema_class = get_schema_class
- self._create_schema_class = create_schema_class
- self._update_schema_class = update_schema_class
- async def get_page_list(
- self,
- page_dto: PageDto,
- ) -> PageResultDto[GetSchemaType]:
- """
- 获取实体列表
- :param page_dto: 查询参数
- :return: 实体列表
- """
- try:
- async with await self._get_async_db() as db:
- query = self._apply_query()
- query = self._apply_search(query, page_dto.search)
- query = self._apply_filter(
- query, page_dto.filters, page_dto.filter_conditions
- )
- total = await db.scalar(select(func.count()).select_from(query))
- query = self._apply_order(query, page_dto.order)
- query = query.offset(page_dto.offset).limit(page_dto.limit)
- result = await db.execute(query)
- models = result.scalars().all()
- return PageResultDto(
- total=total,
- rows=[self._get_schema_class.from_model(model) for model in models],
- page_size=page_dto.page_size,
- )
- except Exception as e:
- raise ServiceWarning(f"获取实体列表失败: {str(e)}")
- def _apply_query(self):
- return select(self._model_class)
- def _apply_search(self, query, search: str):
- """
- 根据关键字构建查询对象
- :param query: 查询对象
- :param search: 关键字
- :return: 查询对象
- """
- return self._apply_search_base(self._model_class, query, search)
- def _apply_filter(
- self,
- query,
- filters: Optional[Dict[str, Any]] = None,
- filter_conditions: Optional[List[Dict[str, Any]]] = None,
- ):
- """
- 根据过滤条件构建查询对象
- :param query: 查询对象
- :param filters: 过滤条件
- :return: 查询对象
- """
- return self._apply_filter_base(
- self._model_class, query, filters, filter_conditions
- )
- def _apply_order(self, query, order: str):
- """
- 根据排序字段构建查询对象
- :param query: 查询对象
- :param order: 排序字段
- :return: 查询对象
- """
- return self._apply_order_base(self._model_class, query, order)
- async def get(self, id: int, include_deleted: bool = False) -> ModelType | None:
- try:
- async with await self._get_async_db() as db:
- query = select(self._model_class).where(self._model_class.id == id) # type: ignore
- if not include_deleted:
- query = query.where(
- self._model_class.is_del == DeleteTypeEnum.NORMAL.key
- )
- result = await db.execute(query)
- model = result.scalar_one_or_none()
- return model
- except Exception as e:
- raise ServiceWarning(f"获取实体失败: {str(e)}")
- async def get_dto(
- self, id: int, include_deleted: bool = False
- ) -> GetSchemaType | None:
- """
- 获取实体DTO
- :param id:
- :param include_deleted: 是否包含已删除的实体
- :return: 实体DTO
- """
- model = await self.get(id, include_deleted)
- try:
- if model:
- return self._apply_map_get_dto(model)
- else:
- return None
- except Exception as e:
- raise ServiceWarning(f"获取实体DTO失败: {str(e)}")
- async def create(self, obj_in: CreateSchemaType) -> GetSchemaType:
- """
- 创建实体
- :param obj_in: 实体数据
- :return: 创建后的实体
- """
- try:
- db_obj = self._apply_map_create_model(obj_in)
- async with await self._get_async_db() as db:
- db.add(db_obj)
- await db.commit()
- await db.refresh(db_obj)
- return self._apply_map_get_dto(db_obj)
- except Exception as e:
- raise ServiceWarning(f"创建实体失败: {str(e)}")
- async def update(self, obj_in: UpdateSchemaType) -> GetSchemaType:
- """
- 更新实体
- :param obj_in: 实体数据
- :return: 创建后的实体
- """
- try:
- async with await self._get_async_db() as db:
- db_obj = await self.get(obj_in.id)
- db_obj = self._apply_map_update_model(db_obj, obj_in)
- await db.commit()
- await db.refresh(db_obj)
- return self._apply_map_get_dto(db_obj)
- except Exception as e:
- raise ServiceWarning(f"更新实体失败: {str(e)}")
- async def delete(self, id: int) -> None:
- """
- 删除实体
- :param id:
- :return:
- """
- try:
- async with await self._get_async_db() as db:
- db_obj = await self.get(id)
- if not db_obj:
- raise ValueError(
- f"{self._model_class.__name__} with id {id} not found"
- )
- await db.delete(db_obj)
- await db.commit()
- await db.refresh(db_obj)
- except Exception as e:
- raise ServiceWarning(f"删除实体失败: {str(e)}")
- def _apply_map_get_dto_list(
- self, models: Sequence[ModelType]
- ) -> List[GetSchemaType]:
- """
- 将模型对象转换为DTO对象
- :param models: 模型对象列表
- :return: DTO对象
- """
- return [self._apply_map_get_dto(model) for model in models]
- def _apply_map_get_dto(self, model: ModelType) -> GetSchemaType:
- """
- 将模型对象转换为DTO对象
- :param model: 模型对象
- :return: DTO对象
- """
- return self._get_schema_class.from_model(model)
- def _apply_map_create_model(self, obj_in: CreateSchemaType) -> ModelType:
- dict_obj = obj_in.to_dict()
- return self._model_class.from_dict(dict_obj)
- @staticmethod
- def _apply_map_update_model(
- db_obj: ModelType, obj_in: UpdateSchemaType
- ) -> ModelType:
- dict_obj = obj_in.to_dict()
- db_obj.update(dict_obj)
- return db_obj
|