基于Gorm实现数据库迁移的方案

692 阅读4分钟

1、Gorm数据库迁移的痛点

        Gorm作为一个golang生态下的Orm,拥有庞大的用户群体。其提供的AutoMigrate方法,可以实现数据库的自动迁移。写法:db.AutoMigrate(&User{},&Product{},&Order{})。 不过,这种写法每次新增model文件,都需要在迁移方法中添加结构体,在开发过程中容易出现忘记修改或者无法灵活的迁移指定表。因此,本篇文章尝试实现一种可指定表或全部表的迁移方案,供大家参考,如能有所帮助,深感欣慰。

2、方案介绍

        自动迁移的基础在于传给AutoMigrate方法的结构体,要实现根据传入的表的名字实例化结构体,需要使用go的反射功能(不熟悉这块的小伙伴可以自己查找一下相关文章)。可是我们知道反射不能根据字符串实例化结构体(不能通过 包名.字符串  的方式访问包下的结构体),但却提供了创建结构体的方法(reflect.StructOf()),本文正是通过创建一个与model结构体相同的结构体来实现实例化结构体的。

        至于如何解析包内所有的结构体,则需要ast抽象语法树来解决。另外,为解决编译后无法使用ast的情况,又引入了golang1.6后引入embed包,将包文件编译到可执行文件当中。

3、代码实现

        代码文件结构如下:

image.png

main文件主要处理命令行参数和embed models文件夹,内容如下:

package main

import (
	"demo/mysql"
	"embed"
	"flag"
)

//go:embed models
var ModelsFiles embed.FS

func main() {
	// 解析命令
	tableName := flag.String("migrate", "", "迁移的表名,如果是all,则迁移所有表")

	flag.Parse()

	mysql.Migrate(ModelsFiles, *tableName)
}

models存放的是表的结构体,内容如下:

package models

import (
	"time"
)

type AdminGroup struct {
	Groupid   int    `gorm:"primaryKey"`
	Groupname string `gorm:"type:varchar(255);"`
	Privs     string `gorm:"type:text;"`
	CreatedAt time.Time
	UpdatedAt time.Time
}

// 这个方法可以设置表名,必须实现这个方法,在自动迁移过程中,需要到这里解析表名
func (AdminGroup) TableName() string {
	return "dlj_admin_group"
}

mysql文件夹下的MyDatabase.go中处理的是数据库连接,初始化*grom.DB

package mysql

import (
	"fmt"

	"gorm.io/driver/mysql"
	"gorm.io/gorm"
)

var MyDatabase *gorm.DB

func init() {
	host := "127.0.0.1"
	port := 3306
	user := "root"
	password := ""
	database := "demo"

	dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local", user, password, host, port, database)

	var err error
	
	MyDatabase, err = gorm.Open(mysql.New(mysql.Config{
		DSN:                       dsn,
		DefaultStringSize:         256,   
		DisableDatetimePrecision:  true,  
		DontSupportRenameIndex:    true,  
		DontSupportRenameColumn:   true,  
		SkipInitializeWithVersion: false, 
	}))

	if err != nil {
		panic(err)
	}
}

接下来就是迁移文件Migrate.go的内容,这里通过读取models文件夹的内容,通过ast语法解析出结构体内容,然后通过反射直接初始化与models中的结构体一样的结构体,解决无法动态初始化某个包的结构体的问题。

package mysql

import (
	"embed"
	"fmt"
	"go/ast"
	"go/parser"
	"go/token"
	"reflect"
	"strings"
	"time"
)

func Migrate(embededFiles embed.FS, tableName string) {
        // 读取所有的表文件
	files, _ := embededFiles.ReadDir("models")

	for _, file := range files {
                // 这里判断命令是否为all,如果不为all,则迁移指定表
		if (tableName != "all") && ((tableName + ".go") != file.Name()) {
			continue
		}

		fileInfo, _ := embededFiles.ReadFile("models/" + file.Name())
                // 这里是ast解析语法树的常规操作,大家可以百度一下
		fset := token.NewFileSet()
		astfile, err := parser.ParseFile(fset, "", fileInfo, 0)
		if err != nil {
			panic(err)
		}

		var (
			structField []reflect.StructField// 初始化空的结构体,后面通过解析models文件,来添加结构体成员信息
			tableName string
		)

		for _, v := range astfile.Decls {
                        // 如果是结构体定义的语法,则解析结构体所有字段(包括字段名/字段类型/tag)
			if strc, ok := v.(*ast.GenDecl); ok && strc.Tok == token.TYPE {
				for _, spec := range strc.Specs {
					if tp, ok := spec.(*ast.TypeSpec); ok {
						if stp, ok := tp.Type.(*ast.StructType); ok {
							if !stp.Struct.IsValid() {
								continue
							}

							for _, li := range stp.Fields.List {
								elementTypeReflect := getElementType(li)
								if elementTypeReflect == nil {
									panic(fmt.Sprintf("不支持的数据类型,%v", li))
								}

								if li.Tag != nil {
									tag := strings.ReplaceAll(li.Tag.Value, "`", "")
									// 新增成员
									structField = append(structField, reflect.StructField{
										Name:li.Names[0].Name,
										Type:elementTypeReflect,
										Tag:  reflect.StructTag(tag),// 将string类型强制转换为reflect.StructTag类型
									})
								} else {
									// 新增成员
									structField = append(structField, reflect.StructField{
										Name:li.Names[0].Name,
										Type:elementTypeReflect,
									})
								}
							}
						}
					}
				}
			}
                        // 如果是函数,并且函数名是TableName(见models文件内容,这个函数是gorm获取表名的方法),则可以获取到表名
			if fun, ok := v.(*ast.FuncDecl); ok && fun.Name.Name == "TableName" {
				tableName = fun.Body.List[0].(*ast.ReturnStmt).Results[0].(*ast.BasicLit).Value
				tableName = strings.ReplaceAll(tableName, "\"", "")
			}
		}
                
		if len(structField) == 0 {
			panic("找不到指定的model")
		}
                // 根据结构体内容实例化结构体
		typ := reflect.StructOf(structField)
		v:= reflect.New(typ)
		s1 := v.Interface()
                // 运行迁移命令
		MyDatabase.Table(tableName).AutoMigrate(&s1)
	}
}
// 这里来解析models成员的类型,可以根据实际情况扩展
func getElementType(li *ast.Field) reflect.Type {
	// 判断成员类型,可根据实际情况扩展
	var (
		elementType string
		elementTypeReflect reflect.Type
	)

	if _, ok := li.Type.(*ast.Ident); ok {
		elementType = li.Type.(*ast.Ident).Name
	}
        // 如果是time.Time会被ast识别为ast.SelectorExpr
	if _, ok := li.Type.(*ast.SelectorExpr); ok {
		elementType = li.Type.(*ast.SelectorExpr).X.(*ast.Ident).Name + "." + li.Type.(*ast.SelectorExpr).Sel.Name
	}

	switch elementType {
	case "int":
		elementTypeReflect = reflect.TypeOf(1)
	case "int8":
		elementTypeReflect = reflect.TypeOf(int8(1))
	case "int16":
		elementTypeReflect = reflect.TypeOf(int16(1))
	case "int32":
		elementTypeReflect = reflect.TypeOf(int32(1))
	case "int64":
		elementTypeReflect = reflect.TypeOf(int64(1))
	case "float32":
		elementTypeReflect = reflect.TypeOf(float32(1))
	case "float64":
		elementTypeReflect = reflect.TypeOf(float64(1))
	case "bool":
		elementTypeReflect = reflect.TypeOf(true)
	case "string":
		elementTypeReflect = reflect.TypeOf("")
	case "time.Time":
		elementTypeReflect = reflect.TypeOf(time.Now())
	}

	return elementTypeReflect
}