db.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. package db
  2. import (
  3. "MeterService/core/utils"
  4. "database/sql"
  5. "sync"
  6. )
  7. var (
  8. err error
  9. myDbs *map[string]*MyDB
  10. )
  11. const defaultDb = "default"
  12. var once sync.Once
  13. type MyDB struct {
  14. DB *sql.DB
  15. }
  16. // OpenDb 获取数据库连接
  17. func OpenDb(dbFun func() *sql.DB) {
  18. once.Do(func() {
  19. myDbs = &map[string]*MyDB{}
  20. })
  21. myDb := NewDB(dbFun())
  22. (*myDbs)[defaultDb] = myDb
  23. }
  24. func GetDb() *MyDB {
  25. return (*myDbs)[defaultDb]
  26. }
  27. func CloseDb() {
  28. err := (*myDbs)[defaultDb].DB.Close()
  29. if err != nil {
  30. return
  31. } else {
  32. panic(err)
  33. }
  34. }
  35. func OpenManyDb(key string, dbFun func() *sql.DB) {
  36. once.Do(func() {
  37. myDbs = &map[string]*MyDB{}
  38. })
  39. myDb := NewDB(dbFun())
  40. (*myDbs)[key] = myDb
  41. }
  42. func GetDbByKey(key string) *MyDB {
  43. return (*myDbs)[key]
  44. }
  45. func CloseManyDb() {
  46. for _, v := range *myDbs {
  47. err := v.DB.Close()
  48. if err != nil {
  49. return
  50. } else {
  51. panic(err)
  52. }
  53. }
  54. }
  55. // Exec 增、删、改
  56. func (db *MyDB) Exec(SQL string, args ...interface{}) (sql.Result, error) {
  57. //DB := OpenDb().DB()
  58. var ret sql.Result
  59. if args == nil {
  60. ret, err = db.DB.Exec(SQL)
  61. } else {
  62. ret, err = db.DB.Exec(SQL, args...)
  63. }
  64. if err != nil {
  65. return nil, err
  66. }
  67. return ret, nil
  68. }
  69. // Query 查询
  70. func (db *MyDB) Query(SQL string, args ...interface{}) ([]map[string]string, bool) { //通用查询
  71. var rows *sql.Rows
  72. if args == nil {
  73. rows, err = db.DB.Query(SQL)
  74. } else {
  75. rows, err = db.DB.Query(SQL, args...)
  76. }
  77. //rows, err := DB.Query(SQL, args) //执行SQL语句,比如select * from users
  78. if err != nil {
  79. panic(err)
  80. }
  81. columns, _ := rows.Columns() //获取列的信息
  82. count := len(columns) //列的数量
  83. var values = make([]interface{}, count) //创建一个与列的数量相当的空接口
  84. for i, _ := range values {
  85. var ii interface{} //为空接口分配内存
  86. values[i] = &ii //取得这些内存的指针,因后继的Scan函数只接受指针
  87. }
  88. ret := make([]map[string]string, 0) //创建返回值:不定长的map类型切片
  89. for rows.Next() {
  90. err := rows.Scan(values...) //开始读行,Scan函数只接受指针变量
  91. m := make(map[string]string) //用于存放1列的 [键/值] 对
  92. if err != nil {
  93. panic(err)
  94. }
  95. for i, colName := range columns {
  96. var rawValue = *(values[i].(*interface{})) //读出raw数据
  97. b, ok := rawValue.([]byte)
  98. var v string
  99. if ok {
  100. v = string(b) //将raw数据转换成字符串
  101. } else {
  102. v, err = utils.ToString(rawValue) //colName是键,v是值
  103. if err != nil {
  104. panic(err)
  105. }
  106. }
  107. m[colName] = v
  108. }
  109. ret = append(ret, m) //将单行所有列的键值对附加在总的返回值上(以行为单位)
  110. }
  111. defer func(rows *sql.Rows) {
  112. _ = rows.Close()
  113. }(rows)
  114. if len(ret) != 0 {
  115. return ret, true
  116. }
  117. return nil, false
  118. }
  119. func NewDB(db *sql.DB) *MyDB {
  120. return &MyDB{DB: db}
  121. }