背景
在与SDK交互过程中,v1.0.0 版本手动处理了 proto生成的 .go 文件,导致去除了 ,omitempty 字段。从而导致全部字段均会下发,浪费网络带宽。考虑到兼容性问题,想在 v2.0.0 版本进行升级。即根据版本可分别下发【全部字段】和【仅下发有效值】。
思考
基于该背景,于是考虑根据版本分别下发不一样的数据。
- 基于版本用反射去重新生成标签【考虑 reflect 的性能问题,放弃】
- 基于自定义标签控制
JSON序列化,比如基础版代码【更多可以找AI编写】
package main
import (
"encoding/json"
"fmt"
"reflect"
"strconv"
)
// 自定义序列化条件标签
type User struct {
Name string `json:"name" condition:"nonzero"` // 非空时才序列化
Age int `json:"age" condition:"nonzero"` // 非0时才序列化
Email string `json:"email" condition:"always"` // 总是序列化
Score int `json:"score" condition:"gt:0"` // 大于0时才序列化
Status string `json:"status" condition:"in:active,pending"` // 在指定值中时才序列化
Optional string `json:"optional" condition:"custom:CheckOptional"` // 自定义检查函数
}
// 自定义检查函数
func (u User) CheckOptional(fieldValue string) bool {
return fieldValue != "" && fieldValue != "null"
}
// 实现自定义 MarshalJSON
func (u User) MarshalJSON() ([]byte, error) {
m := make(map[string]interface{})
v := reflect.ValueOf(u)
t := reflect.TypeOf(u)
for i := 0; i < v.NumField(); i++ {
field := t.Field(i)
fieldValue := v.Field(i)
// 获取 JSON 标签名
jsonTag := field.Tag.Get("json")
if jsonTag == "" || jsonTag == "-" {
continue
}
// 获取条件标签
conditionTag := field.Tag.Get("condition")
if conditionTag == "" {
// 没有条件标签,直接添加
m[jsonTag] = fieldValue.Interface()
continue
}
// 检查是否满足序列化条件
if shouldSerialize(fieldValue, conditionTag, u) {
m[jsonTag] = fieldValue.Interface()
}
}
return json.Marshal(m)
}
// 条件检查函数
func shouldSerialize(fieldValue reflect.Value, condition string, obj interface{}) bool {
switch {
case condition == "always":
return true
case condition == "nonzero":
return !isZeroValue(fieldValue)
case len(condition) > 3 && condition[:3] == "gt:":
threshold, _ := strconv.Atoi(condition[3:])
if fieldValue.Kind() == reflect.Int {
return fieldValue.Int() > int64(threshold)
}
return false
case len(condition) > 3 && condition[:3] == "in:":
allowedValues := condition[3:]
currentValue := fmt.Sprintf("%v", fieldValue.Interface())
// 简单的包含检查,实际应用中需要更严谨的实现
return contains(allowedValues, currentValue)
case len(condition) > 7 && condition[:7] == "custom:":
methodName := condition[7:]
objValue := reflect.ValueOf(obj)
method := objValue.MethodByName(methodName)
if method.IsValid() {
// 调用自定义检查方法
result := method.Call([]reflect.Value{fieldValue})
if len(result) > 0 {
return result[0].Bool()
}
}
return false
}
return true
}
// 辅助函数
func isZeroValue(v reflect.Value) bool {
switch v.Kind() {
case reflect.String:
return v.String() == ""
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return v.Float() == 0
case reflect.Bool:
return !v.Bool()
case reflect.Ptr, reflect.Interface:
return v.IsNil()
case reflect.Slice, reflect.Map:
return v.Len() == 0
default:
return false
}
}
func contains(str string, substr string) bool {
// 简化实现,实际需要按逗号分割等
return len(str) > 0 && len(substr) > 0
}
但是该方法侵入性较大,也不太适合。
- 利用
protoc重新生成一份.go文件,其结构体携带,omitempty字段,只是结构体名称不一样。比如:
type User struct {
Name string `json:"name"`
Age int `json:"age"`
}
变更为
type UserWithOmitempty struct {
Name string `json:"name,omitempty"` // 非空时才序列化
Age int `json:"age,omitempty"` // 非0时才序列化
}
问题来了,不知道 protoc的命令是什么,导致可以重命名,同时如果这样改动的话 需要序列化反序列化总共三次,也就是将一个结构体的内容 通过 序列化操作赋值到 另外一个结构体【听着是不都头大,多次序列化也很耗性能,此处思路也打开了后续的结果】
- 利用
protogen生成自定义插件【一种新姿势】
最终实现
- 创建文件夹
mkdir tool
- 初始化
mod文件
go mod init protoc-gen-go-with
- mod文件内容【本文以go1.18为版本编译】
module protoc-gen-go-with
go 1.18
require google.golang.org/protobuf v1.26.0
- 编写文件内容如下
package main
import (
"fmt"
"strconv"
"strings"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/descriptorpb"
)
func main() {
protogen.Options{}.Run(func(gen *protogen.Plugin) error {
for _, f := range gen.Files {
if !f.Generate {
continue
}
generateFile(gen, f)
}
return nil
})
}
func generateFile(gen *protogen.Plugin, file *protogen.File) {
filename := file.GeneratedFilenamePrefix + "_omitempty.pb.go"
g := gen.NewGeneratedFile(filename, file.GoImportPath)
// 文件头部
g.P("// Code generated by protoc-gen-go-with. DO NOT EDIT.")
g.P("// versions:")
g.P("// - protoc-gen-go-with v1.0.0")
g.P("// source: ", file.Desc.Path())
g.P()
g.P("package ", file.GoPackageName)
g.P()
//// 导入必要的包
//g.P("import (")
//// 根据需要导入其他依赖包
//g.P(")")
g.P()
// 生成所有消息类型
for _, message := range file.Messages {
generateMessage(g, message)
}
}
func generateMessage(g *protogen.GeneratedFile, message *protogen.Message) {
// 跳过map entry类型的消息
if message.Desc.Options().(*descriptorpb.MessageOptions).GetMapEntry() {
return
}
// 特殊处理:跳过 StrInt 消息的处理
if string(message.GoIdent.GoName) == "StrInt" || string(message.GoIdent.GoName) == "UnionStrInt" {
return
}
// 生成带 WithOmitempty 后缀的结构体
withTypeName := string(message.GoIdent.GoName) + "WithOmitempty"
g.P("type ", withTypeName, " struct {")
for _, field := range message.Fields {
generateField(g, field)
}
g.P("}")
g.P()
// 生成转换函数
generateConvertFunction(g, message)
// 递归处理嵌套消息
for _, nested := range message.Messages {
// 跳过map entry类型的消息
if !nested.Desc.Options().(*descriptorpb.MessageOptions).GetMapEntry() {
generateMessage(g, nested)
}
}
}
func generateField(g *protogen.GeneratedFile, field *protogen.Field) {
// 获取字段类型
fieldType := resolveFieldType(g, field)
// 获取字段标签
tags := getFieldTags(field)
// 写入字段定义
g.P("\t", field.GoName, " ", fieldType, " ", tags)
}
func resolveFieldType(g *protogen.GeneratedFile, field *protogen.Field) string {
switch {
case field.Desc.IsList():
// 切片类型
elemType := resolveBaseType(g, field)
return "[]" + elemType
case field.Desc.IsMap():
// 映射类型
keyField := field.Message.Fields[0]
valField := field.Message.Fields[1]
keyType := resolveBaseType(g, keyField)
valType := resolveBaseType(g, valField)
return fmt.Sprintf("map[%s]%s", keyType, valType)
default:
// 基础类型
return resolveBaseType(g, field)
}
}
func resolveBaseType(g *protogen.GeneratedFile, field *protogen.Field) string {
switch field.Desc.Kind() {
case protoreflect.BoolKind:
return "bool"
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
return "int32"
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
return "int64"
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
return "uint32"
case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
return "uint64"
case protoreflect.FloatKind:
return "float32"
case protoreflect.DoubleKind:
return "float64"
case protoreflect.StringKind:
return "string"
case protoreflect.BytesKind:
return "[]byte"
case protoreflect.MessageKind:
// 特殊处理:对于 StrInt 类型,不进行转换,直接使用原始类型
if field.Message != nil && string(field.Message.GoIdent.GoName) == "StrInt" || string(field.Message.GoIdent.GoName) == "UnionStrInt" {
return "*" + g.QualifiedGoIdent(field.Message.GoIdent)
}
// 检查是否是同一包中的消息
if field.Message.GoIdent.GoImportPath == field.Parent.GoIdent.GoImportPath {
// 同一包中的消息,使用带后缀的类型
return "*" + string(field.Message.GoIdent.GoName) + "WithOmitempty"
} else {
// 不同包中的消息,使用原始类型(带包前缀)
return "*" + g.QualifiedGoIdent(field.Message.GoIdent)
}
case protoreflect.EnumKind:
// 枚举类型
if field.Enum != nil {
return g.QualifiedGoIdent(field.Enum.GoIdent)
}
return "int32"
default:
return "interface{}"
}
}
func getFieldTags(field *protogen.Field) string {
var tags []string
// 保留 protobuf 标签
protoTags := []string{
"bytes",
strconv.FormatInt(int64(field.Desc.Number()), 10),
getProtoLabel(field),
"name=" + string(field.Desc.Name()),
}
if field.Desc.Kind() != protoreflect.MessageKind && field.Desc.Kind() != protoreflect.GroupKind {
protoTags = append(protoTags, "proto3")
}
if jsonName := field.Desc.JSONName(); jsonName != "" && jsonName != string(field.Desc.Name()) {
protoTags = append(protoTags, "json="+jsonName)
}
tags = append(tags, fmt.Sprintf("protobuf:%s", strconv.Quote(strings.Join(protoTags, ","))))
// JSON标签中使用下划线命名格式
if jsonName := field.Desc.JSONName(); jsonName != "" {
snakeName := camelToSnake(jsonName)
tags = append(tags, fmt.Sprintf("json:%s", strconv.Quote(snakeName+",omitempty")))
}
return "`" + strings.Join(tags, " ") + "`"
}
// 驼峰命名转下划线命名
func camelToSnake(name string) string {
var result strings.Builder
for i, r := range name {
if i > 0 && r >= 'A' && r <= 'Z' {
result.WriteRune('_')
}
result.WriteRune(r)
}
return strings.ToLower(result.String())
}
func getProtoLabel(field *protogen.Field) string {
if field.Desc.IsList() {
return "rep"
}
if field.Desc.HasOptionalKeyword() {
return "opt"
}
return "req"
}
func generateConvertFunction(g *protogen.GeneratedFile, message *protogen.Message) {
originalType := string(message.GoIdent.GoName)
withType := originalType + "WithOmitempty"
// 函数签名
g.P("func To", withType, "(src *", originalType, ") *", withType, " {")
g.P("\tif src == nil {")
g.P("\t\treturn nil")
g.P("\t}")
g.P()
// 创建新实例
g.P("\tresult := &", withType, "{}")
// 添加标志变量,用于跟踪是否有字段被赋值
g.P("\tassigned := false")
// 复制字段
for _, field := range message.Fields {
copyFieldWithEmptyCheck(g, field)
}
// 如果没有任何字段被赋值,则返回nil
g.P("\tif !assigned {")
g.P("\t\treturn nil")
g.P("\t}")
g.P("\treturn result")
g.P("}")
g.P()
}
func copyFieldWithEmptyCheck(g *protogen.GeneratedFile, field *protogen.Field) {
fieldName := field.GoName
switch {
case field.Desc.IsList():
// 切片字段
if field.Message != nil {
// 特殊处理:对于 StrInt 类型,直接复制而不转换
if string(field.Message.GoIdent.GoName) == "StrInt" || string(field.Message.GoIdent.GoName) == "UnionStrInt" {
g.P("\tif len(src.", fieldName, ") > 0 {")
g.P("\t\tresult.", fieldName, " = src.", fieldName)
g.P("\t\tassigned = true")
g.P("\t}")
} else if field.Message.GoIdent.GoImportPath == field.Parent.GoIdent.GoImportPath {
// 消息切片 - 需要转换每个元素
g.P("\tif len(src.", fieldName, ") > 0 {")
g.P("\t\tresult.", fieldName, " = make([]*", string(field.Message.GoIdent.GoName), "WithOmitempty, len(src.", fieldName, "))")
g.P("\t\tfor i, item := range src.", fieldName, " {")
g.P("\t\t\tresult.", fieldName, "[i] = To", string(field.Message.GoIdent.GoName), "WithOmitempty(item)")
g.P("\t\t}")
g.P("\t\tassigned = true")
g.P("\t}")
} else {
// 不同包中的消息,直接复制切片
g.P("\tif len(src.", fieldName, ") > 0 {")
g.P("\t\tresult.", fieldName, " = src.", fieldName)
g.P("\t\tassigned = true")
g.P("\t}")
}
} else {
// 基础类型切片 - 直接复制
g.P("\tif len(src.", fieldName, ") > 0 {")
g.P("\t\tresult.", fieldName, " = append([]", resolveBaseType(g, field), "(nil), src.", fieldName, "...)")
g.P("\t\tassigned = true")
g.P("\t}")
}
case field.Desc.IsMap():
// 映射字段
g.P("\tif len(src.", fieldName, ") > 0 {")
keyType := resolveBaseType(g, field.Message.Fields[0])
valType := resolveBaseType(g, field.Message.Fields[1])
g.P("\t\tresult.", fieldName, " = make(map[", keyType, "]", valType, ")")
g.P("\t\tfor k, v := range src.", fieldName, " {")
g.P("\t\t\tif v != "" {")
g.P("\t\t\tresult.", fieldName, "[k] = v")
g.P("\t\t\t}")
g.P("\t\t}")
g.P("\t\tassigned = true")
g.P("\t}")
default:
// 单值字段
if field.Message != nil {
if string(field.Message.GoIdent.GoName) == "StrInt" || string(field.Message.GoIdent.GoName) == "UnionStrInt" {
// 对于 StrInt 类型,检查是否为 nil
g.P("\tif src.", fieldName, " != nil {")
g.P("\t\tresult.", fieldName, " = src.", fieldName)
g.P("\t\tassigned = true")
g.P("\t}")
} else if field.Message.GoIdent.GoImportPath == field.Parent.GoIdent.GoImportPath {
// 同一包中的消息,需要转换
g.P("\tif src.", fieldName, " != nil {")
g.P("\t\tconverted := To", string(field.Message.GoIdent.GoName), "WithOmitempty(src.", fieldName, ")")
g.P("\t\tif converted != nil {")
g.P("\t\t\tresult.", fieldName, " = converted")
g.P("\t\t\tassigned = true")
g.P("\t\t}")
g.P("\t}")
} else {
// 不同包中的消息,直接赋值
g.P("\tif src.", fieldName, " != nil {")
g.P("\t\tresult.", fieldName, " = src.", fieldName)
g.P("\t\tassigned = true")
g.P("\t}")
}
} else {
// 基础类型 - 根据类型进行空值检查
switch field.Desc.Kind() {
case protoreflect.StringKind:
g.P("\tif src.", fieldName, " != "" {")
g.P("\t\tresult.", fieldName, " = src.", fieldName)
g.P("\t\tassigned = true")
g.P("\t}")
case protoreflect.BoolKind:
// 布尔类型默认为false,所以总是赋值
g.P("\tresult.", fieldName, " = src.", fieldName)
g.P("\tassigned = true")
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind,
protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind,
protoreflect.Uint32Kind, protoreflect.Fixed32Kind,
protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
g.P("\tif src.", fieldName, " != 0 {")
g.P("\t\tresult.", fieldName, " = src.", fieldName)
g.P("\t\tassigned = true")
g.P("\t}")
case protoreflect.FloatKind, protoreflect.DoubleKind:
g.P("\tif src.", fieldName, " != 0.0 {")
g.P("\t\tresult.", fieldName, " = src.", fieldName)
g.P("\t\tassigned = true")
g.P("\t}")
default:
// 其他类型默认总是赋值
g.P("\tresult.", fieldName, " = src.", fieldName)
g.P("\tassigned = true")
}
}
}
}
func copyField(g *protogen.GeneratedFile, field *protogen.Field) {
fieldName := field.GoName
switch {
case field.Desc.IsList():
// 切片字段
if field.Message != nil {
// 特殊处理:对于 StrInt 类型,直接复制而不转换
if string(field.Message.GoIdent.GoName) == "StrInt" || string(field.Message.GoIdent.GoName) == "UnionStrInt" {
g.P("\tresult.", fieldName, " = src.", fieldName)
} else if field.Message.GoIdent.GoImportPath == field.Parent.GoIdent.GoImportPath {
// 消息切片 - 需要转换每个元素
g.P("\tif src.", fieldName, " != nil {")
g.P("\t\tresult.", fieldName, " = make([]*", string(field.Message.GoIdent.GoName), "WithOmitempty, len(src.", fieldName, "))")
g.P("\t\tfor i, item := range src.", fieldName, " {")
g.P("\t\t\tresult.", fieldName, "[i] = To", string(field.Message.GoIdent.GoName), "WithOmitempty(item)")
g.P("\t\t}")
g.P("\t}")
} else {
// 不同包中的消息,直接复制切片
g.P("\tresult.", fieldName, " = src.", fieldName)
}
} else {
// 基础类型切片 - 直接复制
g.P("\tresult.", fieldName, " = append([]", resolveBaseType(g, field), "(nil), src.", fieldName, "...)")
}
case field.Desc.IsMap():
// 基础类型映射 - 直接复制
keyType := resolveBaseType(g, field.Message.Fields[0])
valType := resolveBaseType(g, field.Message.Fields[1])
g.P("\tif src.", fieldName, " != nil {")
g.P("\t\tresult.", fieldName, " = make(map[", keyType, "]", valType, ")")
g.P("\t\tfor k, v := range src.", fieldName, " {")
g.P("\t\t\tresult.", fieldName, "[k] = v")
g.P("\t\t}")
g.P("\t}")
default:
// 单值字段
if field.Message != nil {
if string(field.Message.GoIdent.GoName) == "StrInt" || string(field.Message.GoIdent.GoName) == "UnionStrInt" {
g.P("\tresult.", fieldName, " = src.", fieldName)
} else if field.Message.GoIdent.GoImportPath == field.Parent.GoIdent.GoImportPath {
// 同一包中的消息,需要转换
g.P("\tif src.", fieldName, " != nil {")
g.P("\t\tresult.", fieldName, " = To", string(field.Message.GoIdent.GoName), "WithOmitempty(src.", fieldName, ")")
g.P("\t}")
} else {
// 不同包中的消息,直接赋值
g.P("\tresult.", fieldName, " = src.", fieldName)
}
} else {
// 基础类型 - 直接赋值
g.P("\tresult.", fieldName, " = src.", fieldName)
}
}
}
- 执行编译语句
go build -o protoc-gen-go-with
- 移动二进制文件到
GOPATH下
mv protoc-gen-go-with $GOPATH/bin/
- 执行
protoc命令
protoc -I=/project/test -I=. --go-with_out=. \
--go-with_opt=paths=source_relative \
test.proto
其中 -I 为引入依赖项,本 case 中为 /project/test
如果生成的有问题,可以根据代码报错问问AI。【本文仅提供思路用于学习】