watcher.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. package rediswatcher
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "log"
  8. "strings"
  9. "sync"
  10. "github.com/casbin/casbin/v2"
  11. "github.com/casbin/casbin/v2/model"
  12. "github.com/casbin/casbin/v2/persist"
  13. "github.com/redis/go-redis/v9"
  14. )
  15. type Watcher struct {
  16. l sync.Mutex
  17. subClient redis.UniversalClient
  18. pubClient redis.UniversalClient
  19. options WatcherOptions
  20. close chan struct{}
  21. callback func(string)
  22. ctx context.Context
  23. }
  24. func DefaultUpdateCallback(e casbin.IEnforcer) func(string) {
  25. return func(msg string) {
  26. msgStruct := &MSG{}
  27. err := msgStruct.UnmarshalBinary([]byte(msg))
  28. if err != nil {
  29. log.Println(err)
  30. return
  31. }
  32. var res bool
  33. switch msgStruct.Method {
  34. case Update, UpdateForSavePolicy:
  35. err = e.LoadPolicy()
  36. res = true
  37. case UpdateForAddPolicy:
  38. res, err = e.SelfAddPolicy(msgStruct.Sec, msgStruct.Ptype, msgStruct.NewRule)
  39. case UpdateForAddPolicies:
  40. res, err = e.SelfAddPolicies(msgStruct.Sec, msgStruct.Ptype, msgStruct.NewRules)
  41. case UpdateForRemovePolicy:
  42. res, err = e.SelfRemovePolicy(msgStruct.Sec, msgStruct.Ptype, msgStruct.NewRule)
  43. case UpdateForRemoveFilteredPolicy:
  44. res, err = e.SelfRemoveFilteredPolicy(msgStruct.Sec, msgStruct.Ptype, msgStruct.FieldIndex, msgStruct.FieldValues...)
  45. case UpdateForRemovePolicies:
  46. res, err = e.SelfRemovePolicies(msgStruct.Sec, msgStruct.Ptype, msgStruct.NewRules)
  47. case UpdateForUpdatePolicy:
  48. res, err = e.SelfUpdatePolicy(msgStruct.Sec, msgStruct.Ptype, msgStruct.OldRule, msgStruct.NewRule)
  49. case UpdateForUpdatePolicies:
  50. res, err = e.SelfUpdatePolicies(msgStruct.Sec, msgStruct.Ptype, msgStruct.OldRules, msgStruct.NewRules)
  51. default:
  52. err = errors.New("unknown update type")
  53. }
  54. if err != nil {
  55. log.Println(err)
  56. }
  57. if !res {
  58. log.Println("callback update policy failed")
  59. }
  60. }
  61. }
  62. type MSG struct {
  63. Method UpdateType
  64. ID string
  65. Sec string
  66. Ptype string
  67. OldRule []string
  68. OldRules [][]string
  69. NewRule []string
  70. NewRules [][]string
  71. FieldIndex int
  72. FieldValues []string
  73. }
  74. type UpdateType string
  75. const (
  76. Update UpdateType = "Update"
  77. UpdateForAddPolicy UpdateType = "UpdateForAddPolicy"
  78. UpdateForRemovePolicy UpdateType = "UpdateForRemovePolicy"
  79. UpdateForRemoveFilteredPolicy UpdateType = "UpdateForRemoveFilteredPolicy"
  80. UpdateForSavePolicy UpdateType = "UpdateForSavePolicy"
  81. UpdateForAddPolicies UpdateType = "UpdateForAddPolicies"
  82. UpdateForRemovePolicies UpdateType = "UpdateForRemovePolicies"
  83. UpdateForUpdatePolicy UpdateType = "UpdateForUpdatePolicy"
  84. UpdateForUpdatePolicies UpdateType = "UpdateForUpdatePolicies"
  85. )
  86. func (m *MSG) MarshalBinary() ([]byte, error) {
  87. return json.Marshal(m)
  88. }
  89. // UnmarshalBinary decodes the struct into a User
  90. func (m *MSG) UnmarshalBinary(data []byte) error {
  91. if err := json.Unmarshal(data, m); err != nil {
  92. return err
  93. }
  94. return nil
  95. }
  96. // NewWatcher creates a new Watcher to be used with a Casbin enforcer
  97. // addr is a redis target string in the format "host:port"
  98. // setters allows for inline WatcherOptions
  99. //
  100. // Example:
  101. // w, err := rediswatcher.NewWatcher("127.0.0.1:6379",WatcherOptions{}, nil)
  102. func NewWatcher(addr string, option WatcherOptions) (persist.Watcher, error) {
  103. option.Options.Addr = addr
  104. initConfig(&option)
  105. w := &Watcher{
  106. ctx: context.Background(),
  107. close: make(chan struct{}),
  108. }
  109. if err := w.initConfig(option); err != nil {
  110. return nil, err
  111. }
  112. if err := w.subClient.Ping(w.ctx).Err(); err != nil {
  113. return nil, err
  114. }
  115. if err := w.pubClient.Ping(w.ctx).Err(); err != nil {
  116. return nil, err
  117. }
  118. w.options = option
  119. w.subscribe()
  120. return w, nil
  121. }
  122. // NewWatcherWithCluster creates a new Watcher to be used with a Casbin enforcer
  123. // addrs is a redis-cluster target string in the format "host1:port1,host2:port2,host3:port3"
  124. //
  125. // Example:
  126. // w, err := rediswatcher.NewWatcherWithCluster("127.0.0.1:6379,127.0.0.1:6379,127.0.0.1:6379",WatcherOptions{})
  127. func NewWatcherWithCluster(addrs string, option WatcherOptions) (persist.Watcher, error) {
  128. addrsStr := strings.Split(addrs, ",")
  129. option.ClusterOptions.Addrs = addrsStr
  130. initConfig(&option)
  131. w := &Watcher{
  132. subClient: redis.NewClusterClient(&redis.ClusterOptions{
  133. Addrs: addrsStr,
  134. Password: option.ClusterOptions.Password,
  135. }),
  136. pubClient: redis.NewClusterClient(&redis.ClusterOptions{
  137. Addrs: addrsStr,
  138. Password: option.ClusterOptions.Password,
  139. }),
  140. ctx: context.Background(),
  141. close: make(chan struct{}),
  142. }
  143. err := w.initConfig(option, true)
  144. if err != nil {
  145. return nil, err
  146. }
  147. if err := w.subClient.Ping(w.ctx).Err(); err != nil {
  148. return nil, err
  149. }
  150. if err := w.pubClient.Ping(w.ctx).Err(); err != nil {
  151. return nil, err
  152. }
  153. w.options = option
  154. w.subscribe()
  155. return w, nil
  156. }
  157. func (w *Watcher) initConfig(option WatcherOptions, cluster ...bool) error {
  158. var err error
  159. if option.OptionalUpdateCallback != nil {
  160. err = w.SetUpdateCallback(option.OptionalUpdateCallback)
  161. } else {
  162. err = w.SetUpdateCallback(func(string) {
  163. log.Println("Casbin Redis Watcher callback not set when an update was received")
  164. })
  165. }
  166. if err != nil {
  167. return err
  168. }
  169. if option.SubClient != nil {
  170. w.subClient = option.SubClient
  171. } else {
  172. if len(cluster) > 0 && cluster[0] {
  173. w.subClient = redis.NewClusterClient(&option.ClusterOptions)
  174. } else {
  175. w.subClient = redis.NewClient(&option.Options)
  176. }
  177. }
  178. if option.PubClient != nil {
  179. w.pubClient = option.PubClient
  180. } else {
  181. if len(cluster) > 0 && cluster[0] {
  182. w.pubClient = redis.NewClusterClient(&option.ClusterOptions)
  183. } else {
  184. w.pubClient = redis.NewClient(&option.Options)
  185. }
  186. }
  187. return nil
  188. }
  189. // NewPublishWatcher return a Watcher only publish but not subscribe
  190. func NewPublishWatcher(addr string, option WatcherOptions) (persist.Watcher, error) {
  191. option.Options.Addr = addr
  192. w := &Watcher{
  193. pubClient: redis.NewClient(&option.Options),
  194. ctx: context.Background(),
  195. close: make(chan struct{}),
  196. }
  197. initConfig(&option)
  198. w.options = option
  199. return w, nil
  200. }
  201. // SetUpdateCallback sets the update callback function invoked by the watcher
  202. // when the policy is updated. Defaults to Enforcer.LoadPolicy()
  203. func (w *Watcher) SetUpdateCallback(callback func(string)) error {
  204. w.l.Lock()
  205. w.callback = callback
  206. w.l.Unlock()
  207. return nil
  208. }
  209. // Update publishes a message to all other casbin instances telling them to
  210. // invoke their update callback
  211. func (w *Watcher) Update() error {
  212. return w.logRecord(func() error {
  213. w.l.Lock()
  214. defer w.l.Unlock()
  215. return w.pubClient.Publish(
  216. context.Background(),
  217. w.options.Channel,
  218. &MSG{
  219. Method: Update,
  220. ID: w.options.LocalID,
  221. },
  222. ).Err()
  223. })
  224. }
  225. // UpdateForAddPolicy calls the update callback of other instances to synchronize their policy.
  226. // It is called after Enforcer.AddPolicy()
  227. func (w *Watcher) UpdateForAddPolicy(sec, ptype string, params ...string) error {
  228. return w.logRecord(func() error {
  229. w.l.Lock()
  230. defer w.l.Unlock()
  231. return w.pubClient.Publish(
  232. context.Background(),
  233. w.options.Channel,
  234. &MSG{
  235. Method: UpdateForAddPolicy,
  236. ID: w.options.LocalID,
  237. Sec: sec,
  238. Ptype: ptype,
  239. NewRule: params,
  240. }).Err()
  241. })
  242. }
  243. // UpdateForRemovePolicy calls the update callback of other instances to synchronize their policy.
  244. // It is called after Enforcer.RemovePolicy()
  245. func (w *Watcher) UpdateForRemovePolicy(sec, ptype string, params ...string) error {
  246. return w.logRecord(func() error {
  247. w.l.Lock()
  248. defer w.l.Unlock()
  249. return w.pubClient.Publish(
  250. context.Background(),
  251. w.options.Channel,
  252. &MSG{
  253. Method: UpdateForRemovePolicy,
  254. ID: w.options.LocalID,
  255. Sec: sec,
  256. Ptype: ptype,
  257. NewRule: params,
  258. },
  259. ).Err()
  260. })
  261. }
  262. // UpdateForRemoveFilteredPolicy calls the update callback of other instances to synchronize their policy.
  263. // It is called after Enforcer.RemoveFilteredNamedGroupingPolicy()
  264. func (w *Watcher) UpdateForRemoveFilteredPolicy(sec, ptype string, fieldIndex int, fieldValues ...string) error {
  265. return w.logRecord(func() error {
  266. w.l.Lock()
  267. defer w.l.Unlock()
  268. return w.pubClient.Publish(
  269. context.Background(),
  270. w.options.Channel,
  271. &MSG{
  272. Method: UpdateForRemoveFilteredPolicy,
  273. ID: w.options.LocalID,
  274. Sec: sec,
  275. Ptype: ptype,
  276. FieldIndex: fieldIndex,
  277. FieldValues: fieldValues,
  278. },
  279. ).Err()
  280. })
  281. }
  282. // UpdateForSavePolicy calls the update callback of other instances to synchronize their policy.
  283. // It is called after Enforcer.RemoveFilteredNamedGroupingPolicy()
  284. func (w *Watcher) UpdateForSavePolicy(model model.Model) error {
  285. return w.logRecord(func() error {
  286. w.l.Lock()
  287. defer w.l.Unlock()
  288. return w.pubClient.Publish(
  289. context.Background(),
  290. w.options.Channel,
  291. &MSG{
  292. Method: UpdateForSavePolicy,
  293. ID: w.options.LocalID,
  294. },
  295. ).Err()
  296. })
  297. }
  298. // UpdateForAddPolicies calls the update callback of other instances to synchronize their policies in batch.
  299. // It is called after Enforcer.AddPolicies()
  300. func (w *Watcher) UpdateForAddPolicies(sec string, ptype string, rules ...[]string) error {
  301. return w.logRecord(func() error {
  302. w.l.Lock()
  303. defer w.l.Unlock()
  304. return w.pubClient.Publish(
  305. context.Background(),
  306. w.options.Channel,
  307. &MSG{
  308. Method: UpdateForAddPolicies,
  309. ID: w.options.LocalID,
  310. Sec: sec,
  311. Ptype: ptype,
  312. NewRules: rules,
  313. },
  314. ).Err()
  315. })
  316. }
  317. // UpdateForRemovePolicies calls the update callback of other instances to synchronize their policies in batch.
  318. // It is called after Enforcer.RemovePolicies()
  319. func (w *Watcher) UpdateForRemovePolicies(sec string, ptype string, rules ...[]string) error {
  320. return w.logRecord(func() error {
  321. w.l.Lock()
  322. defer w.l.Unlock()
  323. return w.pubClient.Publish(
  324. context.Background(),
  325. w.options.Channel,
  326. &MSG{
  327. Method: UpdateForRemovePolicies,
  328. ID: w.options.LocalID,
  329. Sec: sec,
  330. Ptype: ptype,
  331. NewRules: rules,
  332. },
  333. ).Err()
  334. })
  335. }
  336. // UpdateForUpdatePolicy calls the update callback of other instances to synchronize their policy.
  337. // It is called after Enforcer.UpdatePolicy()
  338. func (w *Watcher) UpdateForUpdatePolicy(sec string, ptype string, oldRule, newRule []string) error {
  339. return w.logRecord(func() error {
  340. w.l.Lock()
  341. defer w.l.Unlock()
  342. return w.pubClient.Publish(
  343. context.Background(),
  344. w.options.Channel,
  345. &MSG{
  346. Method: UpdateForUpdatePolicy,
  347. ID: w.options.LocalID,
  348. Sec: sec,
  349. Ptype: ptype,
  350. OldRule: oldRule,
  351. NewRule: newRule,
  352. },
  353. ).Err()
  354. })
  355. }
  356. // UpdateForUpdatePolicies calls the update callback of other instances to synchronize their policy.
  357. // It is called after Enforcer.UpdatePolicies()
  358. func (w *Watcher) UpdateForUpdatePolicies(sec string, ptype string, oldRules, newRules [][]string) error {
  359. return w.logRecord(func() error {
  360. w.l.Lock()
  361. defer w.l.Unlock()
  362. return w.pubClient.Publish(
  363. context.Background(),
  364. w.options.Channel,
  365. &MSG{
  366. Method: UpdateForUpdatePolicies,
  367. ID: w.options.LocalID,
  368. Sec: sec,
  369. Ptype: ptype,
  370. OldRules: oldRules,
  371. NewRules: newRules,
  372. },
  373. ).Err()
  374. })
  375. }
  376. func (w *Watcher) logRecord(f func() error) error {
  377. err := f()
  378. if err != nil {
  379. log.Println(err)
  380. }
  381. return err
  382. }
  383. func (w *Watcher) unsubscribe(psc *redis.PubSub) error {
  384. return psc.Unsubscribe(w.ctx)
  385. }
  386. func (w *Watcher) subscribe() {
  387. w.l.Lock()
  388. sub := w.subClient.Subscribe(w.ctx, w.options.Channel)
  389. w.l.Unlock()
  390. wg := sync.WaitGroup{}
  391. wg.Add(1)
  392. go func() {
  393. defer func() {
  394. err := sub.Close()
  395. if err != nil {
  396. log.Println(err)
  397. }
  398. err = w.pubClient.Close()
  399. if err != nil {
  400. log.Println(err)
  401. }
  402. err = w.subClient.Close()
  403. if err != nil {
  404. log.Println(err)
  405. }
  406. }()
  407. ch := sub.Channel()
  408. wg.Done()
  409. for msg := range ch {
  410. select {
  411. case <-w.close:
  412. return
  413. default:
  414. }
  415. data := msg.Payload
  416. msgStruct := &MSG{}
  417. err := msgStruct.UnmarshalBinary([]byte(data))
  418. if err != nil {
  419. log.Println(fmt.Printf("Failed to parse message: %s with error: %s\n", data, err.Error()))
  420. } else {
  421. isSelf := msgStruct.ID == w.options.LocalID
  422. if !(w.options.IgnoreSelf && isSelf) {
  423. w.callback(data)
  424. }
  425. }
  426. }
  427. }()
  428. wg.Wait()
  429. }
  430. func (w *Watcher) GetWatcherOptions() WatcherOptions {
  431. w.l.Lock()
  432. defer w.l.Unlock()
  433. return w.options
  434. }
  435. func (w *Watcher) Close() {
  436. w.l.Lock()
  437. defer w.l.Unlock()
  438. close(w.close)
  439. w.pubClient.Publish(w.ctx, w.options.Channel, "Close")
  440. }