adapter.go 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894
  1. package gormadapter
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "runtime"
  7. "strings"
  8. "github.com/casbin/casbin/v2/model"
  9. "github.com/casbin/casbin/v2/persist"
  10. "gorm.io/gorm"
  11. "gorm.io/gorm/logger"
  12. "gorm.io/plugin/dbresolver"
  13. )
  14. const (
  15. defaultDatabaseName = "casbin"
  16. defaultTableName = "casbin_rule"
  17. )
  18. const disableMigrateKey = "disableMigrateKey"
  19. const customTableKey = "customTableKey"
  20. type CasbinRule struct {
  21. ID uint `gorm:"primaryKey;autoIncrement"`
  22. PType string `gorm:"size:100"`
  23. V0 string `gorm:"size:100"`
  24. V1 string `gorm:"size:100"`
  25. V2 string `gorm:"size:100"`
  26. V3 string `gorm:"size:100"`
  27. V4 string `gorm:"size:100"`
  28. V5 string `gorm:"size:100"`
  29. V6 string `gorm:"size:25"`
  30. V7 string `gorm:"size:25"`
  31. }
  32. func (*CasbinRule) TableName() string {
  33. return "casbin_rule"
  34. }
  35. type Filter struct {
  36. PType []string
  37. V0 []string
  38. V1 []string
  39. V2 []string
  40. V3 []string
  41. V4 []string
  42. V5 []string
  43. V6 []string
  44. V7 []string
  45. }
  46. // Adapter 策略存储的Gorm适配器
  47. type Adapter struct {
  48. driverName string
  49. dataSourceName string
  50. databaseName string
  51. tablePrefix string
  52. tableName string
  53. dbSpecified bool
  54. db *gorm.DB
  55. isFiltered bool
  56. }
  57. // finalizer Adapter的析构函数
  58. func finalizer(a *Adapter) {
  59. sqlDB, err := a.db.DB()
  60. if err != nil {
  61. panic(err)
  62. }
  63. err = sqlDB.Close()
  64. if err != nil {
  65. panic(err)
  66. }
  67. }
  68. // 根据表名选择conn(use map store name-index)
  69. type specificPolicy int
  70. func (p *specificPolicy) Resolve(connPools []gorm.ConnPool) gorm.ConnPool {
  71. return connPools[*p]
  72. }
  73. type DbPool struct {
  74. dbMap map[string]specificPolicy
  75. policy *specificPolicy
  76. source *gorm.DB
  77. }
  78. func (dbPool *DbPool) switchDb(dbName string) *gorm.DB {
  79. *dbPool.policy = dbPool.dbMap[dbName]
  80. return dbPool.source.Clauses(dbresolver.Write)
  81. }
  82. // NewAdapter Adapter的构造函数
  83. // Params : databaseName,tableName,dbSpecified
  84. //
  85. // databaseName,{tableName/dbSpecified}
  86. // {database/dbSpecified}
  87. //
  88. // databaseName and tableName are user defined.
  89. // Their default value are "casbin" and "casbin_rule"
  90. //
  91. // dbSpecified is an optional bool parameter. The default value is false.
  92. // 是否在dataSourceName中指定了一个现有的DB。
  93. // 如果dbSpecified==true,则需要确保dataSourceName中的DB存在。
  94. // 如果dbSpecified==false,适配器将自动创建一个名为databaseName的数据库。
  95. func NewAdapter(driverName string, dataSourceName string, params ...interface{}) (*Adapter, error) {
  96. a := &Adapter{}
  97. a.driverName = driverName
  98. a.dataSourceName = dataSourceName
  99. a.tableName = defaultTableName
  100. a.databaseName = defaultDatabaseName
  101. a.dbSpecified = false
  102. if len(params) == 1 {
  103. switch p1 := params[0].(type) {
  104. case bool:
  105. a.dbSpecified = p1
  106. case string:
  107. a.databaseName = p1
  108. default:
  109. return nil, errors.New("wrong format")
  110. }
  111. } else if len(params) == 2 {
  112. switch p2 := params[1].(type) {
  113. case bool:
  114. a.dbSpecified = p2
  115. p1, ok := params[0].(string)
  116. if !ok {
  117. return nil, errors.New("wrong format")
  118. }
  119. a.databaseName = p1
  120. case string:
  121. p1, ok := params[0].(string)
  122. if !ok {
  123. return nil, errors.New("wrong format")
  124. }
  125. a.databaseName = p1
  126. a.tableName = p2
  127. default:
  128. return nil, errors.New("wrong format")
  129. }
  130. } else if len(params) == 3 {
  131. if p3, ok := params[2].(bool); ok {
  132. a.dbSpecified = p3
  133. a.databaseName = params[0].(string)
  134. a.tableName = params[1].(string)
  135. } else {
  136. return nil, errors.New("wrong format")
  137. }
  138. } else if len(params) != 0 {
  139. return nil, errors.New("too many parameters")
  140. }
  141. // Open the DB, create it if not existed.
  142. err := a.Open()
  143. if err != nil {
  144. return nil, err
  145. }
  146. // Call the destructor when the object is released.
  147. runtime.SetFinalizer(a, finalizer)
  148. return a, nil
  149. }
  150. // NewAdapterByDBUseTableName creates gorm-adapter by an existing Gorm instance and the specified table prefix and table name
  151. // Example: gormadapter.NewAdapterByDBUseTableName(&db, "cms", "casbin") Automatically generate table name like this "cms_casbin"
  152. func NewAdapterByDBUseTableName(db *gorm.DB, prefix string, tableName string) (*Adapter, error) {
  153. if len(tableName) == 0 {
  154. tableName = defaultTableName
  155. }
  156. a := &Adapter{
  157. tablePrefix: prefix,
  158. tableName: tableName,
  159. }
  160. a.db = db.Scopes(a.casbinRuleTable()).Session(&gorm.Session{Context: db.Statement.Context})
  161. err := a.createTable()
  162. if err != nil {
  163. return a, err
  164. }
  165. return a, nil
  166. }
  167. // InitDbResolver 多个数据源支持
  168. // Example usage:
  169. // dbPool,err := InitDbResolver([]gorm.Dialector{mysql.Open(dsn),mysql.Open(dsn2)},[]string{"casbin1","casbin2"})
  170. // a := initAdapterWithGormInstanceByMulDb(t,dbPool,"casbin1","","casbin_rule1")
  171. // a = initAdapterWithGormInstanceByMulDb(t,dbPool,"casbin2","","casbin_rule2")/*
  172. func InitDbResolver(dbArr []gorm.Dialector, dbNames []string) (DbPool, error) {
  173. if len(dbArr) == 0 {
  174. panic("dbArr len is 0")
  175. }
  176. source, e := gorm.Open(dbArr[0])
  177. if e != nil {
  178. panic(e.Error())
  179. }
  180. var p specificPolicy
  181. p = 0
  182. err := source.Use(dbresolver.Register(dbresolver.Config{Policy: &p, Sources: dbArr}))
  183. dbMap := make(map[string]specificPolicy)
  184. for i := 0; i < len(dbNames); i++ {
  185. dbMap[dbNames[i]] = specificPolicy(i)
  186. }
  187. return DbPool{dbMap: dbMap, policy: &p, source: source}, err
  188. }
  189. func NewAdapterByMulDb(dbPool DbPool, dbName string, prefix string, tableName string) (*Adapter, error) {
  190. //change DB
  191. dbPool.switchDb(dbName)
  192. return NewAdapterByDBUseTableName(dbPool.source, prefix, tableName)
  193. }
  194. // NewFilteredAdapter FilteredAdapter 的构造函数.
  195. // Casbin will not automatically call LoadPolicy() for a filtered adapter.
  196. // Casbin不会自动为已筛选的适配器调用LoadPolicy()
  197. func NewFilteredAdapter(driverName string, dataSourceName string, params ...interface{}) (*Adapter, error) {
  198. adapter, err := NewAdapter(driverName, dataSourceName, params...)
  199. if err != nil {
  200. return nil, err
  201. }
  202. adapter.isFiltered = true
  203. return adapter, err
  204. }
  205. // NewAdapterByDB creates gorm-adapter by an existing Gorm instance
  206. func NewAdapterByDB(db *gorm.DB) (*Adapter, error) {
  207. return NewAdapterByDBUseTableName(db, "", defaultTableName)
  208. }
  209. func TurnOffAutoMigrate(db *gorm.DB) {
  210. ctx := db.Statement.Context
  211. if ctx == nil {
  212. ctx = context.Background()
  213. }
  214. ctx = context.WithValue(ctx, disableMigrateKey, false)
  215. *db = *db.WithContext(ctx)
  216. }
  217. func NewAdapterByDBWithCustomTable(db *gorm.DB, t interface{}, tableName ...string) (*Adapter, error) {
  218. ctx := db.Statement.Context
  219. if ctx == nil {
  220. ctx = context.Background()
  221. }
  222. ctx = context.WithValue(ctx, customTableKey, t)
  223. curTableName := defaultTableName
  224. if len(tableName) > 0 {
  225. curTableName = tableName[0]
  226. }
  227. return NewAdapterByDBUseTableName(db.WithContext(ctx), "", curTableName)
  228. }
  229. func openDBConnection(driverName, dataSourceName string) (*gorm.DB, error) {
  230. driver, ok := opens[driverName]
  231. if !ok {
  232. return nil, errors.New("database dialect is not supported")
  233. }
  234. return gorm.Open(driver(dataSourceName), &gorm.Config{})
  235. }
  236. func (a *Adapter) createDatabase() error {
  237. var err error
  238. db, err := openDBConnection(a.driverName, a.dataSourceName)
  239. if err != nil {
  240. return err
  241. }
  242. if a.driverName == "postgres" {
  243. if err = db.Exec("CREATE DATABASE " + a.databaseName).Error; err != nil {
  244. // 42P04 is duplicate_database
  245. if strings.Contains(fmt.Sprintf("%s", err), "42P04") {
  246. return nil
  247. }
  248. }
  249. } else if a.driverName != "sqlite3" {
  250. err = db.Exec("CREATE DATABASE IF NOT EXISTS " + a.databaseName).Error
  251. }
  252. if err != nil {
  253. return err
  254. }
  255. return nil
  256. }
  257. func (a *Adapter) Open() error {
  258. var err error
  259. var db *gorm.DB
  260. if a.dbSpecified {
  261. db, err = openDBConnection(a.driverName, a.dataSourceName)
  262. if err != nil {
  263. return err
  264. }
  265. } else {
  266. if err = a.createDatabase(); err != nil {
  267. return err
  268. }
  269. if a.driverName == "postgres" {
  270. db, err = openDBConnection(a.driverName, a.dataSourceName+" dbname="+a.databaseName)
  271. } else if a.driverName == "sqlite3" {
  272. db, err = openDBConnection(a.driverName, a.dataSourceName)
  273. } else {
  274. db, err = openDBConnection(a.driverName, a.dataSourceName+a.databaseName)
  275. }
  276. if err != nil {
  277. return err
  278. }
  279. }
  280. a.db = db.Scopes(a.casbinRuleTable()).Session(&gorm.Session{})
  281. return a.createTable()
  282. }
  283. // AddLogger adds logger to db
  284. func (a *Adapter) AddLogger(l logger.Interface) {
  285. a.db = a.db.Session(&gorm.Session{Logger: l, Context: a.db.Statement.Context})
  286. }
  287. func (a *Adapter) Close() error {
  288. finalizer(a)
  289. return nil
  290. }
  291. // getTableInstance return the dynamic table name
  292. func (a *Adapter) getTableInstance() *CasbinRule {
  293. return &CasbinRule{}
  294. }
  295. func (a *Adapter) getFullTableName() string {
  296. if a.tablePrefix != "" {
  297. return a.tablePrefix + "_" + a.tableName
  298. }
  299. return a.tableName
  300. }
  301. func (a *Adapter) casbinRuleTable() func(db *gorm.DB) *gorm.DB {
  302. return func(db *gorm.DB) *gorm.DB {
  303. tableName := a.getFullTableName()
  304. return db.Table(tableName)
  305. }
  306. }
  307. func (a *Adapter) createTable() error {
  308. disableMigrate := a.db.Statement.Context.Value(disableMigrateKey)
  309. if disableMigrate != nil {
  310. return nil
  311. }
  312. t := a.db.Statement.Context.Value(customTableKey)
  313. if t != nil {
  314. return a.db.AutoMigrate(t)
  315. }
  316. t = a.getTableInstance()
  317. if err := a.db.AutoMigrate(t); err != nil {
  318. return err
  319. }
  320. tableName := a.getFullTableName()
  321. index := strings.ReplaceAll("idx_"+tableName, ".", "_")
  322. hasIndex := a.db.Migrator().HasIndex(t, index)
  323. if !hasIndex {
  324. if err := a.db.Exec(fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (p_type,v0,v1,v2,v3,v4,v5,v6,v7)", index, tableName)).Error; err != nil {
  325. return err
  326. }
  327. }
  328. return nil
  329. }
  330. func (a *Adapter) dropTable() error {
  331. t := a.db.Statement.Context.Value(customTableKey)
  332. if t == nil {
  333. return a.db.Migrator().DropTable(a.getTableInstance())
  334. }
  335. return a.db.Migrator().DropTable(t)
  336. }
  337. func (a *Adapter) truncateTable() error {
  338. if a.driverName == "sqlite3" {
  339. return a.db.Exec(fmt.Sprintf("DELETE FROM %s", a.getFullTableName())).Error
  340. }
  341. return a.db.Exec(fmt.Sprintf("TRUNCATE TABLE %s", a.getFullTableName())).Error
  342. }
  343. func loadPolicyLine(line CasbinRule, model model.Model) {
  344. var p = []string{line.PType,
  345. line.V0, line.V1, line.V2,
  346. line.V3, line.V4, line.V5,
  347. line.V6, line.V7}
  348. index := len(p) - 1
  349. for p[index] == "" {
  350. index--
  351. }
  352. index += 1
  353. p = p[:index]
  354. err := persist.LoadPolicyArray(p, model)
  355. if err != nil {
  356. return
  357. }
  358. }
  359. // LoadPolicy 从数据库加载策略
  360. func (a *Adapter) LoadPolicy(model model.Model) error {
  361. var lines []CasbinRule
  362. if err := a.db.Order("Id").Find(&lines).Error; err != nil {
  363. return err
  364. }
  365. for _, line := range lines {
  366. loadPolicyLine(line, model)
  367. }
  368. return nil
  369. }
  370. // LoadFilteredPolicy 仅加载与筛选器匹配的策略规则
  371. func (a *Adapter) LoadFilteredPolicy(model model.Model, filter interface{}) error {
  372. var lines []CasbinRule
  373. filterValue, ok := filter.(Filter)
  374. if !ok {
  375. return errors.New("invalid filter type")
  376. }
  377. if err := a.db.Scopes(a.filterQuery(a.db, filterValue)).Order("Id").Find(&lines).Error; err != nil {
  378. return err
  379. }
  380. for _, line := range lines {
  381. loadPolicyLine(line, model)
  382. }
  383. a.isFiltered = true
  384. return nil
  385. }
  386. // IsFiltered returns true if the loaded policy has been filtered.
  387. func (a *Adapter) IsFiltered() bool {
  388. return a.isFiltered
  389. }
  390. // filterQuery builds the gorm query to match the rule filter to use within a scope.
  391. func (a *Adapter) filterQuery(db *gorm.DB, filter Filter) func(db *gorm.DB) *gorm.DB {
  392. return func(db *gorm.DB) *gorm.DB {
  393. if len(filter.PType) > 0 {
  394. db = db.Where("p_type in (?)", filter.PType)
  395. }
  396. if len(filter.V0) > 0 {
  397. db = db.Where("v0 in (?)", filter.V0)
  398. }
  399. if len(filter.V1) > 0 {
  400. db = db.Where("v1 in (?)", filter.V1)
  401. }
  402. if len(filter.V2) > 0 {
  403. db = db.Where("v2 in (?)", filter.V2)
  404. }
  405. if len(filter.V3) > 0 {
  406. db = db.Where("v3 in (?)", filter.V3)
  407. }
  408. if len(filter.V4) > 0 {
  409. db = db.Where("v4 in (?)", filter.V4)
  410. }
  411. if len(filter.V5) > 0 {
  412. db = db.Where("v5 in (?)", filter.V5)
  413. }
  414. if len(filter.V6) > 0 {
  415. db = db.Where("v6 in (?)", filter.V6)
  416. }
  417. if len(filter.V7) > 0 {
  418. db = db.Where("v7 in (?)", filter.V7)
  419. }
  420. return db
  421. }
  422. }
  423. func (a *Adapter) savePolicyLine(pType string, rule []string) CasbinRule {
  424. line := a.getTableInstance()
  425. line.PType = pType
  426. if len(rule) > 0 {
  427. line.V0 = rule[0]
  428. }
  429. if len(rule) > 1 {
  430. line.V1 = rule[1]
  431. }
  432. if len(rule) > 2 {
  433. line.V2 = rule[2]
  434. }
  435. if len(rule) > 3 {
  436. line.V3 = rule[3]
  437. }
  438. if len(rule) > 4 {
  439. line.V4 = rule[4]
  440. }
  441. if len(rule) > 5 {
  442. line.V5 = rule[5]
  443. }
  444. if len(rule) > 6 {
  445. line.V6 = rule[6]
  446. }
  447. if len(rule) > 7 {
  448. line.V7 = rule[7]
  449. }
  450. return *line
  451. }
  452. // SavePolicy 将策略保存到数据库
  453. func (a *Adapter) SavePolicy(model model.Model) error {
  454. err := a.truncateTable()
  455. if err != nil {
  456. return err
  457. }
  458. var lines []CasbinRule
  459. flushEvery := 1000
  460. for pType, ast := range model["p"] {
  461. for _, rule := range ast.Policy {
  462. lines = append(lines, a.savePolicyLine(pType, rule))
  463. if len(lines) > flushEvery {
  464. if err := a.db.Create(&lines).Error; err != nil {
  465. return err
  466. }
  467. lines = nil
  468. }
  469. }
  470. }
  471. for pType, ast := range model["g"] {
  472. for _, rule := range ast.Policy {
  473. lines = append(lines, a.savePolicyLine(pType, rule))
  474. if len(lines) > flushEvery {
  475. if err := a.db.Create(&lines).Error; err != nil {
  476. return err
  477. }
  478. lines = nil
  479. }
  480. }
  481. }
  482. if len(lines) > 0 {
  483. if err := a.db.Create(&lines).Error; err != nil {
  484. return err
  485. }
  486. }
  487. return nil
  488. }
  489. // AddPolicy 将策略规则添加到存储中
  490. func (a *Adapter) AddPolicy(sec string, pType string, rule []string) error {
  491. line := a.savePolicyLine(pType, rule)
  492. err := a.db.Create(&line).Error
  493. return err
  494. }
  495. // RemovePolicy 从存储中删除策略规则。
  496. func (a *Adapter) RemovePolicy(sec string, pType string, rule []string) error {
  497. line := a.savePolicyLine(pType, rule)
  498. err := a.rawDelete(a.db, line) //can't use db.Delete as we're not using primary key http://jinzhu.me/gorm/crud.html#delete
  499. return err
  500. }
  501. // AddPolicies adds multiple policy rules to the storage.
  502. func (a *Adapter) AddPolicies(sec string, pType string, rules [][]string) error {
  503. var lines []CasbinRule
  504. for _, rule := range rules {
  505. line := a.savePolicyLine(pType, rule)
  506. lines = append(lines, line)
  507. }
  508. return a.db.Create(&lines).Error
  509. }
  510. // RemovePolicies removes multiple policy rules from the storage.
  511. func (a *Adapter) RemovePolicies(sec string, pType string, rules [][]string) error {
  512. return a.db.Transaction(func(tx *gorm.DB) error {
  513. for _, rule := range rules {
  514. line := a.savePolicyLine(pType, rule)
  515. if err := a.rawDelete(tx, line); err != nil { //can't use db.Delete as we're not using primary key http://jinzhu.me/gorm/crud.html#delete
  516. return err
  517. }
  518. }
  519. return nil
  520. })
  521. }
  522. // RemoveFilteredPolicy removes policy rules that match the filter from the storage.
  523. func (a *Adapter) RemoveFilteredPolicy(sec string, pType string, fieldIndex int, fieldValues ...string) error {
  524. line := a.getTableInstance()
  525. line.PType = pType
  526. if fieldIndex == -1 {
  527. return a.rawDelete(a.db, *line)
  528. }
  529. err := checkQueryField(fieldValues)
  530. if err != nil {
  531. return err
  532. }
  533. if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) {
  534. line.V0 = fieldValues[0-fieldIndex]
  535. }
  536. if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) {
  537. line.V1 = fieldValues[1-fieldIndex]
  538. }
  539. if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) {
  540. line.V2 = fieldValues[2-fieldIndex]
  541. }
  542. if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) {
  543. line.V3 = fieldValues[3-fieldIndex]
  544. }
  545. if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) {
  546. line.V4 = fieldValues[4-fieldIndex]
  547. }
  548. if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) {
  549. line.V5 = fieldValues[5-fieldIndex]
  550. }
  551. if fieldIndex <= 6 && 6 < fieldIndex+len(fieldValues) {
  552. line.V6 = fieldValues[6-fieldIndex]
  553. }
  554. if fieldIndex <= 7 && 7 < fieldIndex+len(fieldValues) {
  555. line.V7 = fieldValues[7-fieldIndex]
  556. }
  557. err = a.rawDelete(a.db, *line)
  558. return err
  559. }
  560. // checkQueryField 确保字段不会全部为空(string --> "")
  561. func checkQueryField(fieldValues []string) error {
  562. for _, fieldValue := range fieldValues {
  563. if fieldValue != "" {
  564. return nil
  565. }
  566. }
  567. return errors.New("the query field cannot all be empty string (\"\"), please check")
  568. }
  569. func (a *Adapter) rawDelete(db *gorm.DB, line CasbinRule) error {
  570. queryArgs := []interface{}{line.PType}
  571. queryStr := "p_type = ?"
  572. if line.V0 != "" {
  573. queryStr += " and v0 = ?"
  574. queryArgs = append(queryArgs, line.V0)
  575. }
  576. if line.V1 != "" {
  577. queryStr += " and v1 = ?"
  578. queryArgs = append(queryArgs, line.V1)
  579. }
  580. if line.V2 != "" {
  581. queryStr += " and v2 = ?"
  582. queryArgs = append(queryArgs, line.V2)
  583. }
  584. if line.V3 != "" {
  585. queryStr += " and v3 = ?"
  586. queryArgs = append(queryArgs, line.V3)
  587. }
  588. if line.V4 != "" {
  589. queryStr += " and v4 = ?"
  590. queryArgs = append(queryArgs, line.V4)
  591. }
  592. if line.V5 != "" {
  593. queryStr += " and v5 = ?"
  594. queryArgs = append(queryArgs, line.V5)
  595. }
  596. if line.V6 != "" {
  597. queryStr += " and v6 = ?"
  598. queryArgs = append(queryArgs, line.V6)
  599. }
  600. if line.V7 != "" {
  601. queryStr += " and v7 = ?"
  602. queryArgs = append(queryArgs, line.V7)
  603. }
  604. args := append([]interface{}{queryStr}, queryArgs...)
  605. err := db.Delete(a.getTableInstance(), args...).Error
  606. return err
  607. }
  608. func appendWhere(line CasbinRule) (string, []interface{}) {
  609. queryArgs := []interface{}{line.PType}
  610. queryStr := "p_type = ?"
  611. if line.V0 != "" {
  612. queryStr += " and v0 = ?"
  613. queryArgs = append(queryArgs, line.V0)
  614. }
  615. if line.V1 != "" {
  616. queryStr += " and v1 = ?"
  617. queryArgs = append(queryArgs, line.V1)
  618. }
  619. if line.V2 != "" {
  620. queryStr += " and v2 = ?"
  621. queryArgs = append(queryArgs, line.V2)
  622. }
  623. if line.V3 != "" {
  624. queryStr += " and v3 = ?"
  625. queryArgs = append(queryArgs, line.V3)
  626. }
  627. if line.V4 != "" {
  628. queryStr += " and v4 = ?"
  629. queryArgs = append(queryArgs, line.V4)
  630. }
  631. if line.V5 != "" {
  632. queryStr += " and v5 = ?"
  633. queryArgs = append(queryArgs, line.V5)
  634. }
  635. if line.V6 != "" {
  636. queryStr += " and v6 = ?"
  637. queryArgs = append(queryArgs, line.V6)
  638. }
  639. if line.V7 != "" {
  640. queryStr += " and v7 = ?"
  641. queryArgs = append(queryArgs, line.V7)
  642. }
  643. return queryStr, queryArgs
  644. }
  645. // UpdatePolicy updates a new policy rule to DB.
  646. func (a *Adapter) UpdatePolicy(sec string, pType string, oldRule, newPolicy []string) error {
  647. oldLine := a.savePolicyLine(pType, oldRule)
  648. newLine := a.savePolicyLine(pType, newPolicy)
  649. return a.db.Model(&oldLine).Where(&oldLine).Updates(newLine).Error
  650. }
  651. func (a *Adapter) UpdatePolicies(sec string, pType string, oldRules, newRules [][]string) error {
  652. oldPolicies := make([]CasbinRule, 0, len(oldRules))
  653. newPolicies := make([]CasbinRule, 0, len(oldRules))
  654. for _, oldRule := range oldRules {
  655. oldPolicies = append(oldPolicies, a.savePolicyLine(pType, oldRule))
  656. }
  657. for _, newRule := range newRules {
  658. newPolicies = append(newPolicies, a.savePolicyLine(pType, newRule))
  659. }
  660. tx := a.db.Begin()
  661. for i := range oldPolicies {
  662. if err := tx.Model(&oldPolicies[i]).Where(&oldPolicies[i]).Updates(newPolicies[i]).Error; err != nil {
  663. tx.Rollback()
  664. return err
  665. }
  666. }
  667. return tx.Commit().Error
  668. }
  669. func (a *Adapter) UpdateFilteredPolicies(sec string, pType string, newPolicies [][]string, fieldIndex int, fieldValues ...string) ([][]string, error) {
  670. // UpdateFilteredPolicies deletes old rules and adds new rules.
  671. line := a.getTableInstance()
  672. line.PType = pType
  673. if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) {
  674. line.V0 = fieldValues[0-fieldIndex]
  675. }
  676. if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) {
  677. line.V1 = fieldValues[1-fieldIndex]
  678. }
  679. if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) {
  680. line.V2 = fieldValues[2-fieldIndex]
  681. }
  682. if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) {
  683. line.V3 = fieldValues[3-fieldIndex]
  684. }
  685. if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) {
  686. line.V4 = fieldValues[4-fieldIndex]
  687. }
  688. if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) {
  689. line.V5 = fieldValues[5-fieldIndex]
  690. }
  691. if fieldIndex <= 6 && 6 < fieldIndex+len(fieldValues) {
  692. line.V6 = fieldValues[6-fieldIndex]
  693. }
  694. if fieldIndex <= 7 && 7 < fieldIndex+len(fieldValues) {
  695. line.V7 = fieldValues[7-fieldIndex]
  696. }
  697. newP := make([]CasbinRule, 0, len(newPolicies))
  698. oldP := make([]CasbinRule, 0)
  699. for _, newRule := range newPolicies {
  700. newP = append(newP, a.savePolicyLine(pType, newRule))
  701. }
  702. tx := a.db.Begin()
  703. for i := range newP {
  704. str, args := line.queryString()
  705. if err := tx.Where(str, args...).Find(&oldP).Error; err != nil {
  706. tx.Rollback()
  707. return nil, err
  708. }
  709. if err := tx.Where(str, args...).Delete([]CasbinRule{}).Error; err != nil {
  710. tx.Rollback()
  711. return nil, err
  712. }
  713. if err := tx.Create(&newP[i]).Error; err != nil {
  714. tx.Rollback()
  715. return nil, err
  716. }
  717. }
  718. // return deleted rulues
  719. oldPolicies := make([][]string, 0)
  720. for _, v := range oldP {
  721. oldPolicy := v.toStringPolicy()
  722. oldPolicies = append(oldPolicies, oldPolicy)
  723. }
  724. return oldPolicies, tx.Commit().Error
  725. }
  726. func (c *CasbinRule) queryString() (interface{}, []interface{}) {
  727. queryArgs := []interface{}{c.PType}
  728. queryStr := "p_type = ?"
  729. if c.V0 != "" {
  730. queryStr += " and v0 = ?"
  731. queryArgs = append(queryArgs, c.V0)
  732. }
  733. if c.V1 != "" {
  734. queryStr += " and v1 = ?"
  735. queryArgs = append(queryArgs, c.V1)
  736. }
  737. if c.V2 != "" {
  738. queryStr += " and v2 = ?"
  739. queryArgs = append(queryArgs, c.V2)
  740. }
  741. if c.V3 != "" {
  742. queryStr += " and v3 = ?"
  743. queryArgs = append(queryArgs, c.V3)
  744. }
  745. if c.V4 != "" {
  746. queryStr += " and v4 = ?"
  747. queryArgs = append(queryArgs, c.V4)
  748. }
  749. if c.V5 != "" {
  750. queryStr += " and v5 = ?"
  751. queryArgs = append(queryArgs, c.V5)
  752. }
  753. if c.V6 != "" {
  754. queryStr += " and v6 = ?"
  755. queryArgs = append(queryArgs, c.V6)
  756. }
  757. if c.V7 != "" {
  758. queryStr += " and v7 = ?"
  759. queryArgs = append(queryArgs, c.V7)
  760. }
  761. return queryStr, queryArgs
  762. }
  763. func (c *CasbinRule) toStringPolicy() []string {
  764. policy := make([]string, 0)
  765. if c.PType != "" {
  766. policy = append(policy, c.PType)
  767. }
  768. if c.V0 != "" {
  769. policy = append(policy, c.V0)
  770. }
  771. if c.V1 != "" {
  772. policy = append(policy, c.V1)
  773. }
  774. if c.V2 != "" {
  775. policy = append(policy, c.V2)
  776. }
  777. if c.V3 != "" {
  778. policy = append(policy, c.V3)
  779. }
  780. if c.V4 != "" {
  781. policy = append(policy, c.V4)
  782. }
  783. if c.V5 != "" {
  784. policy = append(policy, c.V5)
  785. }
  786. if c.V6 != "" {
  787. policy = append(policy, c.V6)
  788. }
  789. if c.V7 != "" {
  790. policy = append(policy, c.V7)
  791. }
  792. return policy
  793. }