migrate.go 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. package migration
  2. import (
  3. "IotAdmin/common/database"
  4. "IotAdmin/common/global"
  5. "IotAdmin/migration/models"
  6. "log"
  7. "path/filepath"
  8. "sort"
  9. "sync"
  10. "gorm.io/gorm"
  11. )
  12. var Migrate = &Migration{
  13. version: make(map[string]func(db *gorm.DB, version string) error),
  14. }
  15. type Migration struct {
  16. db *gorm.DB
  17. version map[string]func(db *gorm.DB, version string) error
  18. mutex sync.Mutex
  19. }
  20. func (e *Migration) GetDb() *gorm.DB {
  21. return e.db
  22. }
  23. func (e *Migration) SetDb(db *gorm.DB) {
  24. e.db = db
  25. }
  26. func (e *Migration) SetVersion(k string, f func(db *gorm.DB, version string) error) {
  27. e.mutex.Lock()
  28. defer e.mutex.Unlock()
  29. e.version[k] = f
  30. }
  31. func (e *Migration) Migrate() {
  32. versions := make([]string, 0)
  33. for k := range e.version {
  34. versions = append(versions, k)
  35. }
  36. if !sort.StringsAreSorted(versions) {
  37. sort.Strings(versions)
  38. }
  39. var err error
  40. var count int64
  41. for _, v := range versions {
  42. err = e.db.Table((&models.Migration{}).TableName()).Where("version = ?", v).Count(&count).Error
  43. if err != nil {
  44. log.Fatalln(err)
  45. }
  46. if count > 0 {
  47. log.Println(count)
  48. count = 0
  49. continue
  50. }
  51. err = (e.version[v])(e.db.Debug(), v)
  52. if err != nil {
  53. log.Fatalln(err)
  54. }
  55. }
  56. }
  57. func GetFilename(s string) string {
  58. s = filepath.Base(s)
  59. return s[:13]
  60. }
  61. func InitDbData(db *gorm.DB) (err error) {
  62. filePath := "config/sql/db.sql"
  63. if global.Driver == "postgres" {
  64. filePath := "config/sql/db.sql"
  65. if err = database.ExecSql(db, filePath); err != nil {
  66. return err
  67. }
  68. filePath = "config/sql/pg.sql"
  69. err = database.ExecSql(db, filePath)
  70. } else if global.Driver == "mysql" {
  71. filePath = "config/sql/db-begin-mysql.sql"
  72. if err = database.ExecSql(db, filePath); err != nil {
  73. return err
  74. }
  75. filePath = "config/sql/db.sql"
  76. if err = database.ExecSql(db, filePath); err != nil {
  77. return err
  78. }
  79. filePath = "config/sql/db-end-mysql.sql"
  80. err = database.ExecSql(db, filePath)
  81. } else {
  82. err = database.ExecSql(db, filePath)
  83. }
  84. return err
  85. }