base_store.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. from typing import TypeVar, Generic, Optional, Dict, Any
  2. from sqlalchemy.orm import Session
  3. from sqlalchemy import and_
  4. from ..models.base_model import BaseModel, DeleteModel
  5. T = TypeVar('T', bound=BaseModel)
  6. class BaseStore(Generic[T]):
  7. """
  8. 基础存储类,提供CRUD操作
  9. """
  10. def __init__(self, db: Session, model: type[T]):
  11. self.db = db
  12. self.model = model
  13. def get(self, id: int, include_deleted: bool = False) -> Optional[T]:
  14. """根据ID获取单个对象"""
  15. query = self.db.query(self.model).filter(self.model.id == id)
  16. if not include_deleted and issubclass(self.model, DeleteModel):
  17. query = query.filter(self.model.is_deleted == False)
  18. return query.first()
  19. def get_all(self,
  20. skip: int = 0,
  21. limit: int = 100,
  22. include_deleted: bool = False) -> list[T]:
  23. """获取所有对象"""
  24. query = self.db.query(self.model)
  25. if not include_deleted and issubclass(self.model, DeleteModel):
  26. query = query.filter(self.model.is_deleted == False)
  27. return query.offset(skip).limit(limit).all()
  28. def create(self, obj: T, unique_fields: Dict[str, Any] = None) -> T:
  29. """创建新对象"""
  30. if unique_fields:
  31. self._check_unique_constraints(obj, unique_fields)
  32. self.db.add(obj)
  33. self.db.commit()
  34. self.db.refresh(obj)
  35. return obj
  36. def update(self, obj: T, unique_fields: Dict[str, Any] = None) -> T:
  37. """更新对象"""
  38. if unique_fields:
  39. self._check_unique_constraints(obj, unique_fields)
  40. self.db.commit()
  41. self.db.refresh(obj)
  42. return obj
  43. def delete(self, id: int, soft_delete: bool = True) -> None:
  44. """删除对象"""
  45. obj = self.get(id)
  46. if obj:
  47. if soft_delete and issubclass(self.model, DeleteModel):
  48. obj.is_deleted = True
  49. self.db.commit()
  50. else:
  51. self.db.delete(obj)
  52. self.db.commit()
  53. def _check_unique_constraints(self, obj: T,
  54. unique_fields: Dict[str, Any]) -> None:
  55. """检查唯一性约束"""
  56. for field, value in unique_fields.items():
  57. query = self.db.query(
  58. self.model).filter(getattr(self.model, field) == value)
  59. if obj.id:
  60. query = query.filter(self.model.id != obj.id)
  61. if issubclass(self.model, DeleteModel):
  62. query = query.filter(self.model.is_deleted == False)
  63. if query.first():
  64. raise ValueError(f"{field} must be unique")
  65. def restore(self, id: int) -> Optional[T]:
  66. """恢复软删除的对象"""
  67. if not issubclass(self.model, DeleteModel):
  68. return None
  69. obj = self.get(id, include_deleted=True)
  70. if obj and obj.is_deleted:
  71. obj.is_deleted = False
  72. self.db.commit()
  73. return obj
  74. return None