12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485 |
- from typing import TypeVar, Generic, Optional, Dict, Any
- from sqlalchemy.orm import Session
- from sqlalchemy import and_
- from ..models.base_model import BaseModel, DeleteModel
- T = TypeVar('T', bound=BaseModel)
- class BaseStore(Generic[T]):
- """
- 基础存储类,提供CRUD操作
- """
- def __init__(self, db: Session, model: type[T]):
- self.db = db
- self.model = model
- def get(self, id: int, include_deleted: bool = False) -> Optional[T]:
- """根据ID获取单个对象"""
- query = self.db.query(self.model).filter(self.model.id == id)
- if not include_deleted and issubclass(self.model, DeleteModel):
- query = query.filter(self.model.is_deleted == False)
- return query.first()
- def get_all(self,
- skip: int = 0,
- limit: int = 100,
- include_deleted: bool = False) -> list[T]:
- """获取所有对象"""
- query = self.db.query(self.model)
- if not include_deleted and issubclass(self.model, DeleteModel):
- query = query.filter(self.model.is_deleted == False)
- return query.offset(skip).limit(limit).all()
- def create(self, obj: T, unique_fields: Dict[str, Any] = None) -> T:
- """创建新对象"""
- if unique_fields:
- self._check_unique_constraints(obj, unique_fields)
- self.db.add(obj)
- self.db.commit()
- self.db.refresh(obj)
- return obj
- def update(self, obj: T, unique_fields: Dict[str, Any] = None) -> T:
- """更新对象"""
- if unique_fields:
- self._check_unique_constraints(obj, unique_fields)
- self.db.commit()
- self.db.refresh(obj)
- return obj
- def delete(self, id: int, soft_delete: bool = True) -> None:
- """删除对象"""
- obj = self.get(id)
- if obj:
- if soft_delete and issubclass(self.model, DeleteModel):
- obj.is_deleted = True
- self.db.commit()
- else:
- self.db.delete(obj)
- self.db.commit()
- def _check_unique_constraints(self, obj: T,
- unique_fields: Dict[str, Any]) -> None:
- """检查唯一性约束"""
- for field, value in unique_fields.items():
- query = self.db.query(
- self.model).filter(getattr(self.model, field) == value)
- if obj.id:
- query = query.filter(self.model.id != obj.id)
- if issubclass(self.model, DeleteModel):
- query = query.filter(self.model.is_deleted == False)
- if query.first():
- raise ValueError(f"{field} must be unique")
- def restore(self, id: int) -> Optional[T]:
- """恢复软删除的对象"""
- if not issubclass(self.model, DeleteModel):
- return None
- obj = self.get(id, include_deleted=True)
- if obj and obj.is_deleted:
- obj.is_deleted = False
- self.db.commit()
- return obj
- return None
|