10分钟从零到一实现@autowired

125 阅读7分钟

  背景

假设要开发一个MVC的服务,最外层是一个Server结构体,实现了IDL定义的方法;Server作为流量的入口不处理业务逻辑,业务逻辑代码分散在Server中的Service里面,Service实现真正的业务逻辑;Service里持有不同的链接或者外部服务的依赖,依赖他们完成相应的业务逻辑。

以下是上面描述的demo代码:

package main

func main() {
    redisCli, err := NewRedisCli()
    if err != nil {
        panic(err)
    }
    
    DB, err := NewGormDB()
    if err != nil {
        panic(err)
    }
    
    esCli, err := NewEsCli()
    if err != nil {
        panic(err)
    }
    
    larkCli, err := NewLarkCli()
    if err != nil {
        panic(err)
    }
    
    service1 := NewService1(DB, redisCli)
    service2 := NewService2(esCli, larkCli)
    server := NewSerer(service1, service2)
    
    err := server.Run()
    if err != nil {
        panic(err)
    }
}

上面的代码逻辑上没有什么问题,可以保证代码的正确运行,问题就是上面的初始化部分太麻烦了,有没有一种方案可以让上面的初始化交给框架层来做,main函数只用去run就好了;我们先来看看理想态的代码是什么样子。

方案一:反射自动注入

需用用户把用到的对象交给框架,框架最后来组装出来一个结构体

package main

func main() {
    inject(
        NewRedisCli,
        NewGormDB,
        NewEsCli,
        NewLarkCli,
        NewService1,
        service2,
    )
    
    var s server
    err := Wire(&s)
    if err != nil {
        panic(err)
    }
    
    err = server.Run()
    if err != nil {
        panic(err)
    }
}

方案二:代码自动生成

通过ast遍历全部代码,拿到关键信息,再动态生成代码

package main

// @wire
func NewRedisCli() (*RedisCli, err){
    ...
}

// @wire
type Server struct {
    Service1 Service1
    Service2 Service2
}

func main() {
    // NewAPP是代码自动生成出来的函数
    s, err := NewAPP()
    if err != nil {
        panic(err)
    }
    
    err = server.Run()
    if err != nil {
        panic(err)
    }
}

接下来我们就来一步一步拆解来看怎么实现吧。

功能拆解

方案一

字段填充

将结构体里面的每一个字段都正确的填充上

我们把需要的结构都存储到map中,key是对应的类型的PkgPath+StructName,再遍历需要构造结构体的Filed,找到对应的Type就把map中的value赋值上去即可,以下是demo代码

type Server struct {
   A *AA
   B *BB
}

type AA struct {
}

type BB struct {
}

func main() {
   // 构造一个存储对应结构map
   ret := map[string]any{
      reflect.TypeOf(&AA{}).String(): &AA{},
      reflect.TypeOf(&BB{}).String(): &BB{},
   }

   s := &Server{}
   val := reflect.ValueOf(s)
   if val.Kind() == reflect.Ptr {
      val = val.Elem()
   }

   // 遍历要赋值的结构体,将对应Value设置进去
   for i := 0; i < val.NumField(); i++ {
      v := val.Field(i)
      tval, ok := ret[v.Type().String()]
      if ok && v.CanSet() {
         v.Set(reflect.ValueOf(tval))
      }
   }

   fmt.Println(s)
   // &{0x1acf270 0x1acf270}
}

类型构造

将结构体里所需的构造出来

用户将对应的结构体或者初始化方法给我们,我们通过反射进行调用设置到一个字段填充的map中即可,以下是demo代码

type AA struct {
}

func NewAA(b *BB) (*AA, error) {
   return &AA{}, nil
}

type BB struct {
}

func Inject(as ...any) (map[string]any, error) {
   ret := map[string]any{}

   for _, a := range as {
      va := reflect.ValueOf(a)
      switch va.Kind() {
      case reflect.Func:
         ta := va.Type()
         ins := make([]reflect.Value, ta.NumIn())
         for i := 0; i < ta.NumIn(); i++ {
            ins[i] = reflect.ValueOf(ret[ta.In(i).String()])
         }
         outs := va.Call(ins)
         for _, out := range outs {
            if out.Type().Implements(reflect.TypeOf(new(error)).Elem()) {
               if !out.IsNil() {
                  return nil, out.Interface().(error)
               }

               continue
}

            ret[out.Type().String()] = out.Interface()
         }

      case reflect.Struct:
         ret[va.Type().String()] = va.Interface()
      case reflect.Ptr:
         vva := va.Elem()
         if vva.Kind() != reflect.Struct {
            return nil, errors.New( "ptr类型仅支持struct" )
         }

         ret[va.Type().String()] = va.Interface()
      default:
         return nil, errors.New( "只支持Struct、Ptr、Func类型" )
      }
   }

   return ret, nil
}

func main() {
   ret, err := Inject(&BB{}, NewAA)
   if err != nil {
      panic(err)
   }

   fmt.Println(ret)
   // map[*main.BB:0x1acf270 <*main.AA Value>:0x1acf270]
}

方案二

前置知识

在做功能拆解之前,先跟大家简要介绍一下AST,抽象语法树是源代码语法结构的一种抽象表示。它以树状的形式表现编程语言的语法结构,树上的每个节点都表示源代码中的一种结构。之所以说语法是“抽象”的,是因为这里的语法并不会表示出真实语法中出现的每个细节。比如,嵌套括号被隐含在树的结构中,并没有以节点的形式呈现;而类似于 if-condition-then 这样的条件跳转语句,可以使用带有两个分支的节点来表示。

  • Comments 注释, //-style 或是 /*-style
  • Declarations 声明,GenDecl (generic declaration node) 代表 import, constant, type 或 variable declaration. BadDecl 代表有语法错误的 node
  • Statements 常见的语句表达式,return, case, if 等等
  • File 代表一个 go 源码文件
  • Package 代表一组源代码文件
  • Expr 表达式 ArrayExpr, StructExpr, SliceExpr 等等

我们来看官网的一个demo

// src is the input for which we want to print the AST.
src := `
package main
func main() {
        println("Hello, World!")
}
`

// Create the AST by parsing src.
fset := token.NewFileSet() // positions are relative to fset
f, err := parser.ParseFile(fset, "", src, 0)
if err != nil {
    panic(err)
}

// Print the AST.
ast.Print(fset, f)
 0  *ast.File {
     1  .  Package: 2:1
     2  .  Name: *ast.Ident {
     3  .  .  NamePos: 2:9
     4  .  .  Name: "main"
     5  .  }
     6  .  Decls: []ast.Decl (len = 1) {
     7  .  .  0: *ast.FuncDecl {
     8  .  .  .  Name: *ast.Ident {
     9  .  .  .  .  NamePos: 3:6
    10  .  .  .  .  Name: "main"
    11  .  .  .  .  Obj: *ast.Object {
    12  .  .  .  .  .  Kind: func
    13  .  .  .  .  .  Name: "main"
    14  .  .  .  .  .  Decl: *(obj @ 7)
    15  .  .  .  .  }
    16  .  .  .  }
    17  .  .  .  Type: *ast.FuncType {
    18  .  .  .  .  Func: 3:1
    19  .  .  .  .  Params: *ast.FieldList {
    20  .  .  .  .  .  Opening: 3:10
    21  .  .  .  .  .  Closing: 3:11
    22  .  .  .  .  }
    23  .  .  .  }
    24  .  .  .  Body: *ast.BlockStmt {
    25  .  .  .  .  Lbrace: 3:13
    26  .  .  .  .  List: []ast.Stmt (len = 1) {
    27  .  .  .  .  .  0: *ast.ExprStmt {
    28  .  .  .  .  .  .  X: *ast.CallExpr {
    29  .  .  .  .  .  .  .  Fun: *ast.Ident {
    30  .  .  .  .  .  .  .  .  NamePos: 4:2
    31  .  .  .  .  .  .  .  .  Name: "println"
    32  .  .  .  .  .  .  .  }
    33  .  .  .  .  .  .  .  Lparen: 4:9
    34  .  .  .  .  .  .  .  Args: []ast.Expr (len = 1) {
    35  .  .  .  .  .  .  .  .  0: *ast.BasicLit {
    36  .  .  .  .  .  .  .  .  .  ValuePos: 4:10
    37  .  .  .  .  .  .  .  .  .  Kind: STRING
    38  .  .  .  .  .  .  .  .  .  Value: ""Hello, World!""
    39  .  .  .  .  .  .  .  .  }
    40  .  .  .  .  .  .  .  }
    41  .  .  .  .  .  .  .  Ellipsis: -
    42  .  .  .  .  .  .  .  Rparen: 4:25
    43  .  .  .  .  .  .  }
    44  .  .  .  .  .  }
    45  .  .  .  .  }
    46  .  .  .  .  Rbrace: 5:1
    47  .  .  .  }
    48  .  .  }
    49  .  }
    50  .  Scope: *ast.Scope {
    51  .  .  Objects: map[string]*ast.Object (len = 1) {
    52  .  .  .  "main": *(obj @ 11)
    53  .  .  }
    54  .  }
    55  .  Unresolved: []*ast.Ident (len = 1) {
    56  .  .  0: *(obj @ 29)
    57  .  }
    58  }

基于Go的AST能力,我们就可以获取到模块上面的注释,再判断注释上面是否有我们的关键字@wire即可确认该结构体需要被注入,最终构造出一个依赖关系拓扑图,生成对应文件。

类型识别

通过上面介绍的AST,我们可以遍历项目中的全部文件,收集关键信息,以下是demo代码

func GetPkg() (map[string]*ast.Package, error) {
   // 获取当前目录,类似pwd命令
   wd, err := os.Getwd()
   if err != nil {
      return nil, err
   }

   pkgs := make(map[string]*ast.Package)
   
   // 遍历指定目录的所有文件
   err = filepath.Walk(wd, func(path string, info fs.FileInfo, err error) error {
      if info.IsDir() || !strings.HasSuffix(info.Name(), ".go" ) {
         return nil
      }

      t := token.NewFileSet()

      // 使用ast工具进行分析
      src, err := parser.ParseFile(t, path, nil, parser.ParseComments)
      if err != nil {
         return err
      }
      name := src.Name.Name
      pkg, found := pkgs[name]
      if !found {
         pkg = &ast.Package{
            Name:  name,
            Files: make(map[string]*ast.File),
         }
         pkgs[name] = pkg
      }
      pkg.Files[path] = src

      return nil
   })
   if err != nil {
      return nil, err
   }

   return pkgs, nil
}

关键字查找

遍历完ast后,我们需要找到所有注释包含@wire的Func,以下是demo代码

type Generate struct {
   Package       string
   Type          string
   FuncName      string
   InParamsType  []string
   OutParamsType []string
}

func GetComment(pkgs map[string]*ast.Package) []Generate {
   g := []Generate{}

   // 遍历package下的所有文件
   for _, v := range pkgs {
      for _, f := range v.Files {
         // 遍历文件下面的所有声明,这里主要是函数声明
         for _, decl := range f.Decls {
            switch decl.(type) {
            case *ast.FuncDecl:
               fDecl := decl.(*ast.FuncDecl)
               match := false

 for _, comment := range fDecl.Doc.List {
                  if strings.Contains(comment.Text, "@wired" ) {
                     match = true
 break
}
               }
               if match {
                  inParams := []string{}
                  for _, param := range fDecl.Type.Params.List {
                     switch param.Type.(type) {
                     case *ast.StarExpr:
                        starExpr := param.Type.(*ast.StarExpr)
                        ident := starExpr.X.(*ast.Ident)
                        typeName := ident.Name
                        if ident.NamePos-starExpr.Star == 1 {
                           typeName = "*" + typeName
                        }
                        inParams = append(inParams, typeName)
                     }
                  }

                  outParams := []string{}
                  for _, param := range fDecl.Type.Results.List {
                     switch param.Type.(type) {
                     case *ast.StarExpr:
                        starExpr := param.Type.(*ast.StarExpr)
                        ident := starExpr.X.(*ast.Ident)
                        typeName := ident.Name
                        if ident.NamePos-starExpr.Star == 1 {
                           typeName = "*" + typeName
                        }
                        outParams = append(outParams, typeName)
                     }
                  }

                  g = append(g, Generate{
                     Package:       v.Name,
                     Type:          "FUNC" ,
                     FuncName:      fDecl.Name.Name,
                     InParamsType:  inParams,
                     OutParamsType: outParams,
                  })
               }
            }
         }
      }
   }

   return g
}

模板代码输出

func TopologicalSort(g []*Generate) []*Generate {
   newG := make([]*Generate, 0, len(g))
   count := map[*Generate][]string{}
   record := map[string][]*Generate{}
   // 初始化数据
   // count记录一个func依赖多少个入参
   // record记录一个类型参数被多少个func依赖
   // 拓扑整体思路为:
   // 循环找0依赖的Func,并将该Func的返回类型记录下来
   // 因为该Func的返回类型已经拿到,可以把所有依赖该类型的Func移除该类型依赖
   // 循环往复,直到count记录为空
   for _, generate := range g {
      count[generate] = count[generate]
      for _, in := range generate.InParamsType {
         count[generate] = append(count[generate], in)
         record[in] = append(record[in], generate)
      }
   }

   // demo代码,参数不规范会造成死循环
   for len(count) != 0 {
      for k, v := range count {
         if len(v) == 0 {
            newG = append(newG, k)
            // 这里仅处理了单返回Type的情况
            for _, generate := range record[k.OutParamsType[0]] {
               count[generate] = SliceRemove(count[generate], k.OutParamsType[0])
            }
            delete(count, k)
         }
      }
   }

   return g
}

func SliceRemove(ss []string, s string) []string {
   ret := make([]string, 0, len(ss))
   for _, s2 := range ss {
      if s2 != s {
         ret = append(ret, s2)
      }
   }

   return ret
}

func PrintTmpl(g []*Generate) {
   // 进行拓扑排序,找出没有依赖的接口
   g = TopologicalSort(g)
   for _, generate := range g {
      fmt.Println(generate.FuncName)
   }
}

func main() {
   pkgs, _ := GetPkg()
   PrintTmpl(GetComment(pkgs))
   // NewB
   // NewA
}

// NewB @wired
func NewB() *BB {
   return &BB{}
}

// NewA s
// @wired
func NewA(b *BB) *AA {
   return &AA{}
}

总结

该文简要介绍了两种依赖注入的实现方式,也是社区常用的两种方式,以下是社区的一些方案,想深入了解可以参考以下项目

反射自动注入:github.com/uber-go/dig

代码自动生成:github.com/google/wire

本文中出现的代码均为demo展示,旨在勾勒出核心思想,均存在一定问题:

  1. 在方案一中的字段填充的实现里,当package出现同名且类型一致时,会出现覆盖问题,导致生成失败

    1.   例:domain/a/model/xxx.go packahe为model 存在Order结构体,key为:model.Order
    2.   domain/b/model/xxx.go packahe为model 存在Order结构体,key为:model.Order
  2. 在方案二中的类型构造的实现里,代码仅支持了Func初始化的形式,对于携带字段的结构体是没有实现这一逻辑的

    1. type A struct {
      }
      
      type B struc {
          A A
      }
      // 针对B这种情况,上述代码是无法成功运行的
      
  3. 在方案二中的类型构造的实现里,创建顺序也被忽略,如果Inject的入参顺序调整,上面的代码将会报错

ret, err := Inject(NewAA,&BB{})
// 在调用NewAA时,由于无法找到&BB,会出现问题
  1. 在方案二中的类型识别的实现里,目录遍历是当前目录及其子目录,如果当前目录不是用户的根目录则无法变量全部文件
  2. 在方案二中的关键字查找的实现里,没有实现结构体的自动注入功能