go-zero 使用gorm,丢失追踪信息,解决方案

61 阅读1分钟

直接上代码

package common

import (
    "context"
    "database/sql"
    "errors"

    ztrace "github.com/zeromicro/go-zero/core/trace"
    "go.opentelemetry.io/otel"
    "go.opentelemetry.io/otel/attribute"
    "go.opentelemetry.io/otel/baggage"
    "go.opentelemetry.io/otel/codes"
    "go.opentelemetry.io/otel/trace"
    oteltrace "go.opentelemetry.io/otel/trace"
    "google.golang.org/grpc/metadata"
)

type Option func(r oteltrace.Span)

func StartSpan(ctx context.Context, name string, opts ...Option) (context.Context, oteltrace.Span) {
    start, span := startSpan(ctx, name)
    for _, opt := range opts {
       opt(span)
    }

    return start, span
}

func startSpan(ctx context.Context, method string) (context.Context, trace.Span) {
    md, ok := metadata.FromIncomingContext(ctx)
    if !ok {
       md = metadata.MD{}
    }
    bags, spanCtx := ztrace.Extract(ctx, otel.GetTextMapPropagator(), &md)
    ctx = baggage.ContextWithBaggage(ctx, bags)
    tr := otel.Tracer(ztrace.TraceName)
    name, attr := ztrace.SpanInfo(method, ztrace.PeerFromCtx(ctx))

    return tr.Start(oteltrace.ContextWithRemoteSpanContext(ctx, spanCtx), name,
       trace.WithSpanKind(oteltrace.SpanKindServer), oteltrace.WithAttributes(attr...))
}

func EndSpan(span oteltrace.Span, err error) {
    defer span.End()

    if err == nil || errors.Is(err, sql.ErrNoRows) {
       span.SetStatus(codes.Ok, "")
       return
    }

    span.SetStatus(codes.Error, err.Error())
    span.RecordError(err)
}

// WithTracing 通用追踪包装函数,用于不同的数据库操作
func WithTracing(ctx context.Context, operationName string, dbOp func(ctx context.Context) error) error {
    c, span := StartSpan(ctx, "sql", func(r oteltrace.Span) {
       r.SetAttributes(attribute.Key("sql.method").String(operationName))
    })
    var err error
    defer func() {
       EndSpan(span, err)
    }()

    // 执行数据库操作并记录错误
    err = dbOp(c)

    return err
}

func WithHttpsTracing(ctx context.Context, operationName string, httpOp func(ctx context.Context) error) error {
    c, span := StartSpan(ctx, "http", func(r oteltrace.Span) {
       r.SetAttributes(
          attribute.Key("http.operation").String(operationName),
       )
    })
    var err error
    defer func() {
       defer span.End()

       if err == nil {
          span.SetStatus(codes.Ok, "")
          return
       }

       span.SetStatus(codes.Error, err.Error())
       span.RecordError(err)
    }()

    // 执行HTTP操作并记录错误
    err = httpOp(c)

    return err
}

type BranchHeaders map[string]string

func NewBranchHeaders() BranchHeaders {
    return make(BranchHeaders)
}

// Get 实现 TextMapCarrier 接口的方法,从 map 中获取值
func (c BranchHeaders) Get(key string) string {
    return c[key]
}

// Set 实现 TextMapCarrier 接口的方法,设置 map 的键值
func (c BranchHeaders) Set(key, value string) {
    c[key] = value
}

// Keys 实现 TextMapCarrier 接口的方法,返回所有键
func (c BranchHeaders) Keys() []string {
    keys := make([]string, 0, len(c))
    for k := range c {
       keys = append(keys, k)
    }
    return keys
}

使用

func (d *DAO) GetGroupByIds(ctx context.Context, ids []int64) ([]*rpc.Group, error) {
    var result []*UserGroup
    err := common.WithTracing(ctx, "GetGroupByIds", func(c context.Context) error {
       return d.DB.WithContext(c).Model(&UserGroup{}).Where("id IN (?)", ids).Find(&result).Error
    })
    if err != nil {
       return nil, err
    }
    return convertGroupByIds2Rpc(result), nil
}