base_services.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. from typing import Generic, TypeVar, Optional, List, Any, Dict
  2. from sqlalchemy import select, func, desc, asc, Sequence
  3. from sqlalchemy.ext.asyncio import AsyncSession
  4. from core.current_user import CurrentUserProvider
  5. from core.enums import DeleteTypeEnum
  6. from core.exceptions import ServiceWarning
  7. from domain.dtos import DtoBase, PageDto, PageResultDto
  8. from utils import DBUtil, logger
  9. class ServiceBase:
  10. __abstract__ = True
  11. def __init__(self, db_name: Optional[str] = None):
  12. self._db_name = db_name
  13. self._current_user = CurrentUserProvider.get_current_user()
  14. self._current_user_name = ""
  15. self._logger = logger
  16. async def _get_async_db(self) -> AsyncSession:
  17. """根据db_name获取异步数据库连接"""
  18. try:
  19. # 使用DBUtil获取异步连接
  20. async with DBUtil()(db_name=self._db_name) as conn:
  21. return conn
  22. except Exception as e:
  23. self._logger.error(f"获取数据库连接失败: {str(e)}")
  24. raise
  25. @staticmethod
  26. def _apply_search_base(model_class, query, search: str):
  27. """
  28. 根据关键字构建查询对象
  29. :param query: 查询对象
  30. :param search: 关键字
  31. :return: 查询对象
  32. """
  33. if search and hasattr(model_class, "name"):
  34. query = query.where(model_class.name.contains(search))
  35. return query
  36. def _apply_filter_base(
  37. self,
  38. model_class,
  39. query,
  40. filters: Optional[Dict[str, Any]] = None,
  41. filter_conditions: Optional[List[Dict[str, Any]]] = None,
  42. ):
  43. """
  44. 根据过滤条件构建查询对象
  45. :param query: 查询对象
  46. :param filters: 过滤条件
  47. :param filter_conditions: 过滤条件
  48. :return: 查询对象
  49. """
  50. _del = True
  51. if filters or filter_conditions:
  52. filters_list = self._build_filter(model_class, filters, filter_conditions)
  53. for f in filters_list:
  54. query = query.where(f)
  55. if f.left.key == "is_del":
  56. _del = False
  57. # 处理软删除条件
  58. if _del and hasattr(model_class, "is_del"):
  59. query = query.where(model_class.is_del == DeleteTypeEnum.NORMAL.value)
  60. return query
  61. @staticmethod
  62. def _apply_order_base(model_class, query, order: str):
  63. """
  64. 根据排序字段构建查询对象
  65. :param query: 查询对象
  66. :param order: 排序字段
  67. :return: 查询对象
  68. """
  69. if not order:
  70. from domain.models.base_model import OrderModelBase
  71. if issubclass(model_class, OrderModelBase):
  72. # 使用OrderModelBase的默认排序字段
  73. order = getattr(model_class, "__default_order__", "id desc")
  74. else:
  75. order = "id desc"
  76. order_list = order.split(",")
  77. for order_item in order_list:
  78. try:
  79. order_by, order_direction = order_item.strip().split(" ")
  80. if order_by and hasattr(model_class, order_by):
  81. order_column = getattr(model_class, order_by)
  82. if order_direction.lower() == "desc":
  83. query = query.order_by(desc(order_column))
  84. else:
  85. query = query.order_by(asc(order_column))
  86. except ValueError:
  87. continue
  88. return query
  89. @staticmethod
  90. def _build_filter(
  91. model_class,
  92. filter_dict: Dict[str, Any] = None,
  93. filter_conditions: List[Dict[str, Any]] = None,
  94. ) -> List:
  95. """构建过滤条件
  96. Args:
  97. model_class: 模型类
  98. filter_dict: 过滤条件字典,格式为{字段名: 值}
  99. filter_conditions: 过滤条件列表,格式为[{"field": "字段名", "operator": "操作符", "value": 值}]
  100. Returns:
  101. 过滤条件列表
  102. """
  103. filters_list = []
  104. if filter_dict:
  105. for key, value in filter_dict:
  106. if value is None:
  107. continue
  108. if not hasattr(model_class, key):
  109. continue
  110. column = getattr(model_class, key)
  111. # 处理特殊的过滤条件
  112. if key.endswith("_like") and isinstance(value, str):
  113. base_key = key[:-5] # 移除 _like 后缀
  114. if hasattr(model_class, base_key):
  115. column = getattr(model_class, base_key)
  116. filters_list.append(column.like(f"%{value}%"))
  117. elif key.endswith("_in") and isinstance(value, list):
  118. base_key = key[:-3] # 移除 _in 后缀
  119. if hasattr(model_class, base_key):
  120. column = getattr(model_class, base_key)
  121. filters_list.append(column.in_(value))
  122. elif key.endswith("_gt"):
  123. base_key = key[:-3] # 移除 _gt 后缀
  124. if hasattr(model_class, base_key):
  125. column = getattr(model_class, base_key)
  126. filters_list.append(column > value)
  127. elif key.endswith("_lt"):
  128. base_key = key[:-3] # 移除 _lt 后缀
  129. if hasattr(model_class, base_key):
  130. column = getattr(model_class, base_key)
  131. filters_list.append(column < value)
  132. elif key.endswith("_gte"):
  133. base_key = key[:-4] # 移除 _gte 后缀
  134. if hasattr(model_class, base_key):
  135. column = getattr(model_class, base_key)
  136. filters_list.append(column >= value)
  137. elif key.endswith("_lte"):
  138. base_key = key[:-4] # 移除 _lte 后缀
  139. if hasattr(model_class, base_key):
  140. column = getattr(model_class, base_key)
  141. filters_list.append(column <= value)
  142. elif key == "start_time" and hasattr(model_class, "created_at"):
  143. filters_list.append(model_class.created_at >= value)
  144. elif key == "end_time" and hasattr(model_class, "created_at"):
  145. filters_list.append(model_class.created_at <= value)
  146. else:
  147. filters_list.append(column == value)
  148. if filter_conditions:
  149. for condition in filter_conditions:
  150. field = condition.get("field")
  151. operator = condition.get("operator", "eq")
  152. value = condition.get("value")
  153. if value is None or not field:
  154. continue
  155. if not hasattr(model_class, field):
  156. continue
  157. column = getattr(model_class, field)
  158. # 根据操作符构建过滤条件
  159. if operator == "eq" or operator == "==":
  160. filters_list.append(column == value)
  161. elif operator == "ne" or operator == "!=":
  162. filters_list.append(column != value)
  163. elif operator == "gt" or operator == ">":
  164. filters_list.append(column > value)
  165. elif operator == "lt" or operator == "<":
  166. filters_list.append(column < value)
  167. elif operator == "gte" or operator == ">=":
  168. filters_list.append(column >= value)
  169. elif operator == "lte" or operator == "<=":
  170. filters_list.append(column <= value)
  171. elif operator == "like":
  172. filters_list.append(column.like(f"%{value}%"))
  173. elif operator == "in" and isinstance(value, list):
  174. filters_list.append(column.in_(value))
  175. elif operator == "not_in" and isinstance(value, list):
  176. filters_list.append(~column.in_(value))
  177. elif (
  178. operator == "between"
  179. and isinstance(value, list)
  180. and len(value) == 2
  181. ):
  182. filters_list.append(column.between(value[0], value[1]))
  183. return filters_list
  184. # 定义泛型类型变量
  185. ModelType = TypeVar("ModelType")
  186. GetSchemaType = TypeVar("GetSchemaType", bound=DtoBase)
  187. CreateSchemaType = TypeVar("CreateSchemaType", bound=DtoBase)
  188. UpdateSchemaType = TypeVar("UpdateSchemaType", bound=DtoBase)
  189. class CurdServiceBase(
  190. Generic[ModelType, GetSchemaType, CreateSchemaType, UpdateSchemaType], ServiceBase
  191. ):
  192. __abstract__ = True
  193. def __init__(
  194. self,
  195. model_class,
  196. get_schema_class,
  197. create_schema_class,
  198. update_schema_class,
  199. db_name: Optional[str] = None,
  200. ):
  201. super().__init__(db_name)
  202. self._model_class = model_class
  203. self._get_schema_class = get_schema_class
  204. self._create_schema_class = create_schema_class
  205. self._update_schema_class = update_schema_class
  206. async def get_page_list(
  207. self,
  208. page_dto: PageDto,
  209. ) -> PageResultDto[GetSchemaType]:
  210. """
  211. 获取实体列表
  212. :param page_dto: 查询参数
  213. :return: 实体列表
  214. """
  215. try:
  216. async with await self._get_async_db() as db:
  217. query = self._apply_query()
  218. query = self._apply_search(query, page_dto.search)
  219. query = self._apply_filter(
  220. query, page_dto.filters, page_dto.filter_conditions
  221. )
  222. total = await db.scalar(select(func.count()).select_from(query))
  223. query = self._apply_order(query, page_dto.order)
  224. query = query.offset(page_dto.offset).limit(page_dto.limit)
  225. result = await db.execute(query)
  226. models = result.scalars().all()
  227. return PageResultDto(
  228. total=total,
  229. rows=[self._get_schema_class.from_model(model) for model in models],
  230. page_size=page_dto.page_size,
  231. )
  232. except Exception as e:
  233. raise ServiceWarning(f"获取实体列表失败: {str(e)}")
  234. def _apply_query(self):
  235. return select(self._model_class)
  236. def _apply_search(self, query, search: str):
  237. """
  238. 根据关键字构建查询对象
  239. :param query: 查询对象
  240. :param search: 关键字
  241. :return: 查询对象
  242. """
  243. return self._apply_search_base(self._model_class, query, search)
  244. def _apply_filter(
  245. self,
  246. query,
  247. filters: Optional[Dict[str, Any]] = None,
  248. filter_conditions: Optional[List[Dict[str, Any]]] = None,
  249. ):
  250. """
  251. 根据过滤条件构建查询对象
  252. :param query: 查询对象
  253. :param filters: 过滤条件
  254. :return: 查询对象
  255. """
  256. return self._apply_filter_base(
  257. self._model_class, query, filters, filter_conditions
  258. )
  259. def _apply_order(self, query, order: str):
  260. """
  261. 根据排序字段构建查询对象
  262. :param query: 查询对象
  263. :param order: 排序字段
  264. :return: 查询对象
  265. """
  266. return self._apply_order_base(self._model_class, query, order)
  267. async def get(self, id: int, include_deleted: bool = False) -> ModelType | None:
  268. try:
  269. async with await self._get_async_db() as db:
  270. query = select(self._model_class).where(self._model_class.id == id) # type: ignore
  271. if not include_deleted:
  272. query = query.where(
  273. self._model_class.is_del == DeleteTypeEnum.NORMAL.key
  274. )
  275. result = await db.execute(query)
  276. model = result.scalar_one_or_none()
  277. return model
  278. except Exception as e:
  279. raise ServiceWarning(f"获取实体失败: {str(e)}")
  280. async def get_dto(
  281. self, id: int, include_deleted: bool = False
  282. ) -> GetSchemaType | None:
  283. """
  284. 获取实体DTO
  285. :param id:
  286. :param include_deleted: 是否包含已删除的实体
  287. :return: 实体DTO
  288. """
  289. model = await self.get(id, include_deleted)
  290. try:
  291. if model:
  292. return self._apply_map_get_dto(model)
  293. else:
  294. return None
  295. except Exception as e:
  296. raise ServiceWarning(f"获取实体DTO失败: {str(e)}")
  297. async def create(self, obj_in: CreateSchemaType) -> GetSchemaType:
  298. """
  299. 创建实体
  300. :param obj_in: 实体数据
  301. :return: 创建后的实体
  302. """
  303. try:
  304. db_obj = self._apply_map_create_model(obj_in)
  305. async with await self._get_async_db() as db:
  306. db.add(db_obj)
  307. await db.commit()
  308. await db.refresh(db_obj)
  309. return self._apply_map_get_dto(db_obj)
  310. except Exception as e:
  311. raise ServiceWarning(f"创建实体失败: {str(e)}")
  312. async def update(self, obj_in: UpdateSchemaType) -> GetSchemaType:
  313. """
  314. 更新实体
  315. :param obj_in: 实体数据
  316. :return: 创建后的实体
  317. """
  318. try:
  319. async with await self._get_async_db() as db:
  320. db_obj = await self.get(obj_in.id)
  321. db_obj = self._apply_map_update_model(db_obj, obj_in)
  322. await db.commit()
  323. await db.refresh(db_obj)
  324. return self._apply_map_get_dto(db_obj)
  325. except Exception as e:
  326. raise ServiceWarning(f"更新实体失败: {str(e)}")
  327. async def delete(self, id: int) -> None:
  328. """
  329. 删除实体
  330. :param id:
  331. :return:
  332. """
  333. try:
  334. async with await self._get_async_db() as db:
  335. db_obj = await self.get(id)
  336. if not db_obj:
  337. raise ValueError(
  338. f"{self._model_class.__name__} with id {id} not found"
  339. )
  340. await db.delete(db_obj)
  341. await db.commit()
  342. await db.refresh(db_obj)
  343. except Exception as e:
  344. raise ServiceWarning(f"删除实体失败: {str(e)}")
  345. def _apply_map_get_dto_list(
  346. self, models: Sequence[ModelType]
  347. ) -> List[GetSchemaType]:
  348. """
  349. 将模型对象转换为DTO对象
  350. :param models: 模型对象列表
  351. :return: DTO对象
  352. """
  353. return [self._apply_map_get_dto(model) for model in models]
  354. def _apply_map_get_dto(self, model: ModelType) -> GetSchemaType:
  355. """
  356. 将模型对象转换为DTO对象
  357. :param model: 模型对象
  358. :return: DTO对象
  359. """
  360. return self._get_schema_class.from_model(model)
  361. def _apply_map_create_model(self, obj_in: CreateSchemaType) -> ModelType:
  362. dict_obj = obj_in.to_dict()
  363. return self._model_class.from_dict(dict_obj)
  364. @staticmethod
  365. def _apply_map_update_model(
  366. db_obj: ModelType, obj_in: UpdateSchemaType
  367. ) -> ModelType:
  368. dict_obj = obj_in.to_dict()
  369. db_obj.update(dict_obj)
  370. return db_obj