给proto添加参数校验

10,424 阅读4分钟

应用产景

对外提供的proto接口,需要进行参数校验。以往的做法是在代码逻辑中添加校验规则,就会出现下面示例的写法:

if req.Caller == nil || equip == nil || req.Action == coll_common.Action_OP_UNKNOWN || len(equip.Name) == 0 || equip.CategoryId == 0 ||
   len(equip.LogoList) == 0 || len(equip.Model) == 0 || len(equip.TechnicalIndex) == 0 || len(equip.Description) == 0 ||
   equip.RentCount == 0 || equip.MinimumRental == 0 || equip.Rent == 0 || equip.ContactPerson == nil || equip.Location == nil {
   logrus.Errorf("[AddEquipment] required parameter can not be null")
   rsp.CommonRsp.Code = common.CodeInvalidParam
   rsp.CommonRsp.Msg = "required parameter can not be null"
   return nil
}

这样写的代码可读性较差、对每个错误无法进行详细的处理,返回统一的错误结果难以进行排查,若对每个错误信息进行处理,比较耗时,而且重复的工作较多,也比较容易出bug。由此想到,能否在定义proto的时候就对数据类型做限制,通过编译,将校验的方法绑定在请求信息上,这样可以提升代码的可读性,也能减少很多重复性的工作。

工具对比 go-proto-validators和protoc-gen-validate

Name | contributors | stars | lang | currently | documentation Completed :----------- | :-----------: | -----------: go-proto-validators | 14 | 605 | go | stable | low protoc-gen-validate | 56 | 982 | go、java、python、c++ | alpha | high

protoc-gen-validate:支持多语言、社区活跃度上也比 go-proto-validators略高,文档也比较全,但是还是处于不稳定的版本,后续的api很可能会产生变化。 go-proto-validate:支持各方面用法和原理实际上和protoc-gen-validate相差不大,但是是处于比较稳定的版本,所以暂时采用这个方案。

使用

1、将gopath下的可执行文件目录添加到环境变量下

export PATH=${PATH}:${GOPATH}/bin

2、安装

go get github.com/mwitkow/go-proto-validators/protoc-gen-govalidators

3、编写protobuf

syntax = "proto3";
package test_bid;

import "github.com/mwitkow/go-proto-validators/validator.proto";
import "git.code.oa.com/cloud_industry/proto/common/common.proto";
import "coll_common/common.proto";
import "google/protobuf/descriptor.proto";


option go_package = "git.code.oa.com/cloud_industry/colla_proto/test_bid";

// 投标管理
service BidManagement {
    //添加投标
    rpc DoBid (DoBidReq) returns (DoBidRsp) {
    }
}

enum Action {
    OP_UNKNOWN = 0; //未知
    OP_DRAFT = 1; //保存草稿
    OP_COMMIT = 2; //提交审核
    OP_REVOKE = 3; //撤销审核
    OP_PASS = 4; //通过
    OP_DENY = 5; //驳回
    OP_RELEASE = 6; //上架
    OP_OFF_RELEASE = 7; //下架
    OP_RECOMMEND = 8; //推荐
    OP_CANCEL_RECOMMEND = 9; //取消推荐
    OP_MOVE_UP = 10; //上移
    OP_MOVE_DOWN = 11; //下移
    OP_CLOSE = 12; //关闭
    OP_SUCC = 13; //交易成功
    OP_FAIL = 14; //交易失败
}

message DoBidReq {
    common.Caller Caller = 1 [(validator.field) = {msg_exists: true}];
    int64 DemandId = 2 [(validator.field) = {int_lt: 0}]; // 需求id
    int64 ServicePrice = 3; // 交付报价,单位分
    int32 ServicePeriod = 4; // 交付周期,单位天
    common.AppendixInfo Attachment = 5 [(validator.field) = {msg_exists: true}]; // 附件
    string Connect = 6 [(validator.field) = {string_not_empty: true}]; // 联系人
    bytes ConnectPhone = 7 [(validator.field) = {length_gt: 0, length_lt: 255}]; // 联系手机
    string SMSVerifyCode = 8; // 短信验证码
    Action Action = 9 [(validator.field) = {is_in_enum: true}]; // 操作
}

message DoBidRsp {
    common.CommonRsp CommonRsp = 1;
}

相应的语法支持可以在validator.proto文件中查看

// Copyright 2016 Michal Witkowski. All Rights Reserved.
// See LICENSE for licensing terms.

// Protocol Buffers extensions for defining auto-generateable validators for messages.

// TODO(mwitkow): Add example.


syntax = "proto2";
package validator;

import "google/protobuf/descriptor.proto";

option go_package = "github.com/mwitkow/go-proto-validators;validator";

// TODO(mwitkow): Email protobuf-global-extension-registry@google.com to get an extension ID.

extend google.protobuf.FieldOptions {
  optional FieldValidator field = 65020;
}

extend google.protobuf.OneofOptions {
  optional OneofValidator oneof = 65021;
}

message FieldValidator {
  // Uses a Golang RE2-syntax regex to match the field contents.
  optional string regex = 1;
  // Field value of integer strictly greater than this value.
  optional int64 int_gt = 2;
  // Field value of integer strictly smaller than this value.
  optional int64 int_lt = 3;
  // Used for nested message types, requires that the message type exists.
  optional bool msg_exists = 4;
  // Human error specifies a user-customizable error that is visible to the user.
  optional string human_error = 5;
  // Field value of double strictly greater than this value.
  // Note that this value can only take on a valid floating point
  // value. Use together with float_epsilon if you need something more specific.
  optional double float_gt = 6;
  // Field value of double strictly smaller than this value.
  // Note that this value can only take on a valid floating point
  // value. Use together with float_epsilon if you need something more specific.
  optional double float_lt = 7;
  // Field value of double describing the epsilon within which
  // any comparison should be considered to be true. For example,
  // when using float_gt = 0.35, using a float_epsilon of 0.05
  // would mean that any value above 0.30 is acceptable. It can be
  // thought of as a {float_value_condition} +- {float_epsilon}.
  // If unset, no correction for floating point inaccuracies in
  // comparisons will be attempted.
  optional double float_epsilon = 8;
  // Floating-point value compared to which the field content should be greater or equal.
  optional double float_gte = 9;
  // Floating-point value compared to which the field content should be smaller or equal.
  optional double float_lte = 10;
  // Used for string fields, requires the string to be not empty (i.e different from "").
  optional bool string_not_empty = 11;
  // Repeated field with at least this number of elements.
  optional int64 repeated_count_min = 12;
  // Repeated field with at most this number of elements.
  optional int64 repeated_count_max = 13;
  // Field value of length greater than this value.
  optional int64 length_gt = 14;
  // Field value of length smaller than this value.
  optional int64 length_lt = 15;
  // Field value of length strictly equal to this value.
  optional int64 length_eq = 16;
  // Requires that the value is in the enum.
  optional bool is_in_enum = 17;
  // Ensures that a string value is in UUID format.
  // uuid_ver specifies the valid UUID versions. Valid values are: 0-5.
  // If uuid_ver is 0 all UUID versions are accepted.
  optional int32 uuid_ver = 18;
}

message OneofValidator {
  // Require that one of the oneof fields is set.
  optional bool required = 1;
}

4、编译

在protoc的编译语法中多加一个参数govalidators_out

protoc  \
  --proto_path=${GOPATH}/src \
  --proto_path=${GOPATH}/src/github.com/google/protobuf/src \
  --proto_path=. \
  --go_out=. \
  --govalidators_out=. \
  *.proto

执行完后会生成一个test_bid.validator.pb.go的文件

// Code generated by protoc-gen-gogo. DO NOT EDIT.
// source: test_bid/test_bid.proto

package test_bid

import (
   fmt "fmt"
   math "math"
   proto "github.com/golang/protobuf/proto"
   _ "github.com/mwitkow/go-proto-validators"
   _ "git.code.oa.com/cloud_industry/proto/common"
   _ "git.code.oa.com/cloud_industry/colla_proto/coll_common"
   github_com_mwitkow_go_proto_validators "github.com/mwitkow/go-proto-validators"
)

// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf

func (this *DoBidReq) Validate() error {
   if nil == this.Caller {
      return github_com_mwitkow_go_proto_validators.FieldError("Caller", fmt.Errorf("message must exist"))
   }
   if this.Caller != nil {
      if err := github_com_mwitkow_go_proto_validators.CallValidatorIfExists(this.Caller); err != nil {
         return github_com_mwitkow_go_proto_validators.FieldError("Caller", err)
      }
   }
   if !(this.DemandId < 0) {
      return github_com_mwitkow_go_proto_validators.FieldError("DemandId", fmt.Errorf(`value '%v' must be less than '0'`, this.DemandId))
   }
   if nil == this.Attachment {
      return github_com_mwitkow_go_proto_validators.FieldError("Attachment", fmt.Errorf("message must exist"))
   }
   if this.Attachment != nil {
      if err := github_com_mwitkow_go_proto_validators.CallValidatorIfExists(this.Attachment); err != nil {
         return github_com_mwitkow_go_proto_validators.FieldError("Attachment", err)
      }
   }
   if this.Connect == "" {
      return github_com_mwitkow_go_proto_validators.FieldError("Connect", fmt.Errorf(`value '%v' must not be an empty string`, this.Connect))
   }
   if !(len(this.ConnectPhone) > 0) {
      return github_com_mwitkow_go_proto_validators.FieldError("ConnectPhone", fmt.Errorf(`value '%v' must have a length greater than '0'`, this.ConnectPhone))
   }
   if !(len(this.ConnectPhone) < 255) {
      return github_com_mwitkow_go_proto_validators.FieldError("ConnectPhone", fmt.Errorf(`value '%v' must have a length smaller than '255'`, this.ConnectPhone))
   }
   if _, ok := Action_name[int32(this.Action)]; !ok {
      return github_com_mwitkow_go_proto_validators.FieldError("Action", fmt.Errorf(`value '%v' must be a valid Action field`, this.Action))
   }
   return nil
}
func (this *DoBidRsp) Validate() error {
   if this.CommonRsp != nil {
      if err := github_com_mwitkow_go_proto_validators.CallValidatorIfExists(this.CommonRsp); err != nil {
         return github_com_mwitkow_go_proto_validators.FieldError("CommonRsp", err)
      }
   }
   return nil
}

该文件将验证方法绑定在了DoBidReq的请求上,并且每个参数的校验都都有自己的验证逻辑的错误提示。

使用

func (c *CollaborativeAgent) DoBid(ctx context.Context, req *test_bid.DoBidReq, rsp *test_bid.DoBidRsp) error  {
   if err := req.Validate(); err != nil {
      logrus.Errorf("error is %v\n", err)
      rsp.CommonRsp = &coll_common.CommonRsp {
         Code: 400,
         Msg: err.Error(),
      }
   } else {
      rsp.CommonRsp = &coll_common.CommonRsp {
         Code: 200,
         Msg: "success",
      }
   }


   return nil
}

存在的问题

1、is_in_enum 使用枚举类型进行校验的时候,若引用其他包的枚举类型,验证代码生成会有问题

coll_common.Action Action = 9 [(validator.field) = {is_in_enum: true}]; // 操作

编译后的代码为

if _, ok := Action_name[int32(this.Action)]; !ok {
   return github_com_mwitkow_go_proto_validators.FieldError("Action", fmt.Errorf(`value '%v' must be a valid Action field`, this.Action))
}

而Action_name这个map是在common包下生成pb.go文件中

package coll_common

type Action int32

const (
   Action_OP_UNKNOWN          Action = 0
   Action_OP_DRAFT            Action = 1
   Action_OP_COMMIT           Action = 2
   Action_OP_REVOKE           Action = 3
   Action_OP_PASS             Action = 4
   Action_OP_DENY             Action = 5
   Action_OP_RELEASE          Action = 6
   Action_OP_OFF_RELEASE      Action = 7
   Action_OP_RECOMMEND        Action = 8
   Action_OP_CANCEL_RECOMMEND Action = 9
   Action_OP_MOVE_UP          Action = 10
   Action_OP_MOVE_DOWN        Action = 11
   Action_OP_CLOSE            Action = 12
   Action_OP_SUCC             Action = 13
   Action_OP_FAIL             Action = 14
)

var Action_name = map[int32]string{
   0:  "OP_UNKNOWN",
   1:  "OP_DRAFT",
   2:  "OP_COMMIT",
   3:  "OP_REVOKE",
   4:  "OP_PASS",
   5:  "OP_DENY",
   6:  "OP_RELEASE",
   7:  "OP_OFF_RELEASE",
   8:  "OP_RECOMMEND",
   9:  "OP_CANCEL_RECOMMEND",
   10: "OP_MOVE_UP",
   11: "OP_MOVE_DOWN",
   12: "OP_CLOSE",
   13: "OP_SUCC",
   14: "OP_FAIL",
}

var Action_value = map[string]int32{
   "OP_UNKNOWN":          0,
   "OP_DRAFT":            1,
   "OP_COMMIT":           2,
   "OP_REVOKE":           3,
   "OP_PASS":             4,
   "OP_DENY":             5,
   "OP_RELEASE":          6,
   "OP_OFF_RELEASE":      7,
   "OP_RECOMMEND":        8,
   "OP_CANCEL_RECOMMEND": 9,
   "OP_MOVE_UP":          10,
   "OP_MOVE_DOWN":        11,
   "OP_CLOSE":            12,
   "OP_SUCC":             13,
   "OP_FAIL":             14,
}

这里就出现了无法引用的问题,所以建议枚举定义写在同一个proto文件下。

2、如何正确判断字符串长度

string ConnectPhone = 7 [(validator.field) = {length_gt: 0, length_lt: 255}]

生成的校验代码为

if !(len(this.ConnectPhone) > 0) {
   return github_com_mwitkow_go_proto_validators.FieldError("ConnectPhone", fmt.Errorf(`value '%v' must have a length greater than '0'`, this.ConnectPhone))
}
if !(len(this.ConnectPhone) < 255) {
   return github_com_mwitkow_go_proto_validators.FieldError("ConnectPhone", fmt.Errorf(`value '%v' must have a length smaller than '255'`, this.ConnectPhone))
}

对于中文和特殊字符,此方法无法正确判断真实的字符串长度,golang默认编码是utf8,而在utf8中,中文占3个字节,因此一个中文字符的长度应该是3,所以在设计proto的时候需要注意这点。在业务代码中,可以通过

len([]rune(str))

判断真实的字符串长度,因为rune是基于int32,而string类型的底层类型byte是基于uint8