A2A协议如何实现认证插件? 一些奇奇怪怪的想法

33 阅读16分钟

A2A 协议概述

A2A(Agent-to-Agent)协议是一个开放的通信协议,允许不同框架和供应商构建的AI代理之间进行互操作。根据提供的文档,A2A协议的核心组件包括:

  • AgentCard:描述代理能力的元数据文件
  • A2A Server:实现A2A协议方法的HTTP端点
  • A2A Client:消费A2A服务的应用程序或另一个代理
  • Task:工作的中心单元,包含消息和状态 1

基于A2A认证的认证插件实现

A2A协议包含一个认证系统,主要通过在通用类型中定义的Agent Authentication和AuthenticationInfo类实现。这些类定义了认证方案和凭证的结构,它们在整个A2A协议中使用MCP作为A2A的底层框架, 认证集成是在ADK Host Manager中实现的. 具有以下功能:

    1. API密钥认证
    1. Vertex AI认证
    1. API密钥更新 

A2A中的会话管理与认证系统紧密集成。ADK Host Manager初始化并管理会话 。会话结构在初始化过程中包括认证信息。在代理之间发送消息时,认证信息包含在通信中。Host Agent实现在将任务委派给远程代理时使用此认证信息.

以下是一个基于Go语言实现的A2A认证MCP插件的代码框架:

1. 项目结构

mcp-a2a-plugin/
├── main.go
├── server/
│   └── server.go
├── client/
│   └── client.go
├── types/
│   └── types.go
└── auth/
    └── auth.go

2. 类型定义 (types/types.go)

package types

// AgentCard 描述代理的能力和元数据
type AgentCard struct {
    Name        string     `json:"name"`
    Description string     `json:"description"`
    URL         string     `json:"url"`
    Provider    Provider   `json:"provider"`
    Version     string     `json:"version"`
    Capabilities Capabilities `json:"capabilities"`
    Authentication *Authentication `json:"authentication"`
    DefaultInputModes []string `json:"defaultInputModes"`
    DefaultOutputModes []string `json:"defaultOutputModes"`
    Skills      []Skill    `json:"skills"`
}

type Provider struct {
    Organization string `json:"organization"`
}

type Capabilities struct {
    Streaming bool `json:"streaming"`
    PushNotifications bool `json:"pushNotifications"`
    StateTransitionHistory bool `json:"stateTransitionHistory"`
}

type Authentication struct {
    Schemes []string `json:"schemes"`
    // 其他认证相关字段
}

type Skill struct {
    ID          string   `json:"id"`
    Name        string   `json:"name"`
    Description string   `json:"description"`
    Tags        []string `json:"tags"`
    Examples    []string `json:"examples"`
}

// Task 代表A2A协议中的任务
type Task struct {
    ID        string     `json:"id"`
    SessionID string     `json:"sessionId,omitempty"`
    Status    TaskStatus `json:"status"`
    Artifacts []Artifact `json:"artifacts"`
    Metadata  map[string]interface{} `json:"metadata,omitempty"`
}

// TaskStatus 表示任务的状态
type TaskStatus struct {
    State     string    `json:"state"` // submitted, working, input-required, completed, failed, canceled
    Timestamp string    `json:"timestamp"`
    Message   *Message  `json:"message,omitempty"`
}

// Message 表示通信回合
type Message struct {
    Role     string    `json:"role"` // user 或 agent
    Parts    []Part    `json:"parts"`
    Metadata map[string]interface{} `json:"metadata,omitempty"`
}

// Part 是消息或工件中的基本内容单元
type Part struct {
    Type string `json:"type"` // text, file, data
    Text string `json:"text,omitempty"`
    File *File  `json:"file,omitempty"`
    Data interface{} `json:"data,omitempty"`
}

// File 表示文件部分
type File struct {
    Name  string `json:"name"`
    Bytes string `json:"bytes,omitempty"` // base64编码
    URI   string `json:"uri,omitempty"`
}

// Artifact 表示代理生成的输出
type Artifact struct {
    Parts []Part `json:"parts"`
}

// JSONRPC请求和响应结构
type JSONRPCRequest struct {
    JSONRPC string      `json:"jsonrpc"`
    Method  string      `json:"method"`
    Params  interface{} `json:"params"`
    ID      string      `json:"id"`
}

type JSONRPCResponse struct {
    JSONRPC string      `json:"jsonrpc"`
    Result  interface{} `json:"result,omitempty"`
    Error   *JSONRPCError `json:"error,omitempty"`
    ID      string      `json:"id"`
}

type JSONRPCError struct {
    Code    int         `json:"code"`
    Message string      `json:"message"`
    Data    interface{} `json:"data,omitempty"`
}

// 任务参数
type TaskSendParams struct {
    ID                string   `json:"id"`
    SessionID         string   `json:"sessionId,omitempty"`
    Message           Message  `json:"message"`
    AcceptedOutputModes []string `json:"acceptedOutputModes,omitempty"`
    PushNotification  *PushNotification `json:"pushNotification,omitempty"`
    Metadata          map[string]interface{} `json:"metadata,omitempty"`
}

type PushNotification struct {
    URL            string         `json:"url"`
    Authentication *Authentication `json:"authentication,omitempty"`
}

3. 认证模块 (auth/auth.go)

package auth

import (
    "context"
    "errors"
    "net/http"
    "strings"
)

// Authenticator 定义认证接口
type Authenticator interface {
    Authenticate(r *http.Request) (bool, error)
}

// BearerAuthenticator 实现基于Bearer令牌的认证
type BearerAuthenticator struct {
    Tokens map[string]bool // 有效令牌映射
}

// NewBearerAuthenticator 创建新的Bearer认证器
func NewBearerAuthenticator(tokens []string) *BearerAuthenticator {
    tokenMap := make(map[string]bool)
    for _, token := range tokens {
        tokenMap[token] = true
    }
    return &BearerAuthenticator{
        Tokens: tokenMap,
    }
}

// Authenticate 验证请求中的Bearer令牌
func (ba *BearerAuthenticator) Authenticate(r *http.Request) (bool, error) {
    authHeader := r.Header.Get("Authorization")
    if authHeader == "" {
        return false, errors.New("missing authorization header")
    }

    parts := strings.Split(authHeader, " ")
    if len(parts) != 2 || parts[0] != "Bearer" {
        return false, errors.New("invalid authorization format")
    }

    token := parts[1]
    if _, valid := ba.Tokens[token]; valid {
        return true, nil
    }

    return false, errors.New("invalid token")
}

// AuthMiddleware 创建HTTP中间件进行认证
func AuthMiddleware(auth Authenticator) func(http.Handler) http.Handler {
    return func(next http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
            // 跳过AgentCard请求的认证
            if r.URL.Path == "/.well-known/agent.json" && r.Method == http.MethodGet {
                next.ServeHTTP(w, r)
                return
            }

            valid, err := auth.Authenticate(r)
            if !valid {
                http.Error(w, "Unauthorized: "+err.Error(), http.StatusUnauthorized)
                return
            }

            // 认证通过,继续处理请求
            next.ServeHTTP(w, r)
        })
    }
}

4. A2A服务器实现 (server/server.go)

package server

import (
    "context"
    "encoding/json"
    "fmt"
    "log"
    "net/http"
    "sync"
    "time"

    "mcp-a2a-plugin/auth"
    "mcp-a2a-plugin/types"
)

// TaskHandler 定义任务处理接口
type TaskHandler interface {
    HandleTask(ctx context.Context, task *types.Task, message *types.Message) (*types.Task, error)
}

// TaskStore 定义任务存储接口
type TaskStore interface {
    Save(task *types.Task) error
    Load(id string) (*types.Task, error)
    List() ([]*types.Task, error)
    Delete(id string) error
}

// InMemoryTaskStore 实现内存中的任务存储
type InMemoryTaskStore struct {
    tasks map[string]*types.Task
    mu    sync.RWMutex
}

func NewInMemoryTaskStore() *InMemoryTaskStore {
    return &InMemoryTaskStore{
        tasks: make(map[string]*types.Task),
    }
}

func (s *InMemoryTaskStore) Save(task *types.Task) error {
    s.mu.Lock()
    defer s.mu.Unlock()
    s.tasks[task.ID] = task
    return nil
}

func (s *InMemoryTaskStore) Load(id string) (*types.Task, error) {
    s.mu.RLock()
    defer s.mu.RUnlock()
    task, ok := s.tasks[id]
    if !ok {
        return nil, fmt.Errorf("task not found: %s", id)
    }
    return task, nil
}

func (s *InMemoryTaskStore) List() ([]*types.Task, error) {
    s.mu.RLock()
    defer s.mu.RUnlock()
    tasks := make([]*types.Task, 0, len(s.tasks))
    for _, task := range s.tasks {
        tasks = append(tasks, task)
    }
    return tasks, nil
}

func (s *InMemoryTaskStore) Delete(id string) error {
    s.mu.Lock()
    defer s.mu.Unlock()
    delete(s.tasks, id)
    return nil
}

// A2AServer 实现A2A协议服务器
type A2AServer struct {
    handler    TaskHandler
    taskStore  TaskStore
    agentCard  *types.AgentCard
    auth       auth.Authenticator
}

// NewA2AServer 创建新的A2A服务器
func NewA2AServer(handler TaskHandler, agentCard *types.AgentCard, authenticator auth.Authenticator) *A2AServer {
    return &A2AServer{
        handler:    handler,
        taskStore:  NewInMemoryTaskStore(),
        agentCard:  agentCard,
        auth:       authenticator,
    }
}

// Start 启动服务器
func (s *A2AServer) Start(port int) error {
    mux := http.NewServeMux()

    // 注册AgentCard端点
    mux.HandleFunc("/.well-known/agent.json", s.handleAgentCard)
    
    // 注册A2A协议端点
    mux.HandleFunc("/", s.handleRequest)

    // 应用认证中间件
    handler := auth.AuthMiddleware(s.auth)(mux)

    server := &http.Server{
        Addr:    fmt.Sprintf(":%d", port),
        Handler: handler,
    }

    log.Printf("A2A Server starting on port %d", port)
    return server.ListenAndServe()
}

// handleAgentCard 处理AgentCard请求
func (s *A2AServer) handleAgentCard(w http.ResponseWriter, r *http.Request) {
    if r.Method != http.MethodGet {
        http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
        return
    }

    w.Header().Set("Content-Type", "application/json")
    json.NewEncoder(w).Encode(s.agentCard)
}

// handleRequest 处理A2A协议请求
func (s *A2AServer) handleRequest(w http.ResponseWriter, r *http.Request) {
    if r.Method != http.MethodPost {
        http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
        return
    }

    var request types.JSONRPCRequest
    if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
        sendJSONRPCError(w, -32700, "Parse error", nil, "")
        return
    }

    // 根据方法路由请求
    switch request.Method {
    case "tasks/send":
        s.handleTaskSend(w, request)
    case "tasks/get":
        s.handleTaskGet(w, request)
    case "tasks/cancel":
        s.handleTaskCancel(w, request)
    default:
        sendJSONRPCError(w, -32601, "Method not found", nil, request.ID)
    }
}

// handleTaskSend 处理任务发送请求
func (s *A2AServer) handleTaskSend(w http.ResponseWriter, request types.JSONRPCRequest) {
    // 解析参数
    paramsBytes, err := json.Marshal(request.Params)
    if err != nil {
        sendJSONRPCError(w, -32602, "Invalid params", nil, request.ID)
        return
    }

    var params types.TaskSendParams
    if err := json.Unmarshal(paramsBytes, &params); err != nil {
        sendJSONRPCError(w, -32602, "Invalid params", nil, request.ID)
        return
    }

    // 创建或加载任务
    task, err := s.loadOrCreateTask(params.ID, params.SessionID)
    if err != nil {
        sendJSONRPCError(w, -32603, "Internal error", err.Error(), request.ID)
        return
    }

    // 更新任务状态为工作中
    task.Status.State = "working"
    task.Status.Timestamp = time.Now().Format(time.RFC3339)
    if err := s.taskStore.Save(task); err != nil {
        sendJSONRPCError(w, -32603, "Internal error", err.Error(), request.ID)
        return
    }

    // 处理任务
    updatedTask, err := s.handler.HandleTask(r.Context(), task, &params.Message)
    if err != nil {
        task.Status.State = "failed"
        task.Status.Timestamp = time.Now().Format(time.RFC3339)
        s.taskStore.Save(task)
        sendJSONRPCError(w, -32603, "Internal error", err.Error(), request.ID)
        return
    }

    // 保存更新后的任务
    if err := s.taskStore.Save(updatedTask); err != nil {
        sendJSONRPCError(w, -32603, "Internal error", err.Error(), request.ID)
        return
    }

    // 发送响应
    response := types.JSONRPCResponse{
        JSONRPC: "2.0",
        Result:  updatedTask,
        ID:

基于A2A协议的权限认证(增强版)

功能: 用户踢下线、记住我模式、注解鉴权、token自动续签、SSE鉴权以及跨域支持。

系统架构设计

├── cmd/
│   └── server/
│       └── main.go            # 应用入口
├── config/
│   └── config.go              # 配置管理
├── internal/
│   ├── auth/
│   │   ├── annotation.go      # 注解鉴权
│   │   ├── jwt.go             # JWT处理
│   │   ├── middleware.go      # 认证中间件
│   │   └── session.go         # 会话管理
│   ├── cache/
│   │   └── redis.go           # Redis客户端
│   ├── handler/
│   │   ├── agent.go           # A2A代理处理
│   │   ├── auth.go            # 认证处理
│   │   ├── sse.go             # SSE处理
│   │   └── task.go            # 任务处理
│   ├── model/
│   │   ├── agent.go           # A2A代理模型
│   │   ├── task.go            # 任务模型
│   │   └── user.go            # 用户模型
│   └── service/
│       ├── agent.go           # A2A代理服务
│       ├── auth.go            # 认证服务
│       └── task.go            # 任务服务
└── pkg/
    ├── cors/
    │   └── cors.go            # 跨域支持
    ├── sse/
    │   └── sse.go             # SSE实现
    └── utils/
        └── utils.go           # 工具函数

核心功能实现

1. Redis会话存储

// internal/cache/redis.go
package cache

import (
    "context"
    "encoding/json"
    "time"

    "github.com/go-redis/redis/v8"
)

type RedisClient struct {
    client *redis.Client
}

func NewRedisClient(addr, password string, db int) *RedisClient {
    client := redis.NewClient(&redis.Options{
        Addr:     addr,
        Password: password,
        DB:       db,
    })
    return &RedisClient{client: client}
}

// 设置键值对,支持过期时间
func (r *RedisClient) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error {
    data, err := json.Marshal(value)
    if err != nil {
        return err
    }
    return r.client.Set(ctx, key, data, expiration).Err()
}

// 获取键值
func (r *RedisClient) Get(ctx context.Context, key string, dest interface{}) error {
    data, err := r.client.Get(ctx, key).Bytes()
    if err != nil {
        return err
    }
    return json.Unmarshal(data, dest)
}

// 删除键
func (r *RedisClient) Del(ctx context.Context, key string) error {
    return r.client.Del(ctx, key).Err()
}

// 设置过期时间
func (r *RedisClient) Expire(ctx context.Context, key string, expiration time.Duration) error {
    return r.client.Expire(ctx, key, expiration).Err()
}

// 检查键是否存在
func (r *RedisClient) Exists(ctx context.Context, key string) (bool, error) {
    result, err := r.client.Exists(ctx, key).Result()
    return result > 0, err
}

// 发布消息到频道
func (r *RedisClient) Publish(ctx context.Context, channel string, message interface{}) error {
    data, err := json.Marshal(message)
    if err != nil {
        return err
    }
    return r.client.Publish(ctx, channel, data).Err()
}

// 订阅频道
func (r *RedisClient) Subscribe(ctx context.Context, channel string) *redis.PubSub {
    return r.client.Subscribe(ctx, channel)
}

2. JWT认证与自动续签

// internal/auth/jwt.go
package auth

import (
    "errors"
    "time"

    "github.com/golang-jwt/jwt/v4"
)

type TokenType string

const (
    AccessToken  TokenType = "access"
    RefreshToken TokenType = "refresh"
)

type TokenClaims struct {
    UserID    string    `json:"user_id"`
    Username  string    `json:"username"`
    TokenType TokenType `json:"token_type"`
    RememberMe bool     `json:"remember_me"`
    jwt.RegisteredClaims
}

type TokenService struct {
    accessSecret  string
    refreshSecret string
    accessExpiry  time.Duration
    refreshExpiry time.Duration
    rememberExpiry time.Duration
}

func NewTokenService(accessSecret, refreshSecret string, 
                     accessExpiry, refreshExpiry, rememberExpiry time.Duration) *TokenService {
    return &TokenService{
        accessSecret:  accessSecret,
        refreshSecret: refreshSecret,
        accessExpiry:  accessExpiry,
        refreshExpiry: refreshExpiry,
        rememberExpiry: rememberExpiry,
    }
}

// 生成访问令牌
func (s *TokenService) GenerateAccessToken(userID, username string, rememberMe bool) (string, time.Time, error) {
    expiry := s.accessExpiry
    if rememberMe {
        expiry = s.rememberExpiry
    }
    
    expiryTime := time.Now().Add(expiry)
    claims := TokenClaims{
        UserID:    userID,
        Username:  username,
        TokenType: AccessToken,
        RememberMe: rememberMe,
        RegisteredClaims: jwt.RegisteredClaims{
            ExpiresAt: jwt.NewNumericDate(expiryTime),
            IssuedAt:  jwt.NewNumericDate(time.Now()),
        },
    }

    token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
    tokenString, err := token.SignedString([]byte(s.accessSecret))
    return tokenString, expiryTime, err
}

// 生成刷新令牌
func (s *TokenService) GenerateRefreshToken(userID, username string, rememberMe bool) (string, time.Time, error) {
    expiry := s.refreshExpiry
    if rememberMe {
        expiry = s.rememberExpiry
    }
    
    expiryTime := time.Now().Add(expiry)
    claims := TokenClaims{
        UserID:    userID,
        Username:  username,
        TokenType: RefreshToken,
        RememberMe: rememberMe,
        RegisteredClaims: jwt.RegisteredClaims{
            ExpiresAt: jwt.NewNumericDate(expiryTime),
            IssuedAt:  jwt.NewNumericDate(time.Now()),
        },
    }

    token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
    tokenString, err := token.SignedString([]byte(s.refreshSecret))
    return tokenString, expiryTime, err
}

// 验证访问令牌
func (s *TokenService) ValidateAccessToken(tokenString string) (*TokenClaims, error) {
    return s.validateToken(tokenString, s.accessSecret, AccessToken)
}

// 验证刷新令牌
func (s *TokenService) ValidateRefreshToken(tokenString string) (*TokenClaims, error) {
    return s.validateToken(tokenString, s.refreshSecret, RefreshToken)
}

// 验证令牌
func (s *TokenService) validateToken(tokenString, secret string, tokenType TokenType) (*TokenClaims, error) {
    token, err := jwt.ParseWithClaims(tokenString, &TokenClaims{}, func(token *jwt.Token) (interface{}, error) {
        return []byte(secret), nil
    })

    if err != nil {
        return nil, err
    }

    if claims, ok := token.Claims.(*TokenClaims); ok && token.Valid {
        if claims.TokenType != tokenType {
            return nil, errors.New("invalid token type")
        }
        return claims, nil
    }

    return nil, errors.New("invalid token")
}

// 检查令牌是否需要续签(如果过期时间小于阈值)
func (s *TokenService) NeedsRenewal(claims *TokenClaims) bool {
    threshold := time.Now().Add(s.accessExpiry / 4) // 如果剩余时间少于1/4,则续签
    return claims.ExpiresAt.Time.Before(threshold)
}

3. 会话管理与用户踢下线

// internal/auth/session.go
package auth

import (
    "context"
    "errors"
    "fmt"
    "time"

    "example.com/a2a-auth/internal/cache"
    "example.com/a2a-auth/internal/model"
)

const (
    // 键前缀
    userSessionPrefix = "user:session:"
    userTokenPrefix   = "user:token:"
    sessionUserPrefix = "session:user:"
    
    // 用于踢下线的频道
    kickoutChannel = "user:kickout"
)

type SessionManager struct {
    redis       *cache.RedisClient
    tokenService *TokenService
}

func NewSessionManager(redis *cache.RedisClient, tokenService *TokenService) *SessionManager {
    return &SessionManager{
        redis:       redis,
        tokenService: tokenService,
    }
}

// 创建新会话
func (m *SessionManager) CreateSession(ctx context.Context, user *model.User, rememberMe bool) (*model.Session, error) {
    // 生成访问令牌和刷新令牌
    accessToken, accessExpiry, err := m.tokenService.GenerateAccessToken(user.ID, user.Username, rememberMe)
    if err != nil {
        return nil, err
    }

    refreshToken, refreshExpiry, err := m.tokenService.GenerateRefreshToken(user.ID, user.Username, rememberMe)
    if err != nil {
        return nil, err
    }

    // 创建会话
    session := &model.Session{
        ID:           generateSessionID(),
        UserID:       user.ID,
        AccessToken:  accessToken,
        RefreshToken: refreshToken,
        AccessExpiry: accessExpiry,
        RefreshExpiry: refreshExpiry,
        RememberMe:   rememberMe,
        UserAgent:    getUserAgent(ctx),
        IP:           getClientIP(ctx),
        CreatedAt:    time.Now(),
        LastActivity: time.Now(),
    }

    // 存储会话信息到Redis
    sessionKey := fmt.Sprintf("%s%s", sessionUserPrefix, session.ID)
    if err := m.redis.Set(ctx, sessionKey, session, time.Until(refreshExpiry)); err != nil {
        return nil, err
    }

    // 存储用户的会话列表
    userSessionKey := fmt.Sprintf("%s%s", userSessionPrefix, user.ID)
    var sessions []string
    if err := m.redis.Get(ctx, userSessionKey, &sessions); err == nil {
        sessions = append(sessions, session.ID)
    } else {
        sessions = []string{session.ID}
    }
    if err := m.redis.Set(ctx, userSessionKey, sessions, 0); err != nil {
        return nil, err
    }

    // 存储令牌到会话的映射
    tokenKey := fmt.Sprintf("%s%s", userTokenPrefix, accessToken)
    if err := m.redis.Set(ctx, tokenKey, session.ID, time.Until(accessExpiry)); err != nil {
        return nil, err
    }

    return session, nil
}

// 获取会话
func (m *SessionManager) GetSession(ctx context.Context, sessionID string) (*model.Session, error) {
    sessionKey := fmt.Sprintf("%s%s", sessionUserPrefix, sessionID)
    var session model.Session
    if err := m.redis.Get(ctx, sessionKey, &session); err != nil {
        return nil, err
    }
    return &session, nil
}

// 通过令牌获取会话
func (m *SessionManager) GetSessionByToken(ctx context.Context, token string) (*model.Session, error) {
    // 验证令牌
    claims, err := m.tokenService.ValidateAccessToken(token)
    if err != nil {
        return nil, err
    }

    // 获取令牌对应的会话ID
    tokenKey := fmt.Sprintf("%s%s", userTokenPrefix, token)
    var sessionID string
    if err := m.redis.Get(ctx, tokenKey, &sessionID); err != nil {
        return nil, errors.New("session not found")
    }

    // 获取会话
    return m.GetSession(ctx, sessionID)
}

// 更新会话活动时间
func (m *SessionManager) UpdateSessionActivity(ctx context.Context, sessionID string) error {
    session, err := m.GetSession(ctx, sessionID)
    if err != nil {
        return err
    }

    session.LastActivity = time.Now()
    sessionKey := fmt.Sprintf("%s%s", sessionUserPrefix, sessionID)
    return m.redis.Set(ctx, sessionKey, session, time.Until(session.RefreshExpiry))
}

// 续签令牌
func (m *SessionManager) RenewToken(ctx context.Context, session *model.Session) (*model.Session, error) {
    // 生成新的访问令牌
    accessToken, accessExpiry, err := m.tokenService.GenerateAccessToken(
        session.UserID, session.Username, session.RememberMe)
    if err != nil {
        return nil, err
    }

    // 删除旧令牌
    oldTokenKey := fmt.Sprintf("%s%s", userTokenPrefix, session.AccessToken)
    m.redis.Del(ctx, oldTokenKey)

    // 更新会话
    session.AccessToken = accessToken
    session.AccessExpiry = accessExpiry
    session.LastActivity = time.Now()

    // 保存更新后的会话
    sessionKey := fmt.Sprintf("%s%s", sessionUserPrefix, session.ID)
    if err := m.redis.Set(ctx, sessionKey, session, time.Until(session.RefreshExpiry)); err != nil {
        return nil, err
    }

    // 存储新令牌到会话的映射
    tokenKey := fmt.Sprintf("%s%s", userTokenPrefix, accessToken)

用户和会话模型 (internal/model/user.go 和 session.go)

// internal/model/user.go
package model

import (
    "time"
)

// User 表示系统用户
type User struct {
    ID        string    `json:"id"`
    Username  string    `json:"username"`
    Email     string    `json:"email"`
    Password  string    `json:"-"` // 不序列化密码
    Roles     []string  `json:"roles"`
    Active    bool      `json:"active"`
    CreatedAt time.Time `json:"created_at"`
    UpdatedAt time.Time `json:"updated_at"`
}

// HasRole 检查用户是否拥有指定角色
func (u *User) HasRole(role string) bool {
    for _, r := range u.Roles {
        if r == role {
            return true
        }
    }
    return false
}
// internal/model/session.go
package model

import (
    "time"
)

// Session 表示用户会话
type Session struct {
    ID            string    `json:"id"`
    UserID        string    `json:"user_id"`
    Username      string    `json:"username"`
    AccessToken   string    `json:"access_token"`
    RefreshToken  string    `json:"refresh_token"`
    AccessExpiry  time.Time `json:"access_expiry"`
    RefreshExpiry time.Time `json:"refresh_expiry"`
    RememberMe    bool      `json:"remember_me"`
    UserAgent     string    `json:"user_agent"`
    IP            string    `json:"ip"`
    CreatedAt     time.Time `json:"created_at"`
    LastActivity  time.Time `json:"last_activity"`
}

2. A2A代理和任务模型 (internal/model/agent.go 和 task.go)

// internal/model/agent.go
package model

// Agent 表示A2A代理
type Agent struct {
    ID            string   `json:"id"`
    Name          string   `json:"name"`
    Description   string   `json:"description"`
    URL           string   `json:"url"`
    Capabilities  []string `json:"capabilities"`
    Skills        []string `json:"skills"`
    Authentication struct {
        Schemes     []string `json:"schemes"`
        Credentials string   `json:"credentials,omitempty"`
    } `json:"authentication"`
}
// internal/model/task.go
package model

import (
    "time"
)

// Task 表示A2A任务
type Task struct {
    ID        string                 `json:"id"`
    AgentID   string                 `json:"agent_id"`
    UserID    string                 `json:"user_id"`
    SessionID string                 `json:"session_id"`
    Status    string                 `json:"status"` // submitted, working, completed, failed
    Message   string                 `json:"message"`
    Result    interface{}            `json:"result,omitempty"`
    Metadata  map[string]interface{} `json:"metadata,omitempty"`
    CreatedAt time.Time              `json:"created_at"`
    UpdatedAt time.Time              `json:"updated_at"`
}

3. 注解鉴权实现 (internal/auth/annotation.go)

// internal/auth/annotation.go
package auth

import (
    "context"
    "errors"
    "net/http"
    "reflect"
    "strings"
)

// RequireRoles 是用于标记需要特定角色的结构体标签
const RequireRoles = "auth:roles"

// RoleChecker 检查用户是否拥有所需角色
type RoleChecker interface {
    HasRole(ctx context.Context, role string) bool
}

// AnnotationMiddleware 创建基于注解的鉴权中间件
func AnnotationMiddleware(checker RoleChecker) func(http.Handler) http.Handler {
    return func(next http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
            // 获取处理函数的类型信息
            handlerType := reflect.TypeOf(next)
            
            // 检查是否有结构体字段
            if handlerType.Kind() == reflect.Struct {
                // 遍历结构体字段
                for i := 0; i < handlerType.NumField(); i++ {
                    field := handlerType.Field(i)
                    
                    // 检查是否有RequireRoles标签
                    if rolesTag, ok := field.Tag.Lookup(RequireRoles); ok {
                        roles := strings.Split(rolesTag, ",")
                        
                        // 检查用户是否拥有所需角色
                        hasRequiredRole := false
                        for _, role := range roles {
                            if checker.HasRole(r.Context(), strings.TrimSpace(role)) {
                                hasRequiredRole = true
                                break
                            }
                        }
                        
                        if !hasRequiredRole {
                            http.Error(w, "Forbidden: insufficient permissions", http.StatusForbidden)
                            return
                        }
                    }
                }
            }
            
            next.ServeHTTP(w, r)
        })
    }
}

// WithRoles 创建一个检查用户角色的中间件
func WithRoles(roles ...string) func(http.Handler) http.Handler {
    return func(next http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
            // 从上下文中获取用户
            user, ok := r.Context().Value("user").(interface{ HasRole(string) bool })
            if !ok {
                http.Error(w, "Unauthorized", http.StatusUnauthorized)
                return
            }
            
            // 检查用户是否拥有所需角色
            hasRequiredRole := false
            for _, role := range roles {
                if user.HasRole(role) {
                    hasRequiredRole = true
                    break
                }
            }
            
            if !hasRequiredRole {
                http.Error(w, "Forbidden: insufficient permissions", http.StatusForbidden)
                return
            }
            
            next.ServeHTTP(w, r)
        })
    }
}

4. 认证中间件实现 (internal/auth/middleware.go)

// internal/auth/middleware.go
package auth

import (
    "context"
    "net/http"
    "strings"
    "time"

    "example.com/a2a-auth/internal/model"
)

// AuthMiddleware 创建认证中间件
func AuthMiddleware(sessionManager *SessionManager) func(http.Handler) http.Handler {
    return func(next http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
            // 从请求中获取令牌
            token := extractToken(r)
            if token == "" {
                http.Error(w, "Unauthorized: missing token", http.StatusUnauthorized)
                return
            }
            
            // 验证令牌并获取会话
            session, err := sessionManager.GetSessionByToken(r.Context(), token)
            if err != nil {
                http.Error(w, "Unauthorized: invalid token", http.StatusUnauthorized)
                return
            }
            
            // 检查令牌是否需要续签
            claims, _ := sessionManager.tokenService.ValidateAccessToken(token)
            if sessionManager.tokenService.NeedsRenewal(claims) {
                // 续签令牌
                newSession, err := sessionManager.RenewToken(r.Context(), session)
                if err == nil {
                    session = newSession
                    // 在响应中设置新令牌
                    w.Header().Set("X-New-Token", session.AccessToken)
                }
            }
            
            // 更新会话活动时间
            sessionManager.UpdateSessionActivity(r.Context(), session.ID)
            
            // 将用户信息添加到请求上下文
            ctx := context.WithValue(r.Context(), "session", session)
            
            // 继续处理请求
            next.ServeHTTP(w, r.WithContext(ctx))
        })
    }
}

// 从请求中提取令牌
func extractToken(r *http.Request) string {
    // 从Authorization头中提取
    authHeader := r.Header.Get("Authorization")
    if authHeader != "" && strings.HasPrefix(authHeader, "Bearer ") {
        return strings.TrimPrefix(authHeader, "Bearer ")
    }
    
    // 从Cookie中提取
    cookie, err := r.Cookie("access_token")
    if err == nil {
        return cookie.Value
    }
    
    // 从查询参数中提取
    return r.URL.Query().Get("token")
}

// KickoutMiddleware 创建用于踢下线的中间件
func KickoutMiddleware(sessionManager *SessionManager) func(http.Handler) http.Handler {
    return func(next http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
            // 从上下文中获取会话
            session, ok := r.Context().Value("session").(*model.Session)
            if !ok {
                http.Error(w, "Unauthorized", http.StatusUnauthorized)
                return
            }
            
            // 检查会话是否被踢下线
            kicked, err := sessionManager.IsKickedOut(r.Context(), session.ID)
            if err != nil || kicked {
                http.Error(w, "Session terminated", http.StatusUnauthorized)
                return
            }
            
            next.ServeHTTP(w, r)
        })
    }
}

5. SSE实现 (pkg/sse/sse.go)

// pkg/sse/sse.go
package sse

import (
    "encoding/json"
    "fmt"
    "net/http"
    "time"
)

// Client 表示SSE客户端连接
type Client struct {
    ID     string
    Events chan interface{}
}

// NewClient 创建新的SSE客户端
func NewClient(id string) *Client {
    return &Client{
        ID:     id,
        Events: make(chan interface{}, 10),
    }
}

// ServeHTTP 处理SSE连接
func (c *Client) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    // 设置SSE头
    w.Header().Set("Content-Type", "text/event-stream")
    w.Header().Set("Cache-Control", "no-cache")
    w.Header().Set("Connection", "keep-alive")
    w.Header().Set("X-Accel-Buffering", "no") // 禁用Nginx缓冲
    
    // 刷新缓冲区以发送头
    if f, ok := w.(http.Flusher); ok {
        f.Flush()
    }
    
    // 创建通知通道
    notify := r.Context().Done()
    
    // 发送初始连接成功消息
    fmt.Fprintf(w, "data: %s\n\n", `{"type":"connected","id":"`+c.ID+`"}`)
    if f, ok := w.(http.Flusher); ok {
        f.Flush()
    }
    
    // 保持连接活跃的计时器
    ticker := time.NewTicker(30 * time.Second)
    defer ticker.Stop()
    
    // 处理事件
    for {
        select {
        case <-notify:
            // 客户端断开连接
            return
            
        case <-ticker.C:
            // 发送保持活跃消息
            fmt.Fprintf(w, ": ping\n\n")
            if f, ok := w.(http.Flusher); ok {
                f.Flush()
            }
            
        case event := <-c.Events:
            // 序列化事件
            data, err := json.Marshal(event)
            if err != nil {
                continue
            }
            
            // 发送事件
            fmt.Fprintf(w, "data: %s\n\n", data)
            if f, ok := w.(http.Flusher); ok {
                f.Flush()
            }
        }
    }
}

// Broker 管理多个SSE客户端
type Broker struct {
    clients    map[string]*Client
    register   chan *Client
    unregister chan *Client
    events     chan Event
}

// Event 表示要广播的事件
type Event struct {
    ClientID string
    Data     interface{}
}

// NewBroker 创建新的SSE代理
func NewBroker() *Broker {
    return &Broker{
        clients:    make(map[string]*Client),
        register:   make(chan *Client),
        

I'll help you understand how to implement the utils class based on the previous conversation about the A2A authentication system in Go.

Utils 类实现

根据之前的讨论,utils 包应该包含一些通用的工具函数,用于支持 A2A 认证系统的各个部分。以下是一个完整的 pkg/utils/utils.go 实现:

// pkg/utils/utils.go
package utils

import (
	"crypto/rand"
	"encoding/base64"
	"encoding/hex"
	"fmt"
	"net"
	"net/http"
	"strings"
	"time"
)

// GenerateRandomBytes 生成指定长度的随机字节
func GenerateRandomBytes(length int) ([]byte, error) {
	bytes := make([]byte, length)
	_, err := rand.Read(bytes)
	if err != nil {
		return nil, err
	}
	return bytes, nil
}

// GenerateRandomString 生成指定长度的随机字符串
func GenerateRandomString(length int) (string, error) {
	bytes, err := GenerateRandomBytes(length)
	if err != nil {
		return "", err
	}
	return hex.EncodeToString(bytes), nil
}

// GenerateSessionID 生成会话ID
func GenerateSessionID() (string, error) {
	return GenerateRandomString(16)
}

// GenerateToken 生成令牌
func GenerateToken() (string, error) {
	bytes, err := GenerateRandomBytes(32)
	if err != nil {
		return "", err
	}
	return base64.URLEncoding.EncodeToString(bytes), nil
}

// GetUserAgent 从请求中获取用户代理
func GetUserAgent(r *http.Request) string {
	return r.UserAgent()
}

// GetClientIP 从请求中获取客户端IP
func GetClientIP(r *http.Request) string {
	// 尝试从X-Forwarded-For头获取
	ip := r.Header.Get("X-Forwarded-For")
	if ip != "" {
		// 可能有多个IP,取第一个
		parts := strings.Split(ip, ",")
		ip = strings.TrimSpace(parts[0])
		return ip
	}

	// 尝试从X-Real-IP头获取
	ip = r.Header.Get("X-Real-IP")
	if ip != "" {
		return ip
	}

	// 从RemoteAddr获取
	ip, _, err := net.SplitHostPort(r.RemoteAddr)
	if err != nil {
		return r.RemoteAddr
	}
	return ip
}

// FormatTime 格式化时间为ISO8601格式
func FormatTime(t time.Time) string {
	return t.Format(time.RFC3339)
}

// ParseTime 解析ISO8601格式的时间
func ParseTime(s string) (time.Time, error) {
	return time.Parse(time.RFC3339, s)
}

// IsExpired 检查时间是否已过期
func IsExpired(t time.Time) bool {
	return t.Before(time.Now())
}

// GetExpiryDuration 获取到期时间的剩余时间
func GetExpiryDuration(t time.Time) time.Duration {
	return time.Until(t)
}

// SetCORSHeaders 设置CORS头
func SetCORSHeaders(w http.ResponseWriter, origin string) {
	if origin == "" {
		origin = "*"
	}
	w.Header().Set("Access-Control-Allow-Origin", origin)
	w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
	w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
	w.Header().Set("Access-Control-Allow-Credentials", "true")
	w.Header().Set("Access-Control-Max-Age", "86400") // 24小时
}

// HandleOptionsRequest 处理OPTIONS请求
func HandleOptionsRequest(w http.ResponseWriter, r *http.Request) bool {
	if r.Method == http.MethodOptions {
		w.WriteHeader(http.StatusOK)
		return true
	}
	return false
}

// SetSSEHeaders 设置SSE头
func SetSSEHeaders(w http.ResponseWriter) {
	w.Header().Set("Content-Type", "text/event-stream")
	w.Header().Set("Cache-Control", "no-cache")
	w.Header().Set("Connection", "keep-alive")
	w.Header().Set("X-Accel-Buffering", "no") // 禁用Nginx缓冲
}

// SendSSEEvent 发送SSE事件
func SendSSEEvent(w http.ResponseWriter, event string, data interface{}) error {
	if event != "" {
		fmt.Fprintf(w, "event: %s\n", event)
	}
	fmt.Fprintf(w, "data: %v\n\n", data)
	if f, ok := w.(http.Flusher); ok {
		f.Flush()
	}
	return nil
}

工具类用途

  • GenerateSessionID 用于 SessionManager.CreateSession 中生成会话ID
  • GetUserAgentGetClientIP 用于记录会话信息
  • SetCORSHeaders 用于支持跨域请求
  • SetSSEHeadersSendSSEEvent 用于SSE实现

记住我功能怎么实现的呢?

  1. JWT令牌服务中的实现
    TokenService结构体中,专门定义了一个rememberExpiry字段来存储"记住我"模式下的令牌过期时间:
type TokenService struct {  
    accessSecret  string  
    refreshSecret string  
    accessExpiry  time.Duration  
    refreshExpiry time.Duration  
    rememberExpiry time.Duration  // 记住我模式的过期时间  
}
  1. 令牌生成时的差异化处理
    在生成访问令牌和刷新令牌时,会根据rememberMe参数决定使用普通过期时间还是延长的过期时间:
func (s *TokenService) GenerateAccessToken(userID, username string, rememberMe bool) (string, time.Time, error) {  
    expiry := s.accessExpiry  
    if rememberMe {  
        expiry = s.rememberExpiry  // 如果选择了"记住我",使用更长的过期时间  
    }  
      
    // ...令牌生成逻辑  
}
  1. 在令牌声明中保存记住我状态
    TokenClaims结构体中包含了RememberMe字段,用于在令牌中保存用户的"记住我"选择:
type TokenClaims struct {  
    UserID    string    `json:"user_id"`  
    Username  string    `json:"username"`  
    TokenType TokenType `json:"token_type"`  
    RememberMe bool     `json:"remember_me"`  // 记住我状态  
    jwt.RegisteredClaims  
}
  1. 会话创建时的记住我处理
    CreateSession方法中,将"记住我"状态传递给令牌生成函数,并保存在会话对象中:
func (m *SessionManager) CreateSession(ctx context.Context, user *model.User, rememberMe bool) (*model.Session, error) {  
    // 生成访问令牌和刷新令牌时传入rememberMe参数  
    accessToken, accessExpiry, err := m.tokenService.GenerateAccessToken(user.ID, user.Username, rememberMe)  
    // ...  
      
    // 在会话对象中保存记住我状态  
    session := &model.Session{  
        // ...其他字段  
        RememberMe: rememberMe,  
        // ...  
    }  
      
    // ...  
}
  1. 令牌续签时保持记住我状态
    RenewToken方法中,续签令牌时会保持原有的"记住我"状态:
func (m *SessionManager) RenewToken(ctx context.Context, session *model.Session) (*model.Session, error) {  
    // 生成新的访问令牌,传入原有的记住我状态  
    accessToken, accessExpiry, err := m.tokenService.GenerateAccessToken(  
        session.UserID, session.Username, session.RememberMe)  
    // ...  
}

这种实现方式确保了选择"记住我"的用户会获得更长的会话有效期,并且在令牌自动续签过程中保持这一状态。