optionRedis.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. package config
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "crypto/x509"
  6. "fmt"
  7. "os"
  8. "github.com/redis/go-redis/v9"
  9. )
  10. var _redis *redis.Client
  11. type RedisConnectOptions struct {
  12. Network string `yaml:"network" json:"network"`
  13. Addr string `yaml:"addr" json:"addr"`
  14. Username string `yaml:"username" json:"username"`
  15. Password string `yaml:"password" json:"password"`
  16. DB int `yaml:"db" json:"db"`
  17. PoolSize int `yaml:"pool_size" json:"pool_size"`
  18. Tls *Tls `yaml:"tls" json:"tls"`
  19. MaxRetries int `yaml:"max_retries" json:"max_retries"`
  20. }
  21. type Tls struct {
  22. Cert string `yaml:"cert" json:"cert"`
  23. Key string `yaml:"key" json:"key"`
  24. Ca string `yaml:"ca" json:"ca"`
  25. }
  26. // GetRedisClient 获取redis客户端
  27. func GetRedisClient() *redis.Client {
  28. return _redis
  29. }
  30. // SetRedisClient 设置redis客户端
  31. func SetRedisClient(c *redis.Client) {
  32. if _redis != nil && _redis != c {
  33. _redis.Shutdown(context.TODO())
  34. }
  35. _redis = c
  36. }
  37. func (e RedisConnectOptions) GetRedisOptions() (*redis.Options, error) {
  38. r := &redis.Options{
  39. Network: e.Network,
  40. Addr: e.Addr,
  41. Username: e.Username,
  42. Password: e.Password,
  43. DB: e.DB,
  44. MaxRetries: e.MaxRetries,
  45. PoolSize: e.PoolSize,
  46. }
  47. var err error
  48. r.TLSConfig, err = getTLS(e.Tls)
  49. return r, err
  50. }
  51. func getTLS(c *Tls) (*tls.Config, error) {
  52. if c != nil && c.Cert != "" {
  53. // 从证书相关文件中读取和解析信息,得到证书公钥、密钥对
  54. cert, err := tls.LoadX509KeyPair(c.Cert, c.Key)
  55. if err != nil {
  56. fmt.Printf("tls.LoadX509KeyPair err: %v\n", err)
  57. return nil, err
  58. }
  59. // 创建一个新的、空的 CertPool,并尝试解析 PEM 编码的证书,解析成功会将其加到 CertPool 中
  60. certPool := x509.NewCertPool()
  61. ca, err := os.ReadFile(c.Ca)
  62. if err != nil {
  63. fmt.Printf("ioutil.ReadFile err: %v\n", err)
  64. return nil, err
  65. }
  66. if ok := certPool.AppendCertsFromPEM(ca); !ok {
  67. fmt.Println("certPool.AppendCertsFromPEM err")
  68. return nil, err
  69. }
  70. return &tls.Config{
  71. // 设置证书链,允许包含一个或多个
  72. Certificates: []tls.Certificate{cert},
  73. // 要求必须校验客户端的证书
  74. ClientAuth: tls.RequireAndVerifyClientCert,
  75. // 设置根证书的集合,校验方式使用 ClientAuth 中设定的模式
  76. ClientCAs: certPool,
  77. }, nil
  78. }
  79. return nil, nil
  80. }