借助protogen生成自定义protoc插件-结构体赋值

74 阅读7分钟

背景

在与SDK交互过程中,v1.0.0 版本手动处理了 proto生成的 .go 文件,导致去除了 ,omitempty 字段。从而导致全部字段均会下发,浪费网络带宽。考虑到兼容性问题,想在 v2.0.0 版本进行升级。即根据版本可分别下发【全部字段】和【仅下发有效值】。

思考

基于该背景,于是考虑根据版本分别下发不一样的数据。

  • 基于版本用反射去重新生成标签【考虑 reflect 的性能问题,放弃】
  • 基于自定义标签控制JSON序列化,比如基础版代码【更多可以找AI编写】
package main

import (
    "encoding/json"
    "fmt"
    "reflect"
    "strconv"
)

// 自定义序列化条件标签
type User struct {
    Name     string `json:"name" condition:"nonzero"`        // 非空时才序列化
    Age      int    `json:"age" condition:"nonzero"`         // 非0时才序列化
    Email    string `json:"email" condition:"always"`        // 总是序列化
    Score    int    `json:"score" condition:"gt:0"`          // 大于0时才序列化
    Status   string `json:"status" condition:"in:active,pending"` // 在指定值中时才序列化
    Optional string `json:"optional" condition:"custom:CheckOptional"` // 自定义检查函数
}

// 自定义检查函数
func (u User) CheckOptional(fieldValue string) bool {
    return fieldValue != "" && fieldValue != "null"
}

// 实现自定义 MarshalJSON
func (u User) MarshalJSON() ([]byte, error) {
    m := make(map[string]interface{})
    v := reflect.ValueOf(u)
    t := reflect.TypeOf(u)
    
    for i := 0; i < v.NumField(); i++ {
        field := t.Field(i)
        fieldValue := v.Field(i)
        
        // 获取 JSON 标签名
        jsonTag := field.Tag.Get("json")
        if jsonTag == "" || jsonTag == "-" {
            continue
        }
        
        // 获取条件标签
        conditionTag := field.Tag.Get("condition")
        if conditionTag == "" {
            // 没有条件标签,直接添加
            m[jsonTag] = fieldValue.Interface()
            continue
        }
        
        // 检查是否满足序列化条件
        if shouldSerialize(fieldValue, conditionTag, u) {
            m[jsonTag] = fieldValue.Interface()
        }
    }
    
    return json.Marshal(m)
}

// 条件检查函数
func shouldSerialize(fieldValue reflect.Value, condition string, obj interface{}) bool {
    switch {
    case condition == "always":
        return true
        
    case condition == "nonzero":
        return !isZeroValue(fieldValue)
        
    case len(condition) > 3 && condition[:3] == "gt:":
        threshold, _ := strconv.Atoi(condition[3:])
        if fieldValue.Kind() == reflect.Int {
            return fieldValue.Int() > int64(threshold)
        }
        return false
        
    case len(condition) > 3 && condition[:3] == "in:":
        allowedValues := condition[3:]
        currentValue := fmt.Sprintf("%v", fieldValue.Interface())
        // 简单的包含检查,实际应用中需要更严谨的实现
        return contains(allowedValues, currentValue)
        
    case len(condition) > 7 && condition[:7] == "custom:":
        methodName := condition[7:]
        objValue := reflect.ValueOf(obj)
        method := objValue.MethodByName(methodName)
        if method.IsValid() {
            // 调用自定义检查方法
            result := method.Call([]reflect.Value{fieldValue})
            if len(result) > 0 {
                return result[0].Bool()
            }
        }
        return false
    }
    
    return true
}

// 辅助函数
func isZeroValue(v reflect.Value) bool {
    switch v.Kind() {
    case reflect.String:
        return v.String() == ""
    case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
        return v.Int() == 0
    case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
        return v.Uint() == 0
    case reflect.Float32, reflect.Float64:
        return v.Float() == 0
    case reflect.Bool:
        return !v.Bool()
    case reflect.Ptr, reflect.Interface:
        return v.IsNil()
    case reflect.Slice, reflect.Map:
        return v.Len() == 0
    default:
        return false
    }
}

func contains(str string, substr string) bool {
    // 简化实现,实际需要按逗号分割等
    return len(str) > 0 && len(substr) > 0
}

但是该方法侵入性较大,也不太适合。

  • 利用protoc 重新生成一份.go文件,其结构体携带 ,omitempty字段,只是结构体名称不一样。比如:
type User struct {
    Name     string `json:"name"`       
    Age      int    `json:"age"`     
}

变更为

type UserWithOmitempty struct {
    Name     string `json:"name,omitempty"`        // 非空时才序列化
    Age      int    `json:"age,omitempty"`         // 非0时才序列化
}

问题来了,不知道 protoc的命令是什么,导致可以重命名,同时如果这样改动的话 需要序列化反序列化总共三次,也就是将一个结构体的内容 通过 序列化操作赋值到 另外一个结构体【听着是不都头大,多次序列化也很耗性能,此处思路也打开了后续的结果】

  • 利用 protogen 生成自定义插件【一种新姿势】

最终实现

  1. 创建文件夹
mkdir tool
  1. 初始化 mod 文件
go mod init protoc-gen-go-with
  1. mod文件内容【本文以go1.18为版本编译】
module protoc-gen-go-with

go 1.18

require google.golang.org/protobuf v1.26.0
  1. 编写文件内容如下
package main

import (
    "fmt"
    "strconv"
    "strings"

    "google.golang.org/protobuf/compiler/protogen"
    "google.golang.org/protobuf/reflect/protoreflect"
    "google.golang.org/protobuf/types/descriptorpb"
)

func main() {
    protogen.Options{}.Run(func(gen *protogen.Plugin) error {
       for _, f := range gen.Files {
          if !f.Generate {
             continue
          }
          generateFile(gen, f)
       }
       return nil
    })
}

func generateFile(gen *protogen.Plugin, file *protogen.File) {
    filename := file.GeneratedFilenamePrefix + "_omitempty.pb.go"
    g := gen.NewGeneratedFile(filename, file.GoImportPath)

    // 文件头部
    g.P("// Code generated by protoc-gen-go-with. DO NOT EDIT.")
    g.P("// versions:")
    g.P("// - protoc-gen-go-with v1.0.0")
    g.P("// source: ", file.Desc.Path())
    g.P()
    g.P("package ", file.GoPackageName)
    g.P()

    //// 导入必要的包
    //g.P("import (")
    //// 根据需要导入其他依赖包
    //g.P(")")
    g.P()

    // 生成所有消息类型
    for _, message := range file.Messages {
       generateMessage(g, message)
    }
}

func generateMessage(g *protogen.GeneratedFile, message *protogen.Message) {
    // 跳过map entry类型的消息
    if message.Desc.Options().(*descriptorpb.MessageOptions).GetMapEntry() {
       return
    }
    // 特殊处理:跳过 StrInt 消息的处理
    if string(message.GoIdent.GoName) == "StrInt" || string(message.GoIdent.GoName) == "UnionStrInt" {
       return
    }

    // 生成带 WithOmitempty 后缀的结构体
    withTypeName := string(message.GoIdent.GoName) + "WithOmitempty"
    g.P("type ", withTypeName, " struct {")
    for _, field := range message.Fields {
       generateField(g, field)
    }
    g.P("}")
    g.P()

    // 生成转换函数
    generateConvertFunction(g, message)

    // 递归处理嵌套消息
    for _, nested := range message.Messages {
       // 跳过map entry类型的消息
       if !nested.Desc.Options().(*descriptorpb.MessageOptions).GetMapEntry() {
          generateMessage(g, nested)
       }
    }
}

func generateField(g *protogen.GeneratedFile, field *protogen.Field) {
    // 获取字段类型
    fieldType := resolveFieldType(g, field)

    // 获取字段标签
    tags := getFieldTags(field)

    // 写入字段定义
    g.P("\t", field.GoName, " ", fieldType, " ", tags)
}

func resolveFieldType(g *protogen.GeneratedFile, field *protogen.Field) string {
    switch {
    case field.Desc.IsList():
       // 切片类型
       elemType := resolveBaseType(g, field)
       return "[]" + elemType
    case field.Desc.IsMap():
       // 映射类型
       keyField := field.Message.Fields[0]
       valField := field.Message.Fields[1]
       keyType := resolveBaseType(g, keyField)
       valType := resolveBaseType(g, valField)
       return fmt.Sprintf("map[%s]%s", keyType, valType)
    default:
       // 基础类型
       return resolveBaseType(g, field)
    }
}

func resolveBaseType(g *protogen.GeneratedFile, field *protogen.Field) string {
    switch field.Desc.Kind() {
    case protoreflect.BoolKind:
       return "bool"
    case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
       return "int32"
    case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
       return "int64"
    case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
       return "uint32"
    case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
       return "uint64"
    case protoreflect.FloatKind:
       return "float32"
    case protoreflect.DoubleKind:
       return "float64"
    case protoreflect.StringKind:
       return "string"
    case protoreflect.BytesKind:
       return "[]byte"
    case protoreflect.MessageKind:
       // 特殊处理:对于 StrInt 类型,不进行转换,直接使用原始类型
       if field.Message != nil && string(field.Message.GoIdent.GoName) == "StrInt" || string(field.Message.GoIdent.GoName) == "UnionStrInt" {
          return "*" + g.QualifiedGoIdent(field.Message.GoIdent)
       }
       // 检查是否是同一包中的消息
       if field.Message.GoIdent.GoImportPath == field.Parent.GoIdent.GoImportPath {
          // 同一包中的消息,使用带后缀的类型
          return "*" + string(field.Message.GoIdent.GoName) + "WithOmitempty"
       } else {
          // 不同包中的消息,使用原始类型(带包前缀)
          return "*" + g.QualifiedGoIdent(field.Message.GoIdent)
       }
    case protoreflect.EnumKind:
       // 枚举类型
       if field.Enum != nil {
          return g.QualifiedGoIdent(field.Enum.GoIdent)
       }
       return "int32"
    default:
       return "interface{}"
    }
}

func getFieldTags(field *protogen.Field) string {
    var tags []string

    // 保留 protobuf 标签
    protoTags := []string{
       "bytes",
       strconv.FormatInt(int64(field.Desc.Number()), 10),
       getProtoLabel(field),
       "name=" + string(field.Desc.Name()),
    }

    if field.Desc.Kind() != protoreflect.MessageKind && field.Desc.Kind() != protoreflect.GroupKind {
       protoTags = append(protoTags, "proto3")
    }

    if jsonName := field.Desc.JSONName(); jsonName != "" && jsonName != string(field.Desc.Name()) {
       protoTags = append(protoTags, "json="+jsonName)
    }

    tags = append(tags, fmt.Sprintf("protobuf:%s", strconv.Quote(strings.Join(protoTags, ","))))

    // JSON标签中使用下划线命名格式
    if jsonName := field.Desc.JSONName(); jsonName != "" {
       snakeName := camelToSnake(jsonName)
       tags = append(tags, fmt.Sprintf("json:%s", strconv.Quote(snakeName+",omitempty")))
    }

    return "`" + strings.Join(tags, " ") + "`"
}

// 驼峰命名转下划线命名
func camelToSnake(name string) string {
    var result strings.Builder
    for i, r := range name {
       if i > 0 && r >= 'A' && r <= 'Z' {
          result.WriteRune('_')
       }
       result.WriteRune(r)
    }
    return strings.ToLower(result.String())
}

func getProtoLabel(field *protogen.Field) string {
    if field.Desc.IsList() {
       return "rep"
    }
    if field.Desc.HasOptionalKeyword() {
       return "opt"
    }
    return "req"
}

func generateConvertFunction(g *protogen.GeneratedFile, message *protogen.Message) {
    originalType := string(message.GoIdent.GoName)
    withType := originalType + "WithOmitempty"

    // 函数签名
    g.P("func To", withType, "(src *", originalType, ") *", withType, " {")
    g.P("\tif src == nil {")
    g.P("\t\treturn nil")
    g.P("\t}")
    g.P()

    // 创建新实例
    g.P("\tresult := &", withType, "{}")

    // 添加标志变量,用于跟踪是否有字段被赋值
    g.P("\tassigned := false")

    // 复制字段
    for _, field := range message.Fields {
       copyFieldWithEmptyCheck(g, field)
    }

    // 如果没有任何字段被赋值,则返回nil
    g.P("\tif !assigned {")
    g.P("\t\treturn nil")
    g.P("\t}")

    g.P("\treturn result")
    g.P("}")
    g.P()
}

func copyFieldWithEmptyCheck(g *protogen.GeneratedFile, field *protogen.Field) {
    fieldName := field.GoName

    switch {
    case field.Desc.IsList():
       // 切片字段
       if field.Message != nil {
          // 特殊处理:对于 StrInt 类型,直接复制而不转换
          if string(field.Message.GoIdent.GoName) == "StrInt" || string(field.Message.GoIdent.GoName) == "UnionStrInt" {
             g.P("\tif len(src.", fieldName, ") > 0 {")
             g.P("\t\tresult.", fieldName, " = src.", fieldName)
             g.P("\t\tassigned = true")
             g.P("\t}")
          } else if field.Message.GoIdent.GoImportPath == field.Parent.GoIdent.GoImportPath {
             // 消息切片 - 需要转换每个元素
             g.P("\tif len(src.", fieldName, ") > 0 {")
             g.P("\t\tresult.", fieldName, " = make([]*", string(field.Message.GoIdent.GoName), "WithOmitempty, len(src.", fieldName, "))")
             g.P("\t\tfor i, item := range src.", fieldName, " {")
             g.P("\t\t\tresult.", fieldName, "[i] = To", string(field.Message.GoIdent.GoName), "WithOmitempty(item)")
             g.P("\t\t}")
             g.P("\t\tassigned = true")
             g.P("\t}")
          } else {
             // 不同包中的消息,直接复制切片
             g.P("\tif len(src.", fieldName, ") > 0 {")
             g.P("\t\tresult.", fieldName, " = src.", fieldName)
             g.P("\t\tassigned = true")
             g.P("\t}")
          }
       } else {
          // 基础类型切片 - 直接复制
          g.P("\tif len(src.", fieldName, ") > 0 {")
          g.P("\t\tresult.", fieldName, " = append([]", resolveBaseType(g, field), "(nil), src.", fieldName, "...)")
          g.P("\t\tassigned = true")
          g.P("\t}")
       }
    case field.Desc.IsMap():
       // 映射字段
       g.P("\tif len(src.", fieldName, ") > 0 {")
       keyType := resolveBaseType(g, field.Message.Fields[0])
       valType := resolveBaseType(g, field.Message.Fields[1])
       g.P("\t\tresult.", fieldName, " = make(map[", keyType, "]", valType, ")")
       g.P("\t\tfor k, v := range src.", fieldName, " {")
       g.P("\t\t\tif v != "" {")
       g.P("\t\t\tresult.", fieldName, "[k] = v")
       g.P("\t\t\t}")
       g.P("\t\t}")
       g.P("\t\tassigned = true")
       g.P("\t}")
    default:
       // 单值字段
       if field.Message != nil {
          if string(field.Message.GoIdent.GoName) == "StrInt" || string(field.Message.GoIdent.GoName) == "UnionStrInt" {
             // 对于 StrInt 类型,检查是否为 nil
             g.P("\tif src.", fieldName, " != nil {")
             g.P("\t\tresult.", fieldName, " = src.", fieldName)
             g.P("\t\tassigned = true")
             g.P("\t}")
          } else if field.Message.GoIdent.GoImportPath == field.Parent.GoIdent.GoImportPath {
             // 同一包中的消息,需要转换
             g.P("\tif src.", fieldName, " != nil {")
             g.P("\t\tconverted := To", string(field.Message.GoIdent.GoName), "WithOmitempty(src.", fieldName, ")")
             g.P("\t\tif converted != nil {")
             g.P("\t\t\tresult.", fieldName, " = converted")
             g.P("\t\t\tassigned = true")
             g.P("\t\t}")
             g.P("\t}")
          } else {
             // 不同包中的消息,直接赋值
             g.P("\tif src.", fieldName, " != nil {")
             g.P("\t\tresult.", fieldName, " = src.", fieldName)
             g.P("\t\tassigned = true")
             g.P("\t}")
          }
       } else {
          // 基础类型 - 根据类型进行空值检查
          switch field.Desc.Kind() {
          case protoreflect.StringKind:
             g.P("\tif src.", fieldName, " != "" {")
             g.P("\t\tresult.", fieldName, " = src.", fieldName)
             g.P("\t\tassigned = true")
             g.P("\t}")
          case protoreflect.BoolKind:
             // 布尔类型默认为false,所以总是赋值
             g.P("\tresult.", fieldName, " = src.", fieldName)
             g.P("\tassigned = true")
          case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind,
             protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind,
             protoreflect.Uint32Kind, protoreflect.Fixed32Kind,
             protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
             g.P("\tif src.", fieldName, " != 0 {")
             g.P("\t\tresult.", fieldName, " = src.", fieldName)
             g.P("\t\tassigned = true")
             g.P("\t}")
          case protoreflect.FloatKind, protoreflect.DoubleKind:
             g.P("\tif src.", fieldName, " != 0.0 {")
             g.P("\t\tresult.", fieldName, " = src.", fieldName)
             g.P("\t\tassigned = true")
             g.P("\t}")
          default:
             // 其他类型默认总是赋值
             g.P("\tresult.", fieldName, " = src.", fieldName)
             g.P("\tassigned = true")
          }
       }
    }
}

func copyField(g *protogen.GeneratedFile, field *protogen.Field) {
    fieldName := field.GoName

    switch {
    case field.Desc.IsList():
       // 切片字段
       if field.Message != nil {
          // 特殊处理:对于 StrInt 类型,直接复制而不转换
          if string(field.Message.GoIdent.GoName) == "StrInt" || string(field.Message.GoIdent.GoName) == "UnionStrInt" {
             g.P("\tresult.", fieldName, " = src.", fieldName)
          } else if field.Message.GoIdent.GoImportPath == field.Parent.GoIdent.GoImportPath {
             // 消息切片 - 需要转换每个元素
             g.P("\tif src.", fieldName, " != nil {")
             g.P("\t\tresult.", fieldName, " = make([]*", string(field.Message.GoIdent.GoName), "WithOmitempty, len(src.", fieldName, "))")
             g.P("\t\tfor i, item := range src.", fieldName, " {")
             g.P("\t\t\tresult.", fieldName, "[i] = To", string(field.Message.GoIdent.GoName), "WithOmitempty(item)")
             g.P("\t\t}")
             g.P("\t}")
          } else {
             // 不同包中的消息,直接复制切片
             g.P("\tresult.", fieldName, " = src.", fieldName)
          }
       } else {
          // 基础类型切片 - 直接复制
          g.P("\tresult.", fieldName, " = append([]", resolveBaseType(g, field), "(nil), src.", fieldName, "...)")
       }
    case field.Desc.IsMap():
       // 基础类型映射 - 直接复制
       keyType := resolveBaseType(g, field.Message.Fields[0])
       valType := resolveBaseType(g, field.Message.Fields[1])
       g.P("\tif src.", fieldName, " != nil {")
       g.P("\t\tresult.", fieldName, " = make(map[", keyType, "]", valType, ")")
       g.P("\t\tfor k, v := range src.", fieldName, " {")
       g.P("\t\t\tresult.", fieldName, "[k] = v")
       g.P("\t\t}")
       g.P("\t}")
    default:
       // 单值字段
       if field.Message != nil {
          if string(field.Message.GoIdent.GoName) == "StrInt" || string(field.Message.GoIdent.GoName) == "UnionStrInt" {
             g.P("\tresult.", fieldName, " = src.", fieldName)
          } else if field.Message.GoIdent.GoImportPath == field.Parent.GoIdent.GoImportPath {
             // 同一包中的消息,需要转换
             g.P("\tif src.", fieldName, " != nil {")
             g.P("\t\tresult.", fieldName, " = To", string(field.Message.GoIdent.GoName), "WithOmitempty(src.", fieldName, ")")
             g.P("\t}")
          } else {
             // 不同包中的消息,直接赋值
             g.P("\tresult.", fieldName, " = src.", fieldName)
          }
       } else {
          // 基础类型 - 直接赋值
          g.P("\tresult.", fieldName, " = src.", fieldName)
       }
    }
}
  1. 执行编译语句
go build -o protoc-gen-go-with 
  1. 移动二进制文件到 GOPATH
 mv protoc-gen-go-with $GOPATH/bin/
  1. 执行protoc命令
protoc -I=/project/test -I=. --go-with_out=. \
  --go-with_opt=paths=source_relative \
  test.proto

其中 -I 为引入依赖项,本 case 中为 /project/test

image.png

如果生成的有问题,可以根据代码报错问问AI。【本文仅提供思路用于学习】