通过实际应用学习golang反射

824 阅读1分钟

通过实际应用学习golang反射

在使用go-Micro微服务框架,编写api中间层的逻辑时,需要从micro api的request中获取到get或者post的请求参数,然后进行参数的合法性校验。

封装一个函数使用反射,获取到参数,以及校验请求方法,校验参数合法性,报错返回相应的报错信息

package vodutils

import (
	"errors"
	go_api "github.com/micro/go-micro/api/proto"
	microApi "github.com/micro/go-micro/api/proto"
	"pkg.lwbedu.com/gk-micro/common/errs"
	"pkg.lwbedu.com/gk-micro/common/net/api"
	"reflect"
	"strconv"
	"strings"
)

const GET string = "GET"
const POST string = "POST"

var TagKey = "param"
var MethodNotAllow = errors.New("method not allowed")
var NotStruct = errors.New("not struct")
var NotPointer = errors.New("not ptr")

//传入需要校验的方法,需要获取的参数对象指针,以及请求对象
func ParamValid(m string, getParam interface{}, req *microApi.Request) error {
	// 指针的value
	pv := reflect.ValueOf(getParam)
	// 非指针返回
	if pv.Kind() != reflect.Ptr {
		return errs.BadRequest(NotPointer.Error())
	}
	if pv.IsNil() {
		return errs.BadRequest("param is nil")
	}
	//获取到指针对应的value
	pv = pv.Elem()
	//指针的类型
	if pv.Kind() != reflect.Struct {
		return errs.BadRequest(NotStruct.Error())
	}
	//value对应的type
	pt := pv.Type()

	method := strings.ToUpper(m)
	if method != strings.ToUpper(req.Method) {
		return errs.BadRequest(MethodNotAllow.Error())
	}

	var params = map[string]*go_api.Pair{}

	if method == GET {
		params = req.Get
	} else if method == POST {
		params = req.Post
	} else {
		return errs.BadRequest(MethodNotAllow.Error())
	}
	for i := 0; i < pt.NumField(); i++ {
		tyfield := pt.Field(i)
		name := tyfield.Name
		filed := pv.FieldByName(name)
		tag := tyfield.Tag.Get(TagKey)
		//获得该参数
		paramVals := params[tag].Values
		//如果长度为1,则只有一个参数
		if len(paramVals) == 1 {
			paramVal := paramVals[0]
			if filed.CanSet() {
				switch filed.Kind() {
				case reflect.String:
					filed.SetString(paramVal)
				case reflect.Int:
					fallthrough
				case reflect.Int8:
					fallthrough
				case reflect.Int16:
					fallthrough
				case reflect.Int32:
					fallthrough
				case reflect.Uint:
					fallthrough
				case reflect.Uint8:
					fallthrough
				case reflect.Uint16:
					fallthrough
				case reflect.Uint32:
					fallthrough
				case reflect.Uint64:
					fallthrough
				case reflect.Int64:
					val, _ := strconv.ParseInt(paramVal, 10, 64)
					filed.SetInt(val)
				case reflect.Float32:
					fallthrough
				case reflect.Float64:
					val, _ := strconv.ParseFloat(paramVal, 64)
					filed.SetFloat(val)
				}
			}
		} else {
			//非1的话,为列表,需要赋值给相应的列表

		}
	}
	//校验参数
	err := api.Validate(pv.Interface())
	if err != nil {
		return errs.BadRequest(err.Error())
	}
	return nil
}