129 lines
3.3 KiB
Go
129 lines
3.3 KiB
Go
package cool
|
||
|
||
import (
|
||
_ "blazing/contrib/drivers/pgsql"
|
||
"blazing/cool/cooldb"
|
||
"sync"
|
||
|
||
"github.com/gogf/gf/v2/encoding/gjson"
|
||
"github.com/gogf/gf/v2/frame/g"
|
||
"github.com/gogf/gf/v2/os/gres"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
var (
|
||
autoMigrateMu sync.Mutex
|
||
autoMigrateModels []IModel
|
||
)
|
||
|
||
// 初始化数据库连接供gorm使用
|
||
func InitDB(group string) (*gorm.DB, error) {
|
||
// var ctx context.Context
|
||
var db *gorm.DB
|
||
// 如果group为空,则使用默认的group,否则使用group参数
|
||
if group == "" {
|
||
group = "default"
|
||
}
|
||
defer func() {
|
||
if err := recover(); err != nil {
|
||
panic("failed to connect database")
|
||
}
|
||
}()
|
||
config := g.DB(group).GetConfig()
|
||
db, err := cooldb.GetConn(config)
|
||
if err != nil {
|
||
panic(err.Error())
|
||
}
|
||
|
||
GormDBS[group] = db
|
||
return db, nil
|
||
}
|
||
|
||
// 根据entity结构体获取 *gorm.DB
|
||
func getDBbyModel(model IModel) *gorm.DB {
|
||
|
||
group := model.GroupName()
|
||
// 判断是否存在 GormDBS[group] 字段,如果存在,则使用该字段的值作为DB,否则初始化DB
|
||
if _, ok := GormDBS[group]; ok {
|
||
return GormDBS[group]
|
||
} else {
|
||
|
||
db, err := InitDB(group)
|
||
if err != nil {
|
||
panic("failed to connect database")
|
||
}
|
||
// 把重新初始化的GormDBS存入全局变量中
|
||
GormDBS[group] = db
|
||
return db
|
||
}
|
||
}
|
||
|
||
// 根据entity结构体创建表
|
||
func CreateTable(model IModel) error {
|
||
autoMigrateMu.Lock()
|
||
autoMigrateModels = append(autoMigrateModels, model)
|
||
autoMigrateMu.Unlock()
|
||
return nil
|
||
}
|
||
|
||
// RunAutoMigrate 显式执行已注册模型的建表/迁移。
|
||
func RunAutoMigrate() error {
|
||
if !Config.AutoMigrate {
|
||
return nil
|
||
}
|
||
autoMigrateMu.Lock()
|
||
models := append([]IModel(nil), autoMigrateModels...)
|
||
autoMigrateMu.Unlock()
|
||
|
||
seen := make(map[string]struct{}, len(models))
|
||
for _, model := range models {
|
||
key := model.GroupName() + ":" + model.TableName()
|
||
if _, ok := seen[key]; ok {
|
||
continue
|
||
}
|
||
seen[key] = struct{}{}
|
||
|
||
db := getDBbyModel(model)
|
||
if err := db.AutoMigrate(model); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// FillInitData 数据库填充初始数据
|
||
func FillInitData(ctx g.Ctx, moduleName string, model IModel, ismod *bool) (bool, error) {
|
||
mInit := g.DB("default").Model("base_sys_init")
|
||
n, err := mInit.Clone().Where("group", model.GroupName()).Where("table", model.TableName()).Count()
|
||
if err != nil {
|
||
Logger.Error(ctx, "读取表 base_sys_init 失败 ", err.Error())
|
||
return false, err
|
||
}
|
||
if n > 0 {
|
||
Logger.Debug(ctx, "分组", model.GroupName(), "中的表", model.TableName(), "已经初始化过,跳过本次初始化.")
|
||
return false, err
|
||
}
|
||
m := g.DB(model.GroupName()).Model(model.TableName())
|
||
jsonData, _ := gjson.LoadContent(gres.GetContent("modules/" + moduleName + "/resource/initjson/" + model.TableName() + ".json"))
|
||
if jsonData.Var().Clone().IsEmpty() {
|
||
Logger.Debug(ctx, "分组", model.GroupName(), "中的表", model.TableName(), "无可用的初始化数据,跳过本次初始化.")
|
||
return false, err
|
||
}
|
||
_, err = m.Data(jsonData).Insert()
|
||
if err != nil {
|
||
Logger.Error(ctx, err.Error())
|
||
return false, err
|
||
}
|
||
_, err = mInit.Insert(g.Map{"group": model.GroupName(), "table": model.TableName()})
|
||
if err != nil {
|
||
Logger.Error(ctx, err.Error())
|
||
return false, err
|
||
}
|
||
Logger.Info(ctx, "分组", model.GroupName(), "中的表", model.TableName(), "初始化完成.")
|
||
if ismod != nil {
|
||
*ismod = true
|
||
}
|
||
|
||
return true, err
|
||
}
|