WebAssembly 解析 Go(wagon)

278 阅读9分钟

以 Wagon 为例, Golang 解析 wasm

wasm 原理

wasm 指令的解析,其实都是 入栈,出栈的操作, 它是一个基于栈的虚拟机,比如
get_local 0, 它就是获取函数的第一个参数,并把它放到栈里.
i32.const 42 就是把一个 42(int32)放入栈中.
i32.add 就是从栈中取出两个数,相加后再放回栈里。

下面看一个具体的例子

cpp 如下

extern "C" {
  int large(int num) {
    if (num > 10) {
      num = num + 12;
    } else {
      num = num + 100;
    }
    return num;
  }
}

指定 Optimization Level -o3 优化后 编译后的 wast 如下

 (table 0 anyfunc)
 (memory $0 1)
 (export "memory" (memory $0))
 (export "large" (func $large))
 (func $large (; 0 ;) (param $0 i32) (result i32)
  (i32.add  // 8
   (select  // 6
    (i32.const 12) // 1
    (i32.const 100) // 2
    (i32.gt_s // 5
     (get_local $0) // 3
     (i32.const 10) // 4
    )
   ) 
   (get_local $0) // 7
  )
 )
)

指令解析 $0 为函数输入参数

  • (i32.const 12) 将 12 pushstack;
    • stack=>[12]
  • (i32.const 100) 将 100 pushstack;
    • stack=>[100, 12]
  • (get_local $0) 将 $0 (参数) 从local中读取,并pushstack;
    • stack=>[$0, 100, 12]
  • (i32.const 10) 将 10 pushstack;
    • stack=>[10, $0, 100, 12]
  • (i32.gt_s ()())stackpopv1, v2,并比较大小,v1 > v2push 1 到stack, 反之 push 0;
    • stack=>[1, 100, 12]
  • (select ()()() )stackpop 1 => v3,从 stackpop 100 => v2, 从 stack 中pop 12 => v1,if v3 为 1(true), 将 v2 pushstack, 反之 将 v1 pushstack;
    • stack=>[100]
  • (get_local $0) 将 $0 (参数) 从local中读取,并pushstack;
    • stack=>[$0, 100]
  • (i32.add ()()) 从stack 中pop 两个数,相加后push 到stack;
    • stack=>[108]
  • 返回结果 108

做一个 webassembly 的虚拟机主要分两块, compileInterpreter. 我们先看 compile 模块.

Compile

  • 编译主要是对 wasm 结构进行解析, 首先看看 module 对象,这是个核心类, wasm 就是解析到这个对象
type Module struct {
	Version  uint32			// wasm 的版本
	Sections []Section		// wasm 中所有的section 数组, 一个 wasm 文件, 就是由version 和多个 section 组成
	Types    *SectionTypes 	// wasm 中所有的函数描述
	Import   *SectionImports // wasm 中导入的函数
	Function *SectionFunctions // wasm 中声明的函数,每个函数对应一个index 指向 Types内的函数类型
	Table    *SectionTables 
	Memory   *SectionMemories
	Global   *SectionGlobals
	Export   *SectionExports	// wasm 中导出的函数描述
	Start    *SectionStartFunction	// 需要立刻执行的函数
	Elements *SectionElements	// 定义在 table 中的元素
	Code     *SectionCode	// 该 module 的所有函数信息数据
	Data     *SectionData  // 数据区, 比如一些字符串等数据, 会放在Data里, 用 offset 标记
	Customs  []*SectionCustom 

	// The function index space of the module
	FunctionIndexSpace []Function // wasm 中所有的函数包括 SectionImports 和 SectionFunctions,函数中的 type 指向 Types 中的类型
	GlobalIndexSpace   []GlobalEntry

	// function indices into the global function space
	// the limit of each table is its capacity (cap)
	TableIndexSpace        [][]uint32
	LinearMemoryIndexSpace [][]byte // 线性内存, Data 数据会存放在这里

	imports struct {
		Funcs    []uint32  // 导入的函数
		Globals  int 
		Tables   int
		Memories int
	}
}
  • 读取 wasm 文件
// 从本地读取一个 wasm 文件,并返回 module
func ReadModule(r io.Reader, resolvePath ResolveFunc) (*Module, error) {
	// 通过解析 二进制 wasm 文件,将数据解析道对应的 section 中去
	m, err := DecodeModule(r)

	...
	if m.Import != nil && resolvePath != nil {
		if m.Code == nil {
			m.Code = &SectionCode{}
		}
		// 解析 导入 的 module
		err := m.resolveImports(resolvePath)
	}
	for _, fn := range []func() error{
		m.populateGlobals,
		// 将内部函数转化为 Function 对象,并将 导入的函数也 一并添加到 FunctionIndexSpace 中
		m.populateFunctions,
		m.populateTables,
		// 将 m.Data 放到线性内存中
		m.populateLinearMemory,
	} {
		if err := fn(); err != nil {
			return nil, err
		}
	}
	return m, nil
}

func DecodeModule(r io.Reader) (*Module, error) {
	reader := &readpos.ReadPos{
		R:      r,
		CurPos: 0,
	}
	m := &Module{}
	...

	err = newSectionsReader(m).readSections(reader)
	return m, nil
}
	
  • DecodeModule 新建 sectionReader, 并调用 readSections
func (s *sectionsReader) readSections(r *readpos.ReadPos) error {
	for {
		// 循环读取section,知道读完
		done, err := s.readSection(r)
		switch {
		case err != nil:
			return err
		case done:
			return nil
		}
	}
}

// 从reader 中读取一个有效的 section. The first return value is true if and only if
// the module has been completely read.
func (sr *sectionsReader) readSection(r *readpos.ReadPos) (bool, error) {
	m := sr.m

	logger.Println("Reading section ID")
	// 从 reader 中读取一个字节
	id, err := r.ReadByte()
	...

	s := RawSection{ID: SectionID(id)}

	logger.Println("Reading payload length")
	// 读取实际 数据
	payloadDataLen, err := leb128.ReadVarUint32(r)
	if err != nil {
		return false, err
	}

	logger.Printf("Section payload length: %d", payloadDataLen)

	s.Start = r.CurPos

	sectionBytes := new(bytes.Buffer)

	sectionBytes.Grow(int(getInitialCap(payloadDataLen)))
	sectionReader := io.LimitReader(io.TeeReader(r, sectionBytes), int64(payloadDataLen))
	
	// 判断section 的类型,并将该类型空的 struct 赋值给 module 对应的属性
	var sec Section
	switch s.ID {
	case SectionIDCustom:
		logger.Println("section custom")
		cs := &SectionCustom{}
		m.Customs = append(m.Customs, cs)
		sec = cs
	case SectionIDType:
		logger.Println("section type")
		m.Types = &SectionTypes{}
		sec = m.Types
	case SectionIDImport:
		logger.Println("section import")
		m.Import = &SectionImports{}
		sec = m.Import
	case SectionIDFunction:
		logger.Println("section function")
		m.Function = &SectionFunctions{}
		sec = m.Function
	case SectionIDTable:
		logger.Println("section table")
		m.Table = &SectionTables{}
		sec = m.Table
	case SectionIDMemory:
		logger.Println("section memory")
		m.Memory = &SectionMemories{}
		sec = m.Memory
	case SectionIDGlobal:
		logger.Println("section global")
		m.Global = &SectionGlobals{}
		sec = m.Global
	case SectionIDExport:
		logger.Println("section export")
		m.Export = &SectionExports{}
		sec = m.Export
	case SectionIDStart:
		logger.Println("section start")
		m.Start = &SectionStartFunction{}
		sec = m.Start
	case SectionIDElement:
		logger.Println("section element")
		m.Elements = &SectionElements{}
		sec = m.Elements
	case SectionIDCode:
		logger.Println("section code")
		m.Code = &SectionCode{}
		sec = m.Code
	case SectionIDData:
		logger.Println("section data")
		m.Data = &SectionData{}
		sec = m.Data
	default:
		return false, InvalidSectionIDError(s.ID)
	}
	// 从reader 中读取数据,存入 section (对应到 module 的某个变量中)
	err = sec.ReadPayload(sectionReader)
	if err != nil {
		logger.Println(err)
		return false, err
	}
	s.End = r.CurPos
	s.Bytes = sectionBytes.Bytes()
	// 将 raw s 保存到 对应的 xxxSection 中
	*sec.GetRawSection() = s
	
	...
	
	// 保存 section 
	m.Sections = append(m.Sections, sec)
	return false, nil
}
  • 将文件读取到 module 中后,还需要加载 import 的模块
// 解析import 的函数
func (module *Module) resolveImports(resolve ResolveFunc) error {
	if module.Import == nil {
		return nil
	}
	modules := make(map[string]*Module)

	var funcs uint32
	// 遍历 module.Import 下的 ”入口“
	for _, importEntry := range module.Import.Entries {
		importedModule, ok := modules[importEntry.ModuleName]
		if !ok {
			var err error
			// 如果不存在,就调用外部注入的 resolver 函数解析,并返回 module 对象
			importedModule, err = resolve(importEntry.ModuleName)
			if err != nil {
				return err
			}
			// 将导入的 module 保存起来
			modules[importEntry.ModuleName] = importedModule
		}

		if importedModule.Export == nil {
			return ErrNoExportsInImportedModule
		}
		// 判断 导入的module 中是否暴露了 importEntry.FieldName(本module 需要调用的方法)
		exportEntry, ok := importedModule.Export.Entries[importEntry.FieldName]
		if !ok {
			return ExportNotFoundError{importEntry.ModuleName, importEntry.FieldName}
		}
		// 判断 待导入函数类型, 与被导入模块的函数类型 是否一致
		if exportEntry.Kind != importEntry.Type.Kind() {
			return KindMismatchError{
				FieldName:  importEntry.FieldName,
				ModuleName: importEntry.ModuleName,
				Import:     importEntry.Type.Kind(),
				Export:     exportEntry.Kind,
			}
		}

		index := exportEntry.Index
		switch exportEntry.Kind {
		case ExternalFunction:
			// 根据 exportEntry 对应的 functionIndex ,获取对应的 Function 类型
			fn := importedModule.GetFunction(int(index))
			if fn == nil {
				return InvalidFunctionIndexError(index)
			}
			
			importIndex := importEntry.Type.(FuncImport).Type
			// 下面就判断 待带入的function  和 别导入的 function 的类型是否一致
			// 比较参数以及返回值长度
			if len(fn.Sig.ReturnTypes) != len(module.Types.Entries[importIndex].ReturnTypes) || len(fn.Sig.ParamTypes) != len(module.Types.Entries[importIndex].ParamTypes) {
				return InvalidImportError{importEntry.ModuleName, importEntry.FieldName, importIndex}
			}
			// 比较返回值类型
			for i, typ := range fn.Sig.ReturnTypes {
				if typ != module.Types.Entries[importIndex].ReturnTypes[i] {
					return InvalidImportError{importEntry.ModuleName, importEntry.FieldName, importIndex}
				}
			}
			// 比较参数类型
			for i, typ := range fn.Sig.ParamTypes {
				if typ != module.Types.Entries[importIndex].ParamTypes[i] {
					return InvalidImportError{importEntry.ModuleName, importEntry.FieldName, importIndex}
				}
			}
			// 将 Function 对象(被导入的函数),添加到 module 的 FunctionIndexSpace 数组中
			module.FunctionIndexSpace = append(module.FunctionIndexSpace, *fn)
			// 保存 Function 的函数体
			module.Code.Bodies = append(module.Code.Bodies, *fn.Body)
			// 将 Function 对象保存到 module 的 import.Funcs 数组中
			module.imports.Funcs = append(module.imports.Funcs, funcs)
			funcs++
		case ExternalGlobal:
			// todo ...
			glb := importedModule.GetGlobal(int(index))
			if glb == nil {
				return InvalidGlobalIndexError(index)
			}
			if glb.Type.Mutable {
				return ErrImportMutGlobal
			}
			module.GlobalIndexSpace = append(module.GlobalIndexSpace, *glb)
			module.imports.Globals++

			// In both cases below, index should be always 0 (according to the MVP)
			// We check it against the length of the index space anyway.
		case ExternalTable:
			if int(index) >= len(importedModule.TableIndexSpace) {
				return InvalidTableIndexError(index)
			}
			module.TableIndexSpace[0] = importedModule.TableIndexSpace[0]
			module.imports.Tables++
		case ExternalMemory:
			if int(index) >= len(importedModule.LinearMemoryIndexSpace) {
				return InvalidLinearMemoryIndexError(index)
			}
			module.LinearMemoryIndexSpace[0] = importedModule.LinearMemoryIndexSpace[0]
			module.imports.Memories++
		default:
			return InvalidExternalError(exportEntry.Kind)
		}
	}
	return nil
}

  • populateFunctions
// 函数索引空间索引所有导入和内部定义的函数定义
func (m *Module) populateFunctions() error {
	
	...
	
	// 给内部定义的 func 构造 fn
	// Add the functions from the wasm itself to the function list
	numImports := len(m.FunctionIndexSpace)
	for codeIndex, typeIndex := range m.Function.Types {
		if int(typeIndex) >= len(m.Types.Entries) {
			return InvalidFunctionIndexError(typeIndex)
		}
		// Create the main function structure
		fn := Function{
			Sig:  &m.Types.Entries[typeIndex],
			Body: &m.Code.Bodies[codeIndex],
			Name: names[uint32(codeIndex+numImports)], // Add the name string if we have it
		}
		m.FunctionIndexSpace = append(m.FunctionIndexSpace, fn)
	}

	funcs := make([]uint32, 0, len(m.Function.Types)+len(m.imports.Funcs))
	funcs = append(funcs, m.imports.Funcs...)
	funcs = append(funcs, m.Function.Types...)
	m.Function.Types = funcs
	return nil
}
  • 新建一个 VM

先看 vm 类型

// VM is the execution context for executing WebAssembly bytecode.
type VM struct {
	ctx context // 执行上下文
		type context struct {
			stack   []uint64  // 栈深度
			locals  []uint64	  // 局部变量
			code    []byte	// 函数的字节码
			asm     []asmBlock
			pc      int64		// 当前的字节码 index
			curFunc int64		// 当前函数在 funcs 的index
		}

	module  *wasm.Module  
	globals []uint64
	memory  []byte
	funcs   []function // 函数数组 compiledFunction or goFunction

	funcTable [256]func()  // 指令集,对应的解析函数

	// RecoverPanic controls whether the `ExecCode` method
	// recovers from a panic and returns it as an error
	// instead.
	// A panic can occur either when executing an invalid VM
	// or encountering an invalid instruction, e.g. `unreachable`.
	RecoverPanic bool

	abort bool // Flag for host functions to terminate execution

	nativeBackend *nativeCompiler
}

// 通过 module 对象和 options 构造一个 vm
func NewVM(module *wasm.Module, opts ...VMOption) (*VM, error) {
	var vm VM
	var options config
	for _, opt := range opts {
		opt(&options)
	}
	if module.Memory != nil && len(module.Memory.Entries) != 0 {
		if len(module.Memory.Entries) > 1 {
			return nil, ErrMultipleLinearMemories
		}
		vm.memory = make([]byte, uint(module.Memory.Entries[0].Limits.Initial)*wasmPageSize)
		copy(vm.memory, module.LinearMemoryIndexSpace[0])
	}

	vm.funcs = make([]function, len(module.FunctionIndexSpace))
	vm.globals = make([]uint64, len(module.GlobalIndexSpace))
	vm.newFuncTable()
	vm.module = module

	nNatives := 0
	for i, fn := range module.FunctionIndexSpace {
	
		// 如果是 import 的原生 golang 方法,使用 goFunction 处理
		if fn.IsHost() {
			vm.funcs[i] = goFunction{
				typ: fn.Host.Type(),
				val: fn.Host,
			}
			nNatives++
			continue
		}
		// 将function拆卸并封装成新的结构
		disassembly, err := disasm.NewDisassembly(fn, module)
		if err != nil {
			return nil, err
		}

		totalLocalVars := 0
		totalLocalVars += len(fn.Sig.ParamTypes)
		for _, entry := range fn.Body.Locals {
			totalLocalVars += int(entry.Count)
		}
		// 编译 字节码
		code, meta := compile.Compile(disassembly.Code)
		vm.funcs[i] = compiledFunction{
			codeMeta:       meta,
			code:           code,
			branchTables:   meta.BranchTables,
			maxDepth:       disassembly.MaxDepth,
			totalLocalVars: totalLocalVars,
			args:           len(fn.Sig.ParamTypes),
			returns:        len(fn.Sig.ReturnTypes) != 0,
		}
	}

	...
	
	return &vm, nil
}

Interpreter

下面执行代码的过程,即是翻译代码的过程

// fnIndex 函数的index, args 是该函数的参数
func (vm *VM) ExecCode(fnIndex int64, args ...uint64) (rtrn interface{}, err error) {
	...
	
	if int(fnIndex) > len(vm.funcs) {
		return nil, InvalidFunctionIndexError(fnIndex)
	}
	if len(vm.module.GetFunction(int(fnIndex)).Sig.ParamTypes) != len(args) {
		return nil, ErrInvalidArgumentCount
	}
	compiled, ok := vm.funcs[fnIndex].(compiledFunction)
	if !ok {
		panic(fmt.Sprintf("exec: function at index %d is not a compiled function", fnIndex))
	}

	depth := compiled.maxDepth + 1
	// 初始化执行栈
	if cap(vm.ctx.stack) < depth {
		vm.ctx.stack = make([]uint64, 0, depth)
	} else {
		vm.ctx.stack = vm.ctx.stack[:0]
	}

	vm.ctx.locals = make([]uint64, compiled.totalLocalVars)
	vm.ctx.pc = 0
	vm.ctx.code = compiled.code
	vm.ctx.asm = compiled.asm
	vm.ctx.curFunc = fnIndex
	// 给函数的参数赋值
	for i, arg := range args {
		vm.ctx.locals[i] = arg
	}

	res := vm.execCode(compiled)
	if compiled.returns {
		rtrnType := vm.module.GetFunction(int(fnIndex)).Sig.ReturnTypes[0]
		switch rtrnType {
		case wasm.ValueTypeI32:
			rtrn = uint32(res)
		case wasm.ValueTypeI64:
			rtrn = uint64(res)
		case wasm.ValueTypeF32:
			rtrn = math.Float32frombits(uint32(res))
		case wasm.ValueTypeF64:
			rtrn = math.Float64frombits(res)
		default:
			return nil, InvalidReturnTypeError(rtrnType)
		}
	}

	return rtrn, nil
}


func (vm *VM) execCode(compiled compiledFunction) uint64 {
outer:
	for int(vm.ctx.pc) < len(vm.ctx.code) && !vm.abort {
		op := vm.ctx.code[vm.ctx.pc]
		vm.ctx.pc++
		switch op {
		// 解析到 return 指令的时候,退出循环
		case ops.Return:
			break outer
			
		// 省略一些不常用的case
		...
		
		default:
			// 大部分会走这个case
			vm.funcTable[op]()
		}
	}

	if compiled.returns {
		//如果有返回值,从栈中取出返回
		return vm.ctx.stack[len(vm.ctx.stack)-1]
	}
	return 0
}

funcTable [256]func() 的初始化 ,一个指令(Op)对应一个解析方法, 看看 Op 的结构

// Op describes a WASM operator.
type Op struct {
	Code byte   // The single-byte opcode
	Name string // 该操作的名称

	// Whether this operator is polymorphic.
	// A polymorphic operator has a variable arity. call, call_indirect, and
	// drop are examples of polymorphic operators.
	Polymorphic bool // 是否是动态的, true:比如一些逻辑控制语句, 还有 get/setlocal 等
	Args        []wasm.ValueType // 该指令需要的参数类型(数量)(会从栈中pop出来)
	Returns     wasm.ValueType   // 返回的参数类型
}

func (vm *VM) newFuncTable() {
	vm.funcTable[ops.I32Clz] = vm.i32Clz
	vm.funcTable[ops.I32Ctz] = vm.i32Ctz
	vm.funcTable[ops.I32Popcnt] = vm.i32Popcnt
	
	vm.funcTable[ops.I32Add] = vm.i32Add
	vm.funcTable[ops.I32Sub] = vm.i32Sub
	vm.funcTable[ops.I32Mul] = vm.i32Mul

	....
	....

	vm.funcTable[ops.Drop] = vm.drop
	vm.funcTable[ops.Select] = vm.selectOp

	vm.funcTable[ops.GetLocal] = vm.getLocal
	vm.funcTable[ops.SetLocal] = vm.setLocal
	vm.funcTable[ops.TeeLocal] = vm.teeLocal
	vm.funcTable[ops.GetGlobal] = vm.getGlobal
	vm.funcTable[ops.SetGlobal] = vm.setGlobal

	vm.funcTable[ops.Unreachable] = vm.unreachable
	vm.funcTable[ops.Nop] = vm.nop

	vm.funcTable[ops.Call] = vm.call
	vm.funcTable[ops.CallIndirect] = vm.callIndirect
}


例如

// 从栈中pop 两个uint32 出来,相加后在push 到栈中
func (vm *VM) i32Add() {
	vm.pushUint32(vm.popUint32() + vm.popUint32())
}

这里再讲一下 vm.call

func (vm *VM) call() {
	index := vm.fetchUint32()
	// 这里会从 funcs 数组里取出 Function(or goFunction) 对象,调用call
	vm.funcs[index].call(vm, int64(index))
}

// goFunction 利用反射机制,执行函数
func (fn goFunction) call(vm *VM, index int64) {
	// numIn = # of call inputs + vm, as the function expects
	// an additional *VM argument
	numIn := fn.typ.NumIn()
	args := make([]reflect.Value, numIn)
	proc := NewProcess(vm)

	// 第一个参数必须是 *Process 类型
	if reflect.ValueOf(proc).Kind() != fn.typ.In(0).Kind() {
		panic(fmt.Sprintf("exec: the first argument of a host function was %s, expected %s", fn.typ.In(0).Kind(), reflect.ValueOf(vm).Kind()))
	}
	args[0] = reflect.ValueOf(proc)
	// 给函数的参数赋值
	for i := numIn - 1; i >= 1; i-- {
		val := reflect.New(fn.typ.In(i)).Elem()
		raw := vm.popUint64()
		kind := fn.typ.In(i).Kind()

		switch kind {
		case reflect.Float64, reflect.Float32:
			val.SetFloat(math.Float64frombits(raw))
		case reflect.Uint32, reflect.Uint64:
			val.SetUint(raw)
		case reflect.Int32, reflect.Int64:
			val.SetInt(int64(raw))
		default:
			panic(fmt.Sprintf("exec: args %d invalid kind=%v", i, kind))
		}

		args[i] = val
	}
	// 执行函数
	rtrns := fn.val.Call(args)
	// 将返回值 push 到栈中
	for i, out := range rtrns {
		kind := out.Kind()
		switch kind {
		case reflect.Float64, reflect.Float32:
			vm.pushFloat64(out.Float())
		case reflect.Uint32, reflect.Uint64:
			vm.pushUint64(out.Uint())
		case reflect.Int32, reflect.Int64:
			vm.pushInt64(out.Int())
		default:
			panic(fmt.Sprintf("exec: return value %d invalid kind=%v", i, kind))
		}
	}
}

// 非原生函数实现
func (compiled compiledFunction) call(vm *VM, index int64) {

	newStack := make([]uint64, 0, compiled.maxDepth+1)
	locals := make([]uint64, compiled.totalLocalVars)
	// 给参数赋值
	for i := compiled.args - 1; i >= 0; i-- {
		locals[i] = vm.popUint64()
	}

	//保存执行上下文
	prevCtxt := vm.ctx
	// 新建执行上下文
	vm.ctx = context{
		stack:   newStack,
		locals:  locals,
		code:    compiled.code,
		asm:     compiled.asm,
		pc:      0,
		curFunc: index,
	}
	
	rtrn := vm.execCode(compiled)

	//被调用函数执行完了,恢复上下文
	vm.ctx = prevCtxt
	
	if compiled.returns {
		// 把返回值push到栈中
		vm.pushUint64(rtrn)
	}
}