| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894 |
- package gormadapter
- import (
- "context"
- "errors"
- "fmt"
- "runtime"
- "strings"
- "github.com/casbin/casbin/v2/model"
- "github.com/casbin/casbin/v2/persist"
- "gorm.io/gorm"
- "gorm.io/gorm/logger"
- "gorm.io/plugin/dbresolver"
- )
- const (
- defaultDatabaseName = "casbin"
- defaultTableName = "casbin_rule"
- )
- const disableMigrateKey = "disableMigrateKey"
- const customTableKey = "customTableKey"
- type CasbinRule struct {
- ID uint `gorm:"primaryKey;autoIncrement"`
- PType string `gorm:"size:100"`
- V0 string `gorm:"size:100"`
- V1 string `gorm:"size:100"`
- V2 string `gorm:"size:100"`
- V3 string `gorm:"size:100"`
- V4 string `gorm:"size:100"`
- V5 string `gorm:"size:100"`
- V6 string `gorm:"size:25"`
- V7 string `gorm:"size:25"`
- }
- func (*CasbinRule) TableName() string {
- return "casbin_rule"
- }
- type Filter struct {
- PType []string
- V0 []string
- V1 []string
- V2 []string
- V3 []string
- V4 []string
- V5 []string
- V6 []string
- V7 []string
- }
- // Adapter 策略存储的Gorm适配器
- type Adapter struct {
- driverName string
- dataSourceName string
- databaseName string
- tablePrefix string
- tableName string
- dbSpecified bool
- db *gorm.DB
- isFiltered bool
- }
- // finalizer Adapter的析构函数
- func finalizer(a *Adapter) {
- sqlDB, err := a.db.DB()
- if err != nil {
- panic(err)
- }
- err = sqlDB.Close()
- if err != nil {
- panic(err)
- }
- }
- // 根据表名选择conn(use map store name-index)
- type specificPolicy int
- func (p *specificPolicy) Resolve(connPools []gorm.ConnPool) gorm.ConnPool {
- return connPools[*p]
- }
- type DbPool struct {
- dbMap map[string]specificPolicy
- policy *specificPolicy
- source *gorm.DB
- }
- func (dbPool *DbPool) switchDb(dbName string) *gorm.DB {
- *dbPool.policy = dbPool.dbMap[dbName]
- return dbPool.source.Clauses(dbresolver.Write)
- }
- // NewAdapter Adapter的构造函数
- // Params : databaseName,tableName,dbSpecified
- //
- // databaseName,{tableName/dbSpecified}
- // {database/dbSpecified}
- //
- // databaseName and tableName are user defined.
- // Their default value are "casbin" and "casbin_rule"
- //
- // dbSpecified is an optional bool parameter. The default value is false.
- // 是否在dataSourceName中指定了一个现有的DB。
- // 如果dbSpecified==true,则需要确保dataSourceName中的DB存在。
- // 如果dbSpecified==false,适配器将自动创建一个名为databaseName的数据库。
- func NewAdapter(driverName string, dataSourceName string, params ...interface{}) (*Adapter, error) {
- a := &Adapter{}
- a.driverName = driverName
- a.dataSourceName = dataSourceName
- a.tableName = defaultTableName
- a.databaseName = defaultDatabaseName
- a.dbSpecified = false
- if len(params) == 1 {
- switch p1 := params[0].(type) {
- case bool:
- a.dbSpecified = p1
- case string:
- a.databaseName = p1
- default:
- return nil, errors.New("wrong format")
- }
- } else if len(params) == 2 {
- switch p2 := params[1].(type) {
- case bool:
- a.dbSpecified = p2
- p1, ok := params[0].(string)
- if !ok {
- return nil, errors.New("wrong format")
- }
- a.databaseName = p1
- case string:
- p1, ok := params[0].(string)
- if !ok {
- return nil, errors.New("wrong format")
- }
- a.databaseName = p1
- a.tableName = p2
- default:
- return nil, errors.New("wrong format")
- }
- } else if len(params) == 3 {
- if p3, ok := params[2].(bool); ok {
- a.dbSpecified = p3
- a.databaseName = params[0].(string)
- a.tableName = params[1].(string)
- } else {
- return nil, errors.New("wrong format")
- }
- } else if len(params) != 0 {
- return nil, errors.New("too many parameters")
- }
- // Open the DB, create it if not existed.
- err := a.Open()
- if err != nil {
- return nil, err
- }
- // Call the destructor when the object is released.
- runtime.SetFinalizer(a, finalizer)
- return a, nil
- }
- // NewAdapterByDBUseTableName creates gorm-adapter by an existing Gorm instance and the specified table prefix and table name
- // Example: gormadapter.NewAdapterByDBUseTableName(&db, "cms", "casbin") Automatically generate table name like this "cms_casbin"
- func NewAdapterByDBUseTableName(db *gorm.DB, prefix string, tableName string) (*Adapter, error) {
- if len(tableName) == 0 {
- tableName = defaultTableName
- }
- a := &Adapter{
- tablePrefix: prefix,
- tableName: tableName,
- }
- a.db = db.Scopes(a.casbinRuleTable()).Session(&gorm.Session{Context: db.Statement.Context})
- err := a.createTable()
- if err != nil {
- return a, err
- }
- return a, nil
- }
- // InitDbResolver 多个数据源支持
- // Example usage:
- // dbPool,err := InitDbResolver([]gorm.Dialector{mysql.Open(dsn),mysql.Open(dsn2)},[]string{"casbin1","casbin2"})
- // a := initAdapterWithGormInstanceByMulDb(t,dbPool,"casbin1","","casbin_rule1")
- // a = initAdapterWithGormInstanceByMulDb(t,dbPool,"casbin2","","casbin_rule2")/*
- func InitDbResolver(dbArr []gorm.Dialector, dbNames []string) (DbPool, error) {
- if len(dbArr) == 0 {
- panic("dbArr len is 0")
- }
- source, e := gorm.Open(dbArr[0])
- if e != nil {
- panic(e.Error())
- }
- var p specificPolicy
- p = 0
- err := source.Use(dbresolver.Register(dbresolver.Config{Policy: &p, Sources: dbArr}))
- dbMap := make(map[string]specificPolicy)
- for i := 0; i < len(dbNames); i++ {
- dbMap[dbNames[i]] = specificPolicy(i)
- }
- return DbPool{dbMap: dbMap, policy: &p, source: source}, err
- }
- func NewAdapterByMulDb(dbPool DbPool, dbName string, prefix string, tableName string) (*Adapter, error) {
- //change DB
- dbPool.switchDb(dbName)
- return NewAdapterByDBUseTableName(dbPool.source, prefix, tableName)
- }
- // NewFilteredAdapter FilteredAdapter 的构造函数.
- // Casbin will not automatically call LoadPolicy() for a filtered adapter.
- // Casbin不会自动为已筛选的适配器调用LoadPolicy()
- func NewFilteredAdapter(driverName string, dataSourceName string, params ...interface{}) (*Adapter, error) {
- adapter, err := NewAdapter(driverName, dataSourceName, params...)
- if err != nil {
- return nil, err
- }
- adapter.isFiltered = true
- return adapter, err
- }
- // NewAdapterByDB creates gorm-adapter by an existing Gorm instance
- func NewAdapterByDB(db *gorm.DB) (*Adapter, error) {
- return NewAdapterByDBUseTableName(db, "", defaultTableName)
- }
- func TurnOffAutoMigrate(db *gorm.DB) {
- ctx := db.Statement.Context
- if ctx == nil {
- ctx = context.Background()
- }
- ctx = context.WithValue(ctx, disableMigrateKey, false)
- *db = *db.WithContext(ctx)
- }
- func NewAdapterByDBWithCustomTable(db *gorm.DB, t interface{}, tableName ...string) (*Adapter, error) {
- ctx := db.Statement.Context
- if ctx == nil {
- ctx = context.Background()
- }
- ctx = context.WithValue(ctx, customTableKey, t)
- curTableName := defaultTableName
- if len(tableName) > 0 {
- curTableName = tableName[0]
- }
- return NewAdapterByDBUseTableName(db.WithContext(ctx), "", curTableName)
- }
- func openDBConnection(driverName, dataSourceName string) (*gorm.DB, error) {
- driver, ok := opens[driverName]
- if !ok {
- return nil, errors.New("database dialect is not supported")
- }
- return gorm.Open(driver(dataSourceName), &gorm.Config{})
- }
- func (a *Adapter) createDatabase() error {
- var err error
- db, err := openDBConnection(a.driverName, a.dataSourceName)
- if err != nil {
- return err
- }
- if a.driverName == "postgres" {
- if err = db.Exec("CREATE DATABASE " + a.databaseName).Error; err != nil {
- // 42P04 is duplicate_database
- if strings.Contains(fmt.Sprintf("%s", err), "42P04") {
- return nil
- }
- }
- } else if a.driverName != "sqlite3" {
- err = db.Exec("CREATE DATABASE IF NOT EXISTS " + a.databaseName).Error
- }
- if err != nil {
- return err
- }
- return nil
- }
- func (a *Adapter) Open() error {
- var err error
- var db *gorm.DB
- if a.dbSpecified {
- db, err = openDBConnection(a.driverName, a.dataSourceName)
- if err != nil {
- return err
- }
- } else {
- if err = a.createDatabase(); err != nil {
- return err
- }
- if a.driverName == "postgres" {
- db, err = openDBConnection(a.driverName, a.dataSourceName+" dbname="+a.databaseName)
- } else if a.driverName == "sqlite3" {
- db, err = openDBConnection(a.driverName, a.dataSourceName)
- } else {
- db, err = openDBConnection(a.driverName, a.dataSourceName+a.databaseName)
- }
- if err != nil {
- return err
- }
- }
- a.db = db.Scopes(a.casbinRuleTable()).Session(&gorm.Session{})
- return a.createTable()
- }
- // AddLogger adds logger to db
- func (a *Adapter) AddLogger(l logger.Interface) {
- a.db = a.db.Session(&gorm.Session{Logger: l, Context: a.db.Statement.Context})
- }
- func (a *Adapter) Close() error {
- finalizer(a)
- return nil
- }
- // getTableInstance return the dynamic table name
- func (a *Adapter) getTableInstance() *CasbinRule {
- return &CasbinRule{}
- }
- func (a *Adapter) getFullTableName() string {
- if a.tablePrefix != "" {
- return a.tablePrefix + "_" + a.tableName
- }
- return a.tableName
- }
- func (a *Adapter) casbinRuleTable() func(db *gorm.DB) *gorm.DB {
- return func(db *gorm.DB) *gorm.DB {
- tableName := a.getFullTableName()
- return db.Table(tableName)
- }
- }
- func (a *Adapter) createTable() error {
- disableMigrate := a.db.Statement.Context.Value(disableMigrateKey)
- if disableMigrate != nil {
- return nil
- }
- t := a.db.Statement.Context.Value(customTableKey)
- if t != nil {
- return a.db.AutoMigrate(t)
- }
- t = a.getTableInstance()
- if err := a.db.AutoMigrate(t); err != nil {
- return err
- }
- tableName := a.getFullTableName()
- index := strings.ReplaceAll("idx_"+tableName, ".", "_")
- hasIndex := a.db.Migrator().HasIndex(t, index)
- if !hasIndex {
- if err := a.db.Exec(fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (p_type,v0,v1,v2,v3,v4,v5,v6,v7)", index, tableName)).Error; err != nil {
- return err
- }
- }
- return nil
- }
- func (a *Adapter) dropTable() error {
- t := a.db.Statement.Context.Value(customTableKey)
- if t == nil {
- return a.db.Migrator().DropTable(a.getTableInstance())
- }
- return a.db.Migrator().DropTable(t)
- }
- func (a *Adapter) truncateTable() error {
- if a.driverName == "sqlite3" {
- return a.db.Exec(fmt.Sprintf("DELETE FROM %s", a.getFullTableName())).Error
- }
- return a.db.Exec(fmt.Sprintf("TRUNCATE TABLE %s", a.getFullTableName())).Error
- }
- func loadPolicyLine(line CasbinRule, model model.Model) {
- var p = []string{line.PType,
- line.V0, line.V1, line.V2,
- line.V3, line.V4, line.V5,
- line.V6, line.V7}
- index := len(p) - 1
- for p[index] == "" {
- index--
- }
- index += 1
- p = p[:index]
- err := persist.LoadPolicyArray(p, model)
- if err != nil {
- return
- }
- }
- // LoadPolicy 从数据库加载策略
- func (a *Adapter) LoadPolicy(model model.Model) error {
- var lines []CasbinRule
- if err := a.db.Order("Id").Find(&lines).Error; err != nil {
- return err
- }
- for _, line := range lines {
- loadPolicyLine(line, model)
- }
- return nil
- }
- // LoadFilteredPolicy 仅加载与筛选器匹配的策略规则
- func (a *Adapter) LoadFilteredPolicy(model model.Model, filter interface{}) error {
- var lines []CasbinRule
- filterValue, ok := filter.(Filter)
- if !ok {
- return errors.New("invalid filter type")
- }
- if err := a.db.Scopes(a.filterQuery(a.db, filterValue)).Order("Id").Find(&lines).Error; err != nil {
- return err
- }
- for _, line := range lines {
- loadPolicyLine(line, model)
- }
- a.isFiltered = true
- return nil
- }
- // IsFiltered returns true if the loaded policy has been filtered.
- func (a *Adapter) IsFiltered() bool {
- return a.isFiltered
- }
- // filterQuery builds the gorm query to match the rule filter to use within a scope.
- func (a *Adapter) filterQuery(db *gorm.DB, filter Filter) func(db *gorm.DB) *gorm.DB {
- return func(db *gorm.DB) *gorm.DB {
- if len(filter.PType) > 0 {
- db = db.Where("p_type in (?)", filter.PType)
- }
- if len(filter.V0) > 0 {
- db = db.Where("v0 in (?)", filter.V0)
- }
- if len(filter.V1) > 0 {
- db = db.Where("v1 in (?)", filter.V1)
- }
- if len(filter.V2) > 0 {
- db = db.Where("v2 in (?)", filter.V2)
- }
- if len(filter.V3) > 0 {
- db = db.Where("v3 in (?)", filter.V3)
- }
- if len(filter.V4) > 0 {
- db = db.Where("v4 in (?)", filter.V4)
- }
- if len(filter.V5) > 0 {
- db = db.Where("v5 in (?)", filter.V5)
- }
- if len(filter.V6) > 0 {
- db = db.Where("v6 in (?)", filter.V6)
- }
- if len(filter.V7) > 0 {
- db = db.Where("v7 in (?)", filter.V7)
- }
- return db
- }
- }
- func (a *Adapter) savePolicyLine(pType string, rule []string) CasbinRule {
- line := a.getTableInstance()
- line.PType = pType
- if len(rule) > 0 {
- line.V0 = rule[0]
- }
- if len(rule) > 1 {
- line.V1 = rule[1]
- }
- if len(rule) > 2 {
- line.V2 = rule[2]
- }
- if len(rule) > 3 {
- line.V3 = rule[3]
- }
- if len(rule) > 4 {
- line.V4 = rule[4]
- }
- if len(rule) > 5 {
- line.V5 = rule[5]
- }
- if len(rule) > 6 {
- line.V6 = rule[6]
- }
- if len(rule) > 7 {
- line.V7 = rule[7]
- }
- return *line
- }
- // SavePolicy 将策略保存到数据库
- func (a *Adapter) SavePolicy(model model.Model) error {
- err := a.truncateTable()
- if err != nil {
- return err
- }
- var lines []CasbinRule
- flushEvery := 1000
- for pType, ast := range model["p"] {
- for _, rule := range ast.Policy {
- lines = append(lines, a.savePolicyLine(pType, rule))
- if len(lines) > flushEvery {
- if err := a.db.Create(&lines).Error; err != nil {
- return err
- }
- lines = nil
- }
- }
- }
- for pType, ast := range model["g"] {
- for _, rule := range ast.Policy {
- lines = append(lines, a.savePolicyLine(pType, rule))
- if len(lines) > flushEvery {
- if err := a.db.Create(&lines).Error; err != nil {
- return err
- }
- lines = nil
- }
- }
- }
- if len(lines) > 0 {
- if err := a.db.Create(&lines).Error; err != nil {
- return err
- }
- }
- return nil
- }
- // AddPolicy 将策略规则添加到存储中
- func (a *Adapter) AddPolicy(sec string, pType string, rule []string) error {
- line := a.savePolicyLine(pType, rule)
- err := a.db.Create(&line).Error
- return err
- }
- // RemovePolicy 从存储中删除策略规则。
- func (a *Adapter) RemovePolicy(sec string, pType string, rule []string) error {
- line := a.savePolicyLine(pType, rule)
- err := a.rawDelete(a.db, line) //can't use db.Delete as we're not using primary key http://jinzhu.me/gorm/crud.html#delete
- return err
- }
- // AddPolicies adds multiple policy rules to the storage.
- func (a *Adapter) AddPolicies(sec string, pType string, rules [][]string) error {
- var lines []CasbinRule
- for _, rule := range rules {
- line := a.savePolicyLine(pType, rule)
- lines = append(lines, line)
- }
- return a.db.Create(&lines).Error
- }
- // RemovePolicies removes multiple policy rules from the storage.
- func (a *Adapter) RemovePolicies(sec string, pType string, rules [][]string) error {
- return a.db.Transaction(func(tx *gorm.DB) error {
- for _, rule := range rules {
- line := a.savePolicyLine(pType, rule)
- if err := a.rawDelete(tx, line); err != nil { //can't use db.Delete as we're not using primary key http://jinzhu.me/gorm/crud.html#delete
- return err
- }
- }
- return nil
- })
- }
- // RemoveFilteredPolicy removes policy rules that match the filter from the storage.
- func (a *Adapter) RemoveFilteredPolicy(sec string, pType string, fieldIndex int, fieldValues ...string) error {
- line := a.getTableInstance()
- line.PType = pType
- if fieldIndex == -1 {
- return a.rawDelete(a.db, *line)
- }
- err := checkQueryField(fieldValues)
- if err != nil {
- return err
- }
- if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) {
- line.V0 = fieldValues[0-fieldIndex]
- }
- if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) {
- line.V1 = fieldValues[1-fieldIndex]
- }
- if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) {
- line.V2 = fieldValues[2-fieldIndex]
- }
- if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) {
- line.V3 = fieldValues[3-fieldIndex]
- }
- if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) {
- line.V4 = fieldValues[4-fieldIndex]
- }
- if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) {
- line.V5 = fieldValues[5-fieldIndex]
- }
- if fieldIndex <= 6 && 6 < fieldIndex+len(fieldValues) {
- line.V6 = fieldValues[6-fieldIndex]
- }
- if fieldIndex <= 7 && 7 < fieldIndex+len(fieldValues) {
- line.V7 = fieldValues[7-fieldIndex]
- }
- err = a.rawDelete(a.db, *line)
- return err
- }
- // checkQueryField 确保字段不会全部为空(string --> "")
- func checkQueryField(fieldValues []string) error {
- for _, fieldValue := range fieldValues {
- if fieldValue != "" {
- return nil
- }
- }
- return errors.New("the query field cannot all be empty string (\"\"), please check")
- }
- func (a *Adapter) rawDelete(db *gorm.DB, line CasbinRule) error {
- queryArgs := []interface{}{line.PType}
- queryStr := "p_type = ?"
- if line.V0 != "" {
- queryStr += " and v0 = ?"
- queryArgs = append(queryArgs, line.V0)
- }
- if line.V1 != "" {
- queryStr += " and v1 = ?"
- queryArgs = append(queryArgs, line.V1)
- }
- if line.V2 != "" {
- queryStr += " and v2 = ?"
- queryArgs = append(queryArgs, line.V2)
- }
- if line.V3 != "" {
- queryStr += " and v3 = ?"
- queryArgs = append(queryArgs, line.V3)
- }
- if line.V4 != "" {
- queryStr += " and v4 = ?"
- queryArgs = append(queryArgs, line.V4)
- }
- if line.V5 != "" {
- queryStr += " and v5 = ?"
- queryArgs = append(queryArgs, line.V5)
- }
- if line.V6 != "" {
- queryStr += " and v6 = ?"
- queryArgs = append(queryArgs, line.V6)
- }
- if line.V7 != "" {
- queryStr += " and v7 = ?"
- queryArgs = append(queryArgs, line.V7)
- }
- args := append([]interface{}{queryStr}, queryArgs...)
- err := db.Delete(a.getTableInstance(), args...).Error
- return err
- }
- func appendWhere(line CasbinRule) (string, []interface{}) {
- queryArgs := []interface{}{line.PType}
- queryStr := "p_type = ?"
- if line.V0 != "" {
- queryStr += " and v0 = ?"
- queryArgs = append(queryArgs, line.V0)
- }
- if line.V1 != "" {
- queryStr += " and v1 = ?"
- queryArgs = append(queryArgs, line.V1)
- }
- if line.V2 != "" {
- queryStr += " and v2 = ?"
- queryArgs = append(queryArgs, line.V2)
- }
- if line.V3 != "" {
- queryStr += " and v3 = ?"
- queryArgs = append(queryArgs, line.V3)
- }
- if line.V4 != "" {
- queryStr += " and v4 = ?"
- queryArgs = append(queryArgs, line.V4)
- }
- if line.V5 != "" {
- queryStr += " and v5 = ?"
- queryArgs = append(queryArgs, line.V5)
- }
- if line.V6 != "" {
- queryStr += " and v6 = ?"
- queryArgs = append(queryArgs, line.V6)
- }
- if line.V7 != "" {
- queryStr += " and v7 = ?"
- queryArgs = append(queryArgs, line.V7)
- }
- return queryStr, queryArgs
- }
- // UpdatePolicy updates a new policy rule to DB.
- func (a *Adapter) UpdatePolicy(sec string, pType string, oldRule, newPolicy []string) error {
- oldLine := a.savePolicyLine(pType, oldRule)
- newLine := a.savePolicyLine(pType, newPolicy)
- return a.db.Model(&oldLine).Where(&oldLine).Updates(newLine).Error
- }
- func (a *Adapter) UpdatePolicies(sec string, pType string, oldRules, newRules [][]string) error {
- oldPolicies := make([]CasbinRule, 0, len(oldRules))
- newPolicies := make([]CasbinRule, 0, len(oldRules))
- for _, oldRule := range oldRules {
- oldPolicies = append(oldPolicies, a.savePolicyLine(pType, oldRule))
- }
- for _, newRule := range newRules {
- newPolicies = append(newPolicies, a.savePolicyLine(pType, newRule))
- }
- tx := a.db.Begin()
- for i := range oldPolicies {
- if err := tx.Model(&oldPolicies[i]).Where(&oldPolicies[i]).Updates(newPolicies[i]).Error; err != nil {
- tx.Rollback()
- return err
- }
- }
- return tx.Commit().Error
- }
- func (a *Adapter) UpdateFilteredPolicies(sec string, pType string, newPolicies [][]string, fieldIndex int, fieldValues ...string) ([][]string, error) {
- // UpdateFilteredPolicies deletes old rules and adds new rules.
- line := a.getTableInstance()
- line.PType = pType
- if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) {
- line.V0 = fieldValues[0-fieldIndex]
- }
- if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) {
- line.V1 = fieldValues[1-fieldIndex]
- }
- if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) {
- line.V2 = fieldValues[2-fieldIndex]
- }
- if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) {
- line.V3 = fieldValues[3-fieldIndex]
- }
- if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) {
- line.V4 = fieldValues[4-fieldIndex]
- }
- if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) {
- line.V5 = fieldValues[5-fieldIndex]
- }
- if fieldIndex <= 6 && 6 < fieldIndex+len(fieldValues) {
- line.V6 = fieldValues[6-fieldIndex]
- }
- if fieldIndex <= 7 && 7 < fieldIndex+len(fieldValues) {
- line.V7 = fieldValues[7-fieldIndex]
- }
- newP := make([]CasbinRule, 0, len(newPolicies))
- oldP := make([]CasbinRule, 0)
- for _, newRule := range newPolicies {
- newP = append(newP, a.savePolicyLine(pType, newRule))
- }
- tx := a.db.Begin()
- for i := range newP {
- str, args := line.queryString()
- if err := tx.Where(str, args...).Find(&oldP).Error; err != nil {
- tx.Rollback()
- return nil, err
- }
- if err := tx.Where(str, args...).Delete([]CasbinRule{}).Error; err != nil {
- tx.Rollback()
- return nil, err
- }
- if err := tx.Create(&newP[i]).Error; err != nil {
- tx.Rollback()
- return nil, err
- }
- }
- // return deleted rulues
- oldPolicies := make([][]string, 0)
- for _, v := range oldP {
- oldPolicy := v.toStringPolicy()
- oldPolicies = append(oldPolicies, oldPolicy)
- }
- return oldPolicies, tx.Commit().Error
- }
- func (c *CasbinRule) queryString() (interface{}, []interface{}) {
- queryArgs := []interface{}{c.PType}
- queryStr := "p_type = ?"
- if c.V0 != "" {
- queryStr += " and v0 = ?"
- queryArgs = append(queryArgs, c.V0)
- }
- if c.V1 != "" {
- queryStr += " and v1 = ?"
- queryArgs = append(queryArgs, c.V1)
- }
- if c.V2 != "" {
- queryStr += " and v2 = ?"
- queryArgs = append(queryArgs, c.V2)
- }
- if c.V3 != "" {
- queryStr += " and v3 = ?"
- queryArgs = append(queryArgs, c.V3)
- }
- if c.V4 != "" {
- queryStr += " and v4 = ?"
- queryArgs = append(queryArgs, c.V4)
- }
- if c.V5 != "" {
- queryStr += " and v5 = ?"
- queryArgs = append(queryArgs, c.V5)
- }
- if c.V6 != "" {
- queryStr += " and v6 = ?"
- queryArgs = append(queryArgs, c.V6)
- }
- if c.V7 != "" {
- queryStr += " and v7 = ?"
- queryArgs = append(queryArgs, c.V7)
- }
- return queryStr, queryArgs
- }
- func (c *CasbinRule) toStringPolicy() []string {
- policy := make([]string, 0)
- if c.PType != "" {
- policy = append(policy, c.PType)
- }
- if c.V0 != "" {
- policy = append(policy, c.V0)
- }
- if c.V1 != "" {
- policy = append(policy, c.V1)
- }
- if c.V2 != "" {
- policy = append(policy, c.V2)
- }
- if c.V3 != "" {
- policy = append(policy, c.V3)
- }
- if c.V4 != "" {
- policy = append(policy, c.V4)
- }
- if c.V5 != "" {
- policy = append(policy, c.V5)
- }
- if c.V6 != "" {
- policy = append(policy, c.V6)
- }
- if c.V7 != "" {
- policy = append(policy, c.V7)
- }
- return policy
- }
|