base_dto.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. from datetime import datetime
  2. from enum import Enum
  3. from typing import (
  4. Optional,
  5. TypeVar,
  6. Type,
  7. Dict,
  8. Any,
  9. List,
  10. Union,
  11. ClassVar,
  12. get_type_hints,
  13. get_origin,
  14. get_args,
  15. Generic,
  16. )
  17. from pydantic import BaseModel
  18. ModelType = TypeVar("ModelType")
  19. class DtoBase(BaseModel):
  20. @classmethod
  21. def from_model(cls, model):
  22. """从数据库模型转换为DTO实例(兼容所有ORM模型)"""
  23. # 优先使用模型自带的to_dict方法(SQLAlchemy模型)
  24. if hasattr(model, "to_dict"):
  25. return model.to_dict()
  26. # 其次尝试Pydantic模型的model_dump
  27. elif hasattr(model, "model_dump"):
  28. return model.model_dump()
  29. # 最终回退使用实例字典
  30. return model.__dict__
  31. """基础DTO类,提供通用的转换方法和验证方法"""
  32. __abstract__ = True
  33. id: Optional[int] = None
  34. # 需要格式化的日期时间字段列表
  35. datetime_fields: ClassVar[List[str]] = []
  36. # 日期时间格式
  37. datetime_format: ClassVar[str] = "%Y-%m-%d %H:%M:%S"
  38. @classmethod
  39. def from_dict(cls: Type[ModelType], data: Dict[str, Any]) -> ModelType:
  40. """从字典创建DTO对象,统一的字典转DTO方法"""
  41. # 预处理枚举类型
  42. processed_data = cls.preprocess_enum_fields(data)
  43. return cls(**processed_data)
  44. def to_dict(self) -> Dict[str, Any]:
  45. """转换为字典格式,自动处理日期时间字段"""
  46. data = self.model_dump()
  47. # 处理所有日期时间字段
  48. for field in self.datetime_fields:
  49. if field in data and data[field] is not None:
  50. if isinstance(data[field], datetime):
  51. data[field] = data[field].strftime(self.datetime_format)
  52. # 处理枚举类型字段
  53. for field_name, field_value in data.items():
  54. if isinstance(field_value, Enum):
  55. data[field_name] = field_value.value
  56. return data
  57. @classmethod
  58. def validate_field_length(
  59. cls, field_name: str, value: str, max_length: int
  60. ) -> None:
  61. """验证字段长度"""
  62. if value and len(value) > max_length:
  63. raise ValueError(f"{field_name}长度不能超过{max_length}个字符")
  64. @classmethod
  65. def validate_required_field(cls, field_name: str, value: Any) -> None:
  66. """验证必填字段"""
  67. if value is None:
  68. raise ValueError(f"{field_name}不能为空")
  69. @classmethod
  70. def convert_to_enum(cls, enum_class: Type[Enum], value: Any) -> Enum | None:
  71. """将值转换为枚举类型"""
  72. if value is None:
  73. return None
  74. if isinstance(value, enum_class):
  75. return value
  76. # 尝试使用to_enum方法转换
  77. if hasattr(enum_class, "to_enum"):
  78. return enum_class.to_enum(value)
  79. # 标准转换方式
  80. try:
  81. return enum_class(value)
  82. except ValueError:
  83. # 获取枚举的第一个值作为默认值
  84. return next(iter(enum_class))
  85. @classmethod
  86. def preprocess_enum_fields(cls, data: Dict[str, Any]) -> Dict[str, Any]:
  87. """预处理字典中的枚举类型字段"""
  88. if not data:
  89. return data
  90. result = data.copy()
  91. hints = get_type_hints(cls)
  92. for field_name, field_type in hints.items():
  93. # 跳过不在数据中的字段
  94. if field_name not in result:
  95. continue
  96. # 处理Optional类型
  97. origin = get_origin(field_type)
  98. if origin is Union:
  99. args = get_args(field_type)
  100. for arg in args:
  101. if isinstance(arg, type) and issubclass(arg, Enum):
  102. result[field_name] = cls.convert_to_enum(
  103. arg, result[field_name]
  104. )
  105. break
  106. # 直接处理Enum类型
  107. elif isinstance(field_type, type) and issubclass(field_type, Enum):
  108. result[field_name] = cls.convert_to_enum(field_type, result[field_name])
  109. return result
  110. class Config:
  111. from_attributes = True
  112. class CreateDtoBase(DtoBase):
  113. __abstract__ = True
  114. created_by: Optional[str] = None
  115. created_time: Optional[datetime] = None
  116. datetime_fields: ClassVar[List[str]] = ["created_time"]
  117. class UpdateDtoBase(CreateDtoBase):
  118. __abstract__ = True
  119. updated_by: Optional[str] = None
  120. updated_time: Optional[datetime] = None
  121. datetime_fields: ClassVar[List[str]] = ["created_time", "updated_time"]
  122. class FullDtoBase(UpdateDtoBase):
  123. """完整DTO类,包含所有基础字段"""
  124. __abstract__ = True
  125. is_del: Optional[int] = 0
  126. deleted_by: Optional[str] = None
  127. deleted_time: Optional[datetime] = None
  128. # 默认需要格式化的日期时间字段
  129. datetime_fields: ClassVar[List[str]] = [
  130. "created_time",
  131. "updated_time",
  132. "deleted_time",
  133. ]
  134. class PageDto(BaseModel):
  135. """分页查询DTO"""
  136. page_num: int = 1
  137. page_size: int = 10
  138. order: Optional[str] = None
  139. filters: Optional[Dict[str, Any]] = None
  140. filter_conditions: Optional[Dict[str, Any]] = None
  141. search: Optional[str] = None
  142. @property
  143. def offset(self):
  144. return (self.page_num - 1) * self.page_size
  145. @property
  146. def limit(self):
  147. return self.page_size
  148. GetSchemaType = TypeVar("GetSchemaType", bound=DtoBase)
  149. class PageResultDto(BaseModel, Generic[GetSchemaType]):
  150. """分页查询结果DTO"""
  151. total: int
  152. rows: List[GetSchemaType]
  153. page_size: Optional[int]