商品相关性排序模型调用服务

32 阅读7分钟

📋 类概述

类名: XxxSpuRelevanceTritonModelService

功能: 与 Triton 推理服务器交互,进行 SPU(商品)相关性语义模型推理

核心作用: 通过 BERT 模型对搜索查询和商品进行语义相关性评分,用于搜索结果排序

业务场景:

  • 用户搜索 "红油火锅"
  • 系统召回一批商品(SPU)
  • 本服务计算每个商品与查询的相关性分数
  • 排序模块根据相关性分数对商品重新排序

🔧 核心配置与初始化

关键常量

常量说明
MAX_SEQ_LENGTH64BERT 模型输入序列最大长度
dictPathdict/bert_vocab_zh.txt中文词汇表路径
REMOTE_APP_KEY[远程推理服务标识]远程推理服务 AppKey
LOCAL_APP_KEY[本地应用标识]本地应用 AppKey
RELEVANCE_GROUP_NAME_LION_KEY[配置中心键名]配置中心键(推理服务器分组)
groupNameList["a30", "t4"]默认推理服务器分组(GPU 类型)

Soft Prompt 配置

softPrompt = {
    {"[unused3]", "[unused4]"},  // POI Intent 用
    {"[unused1]", "[unused2]"}   // 其他 Intent 用
}

初始化流程 (@PostConstruct init())

1. 加载词汇表
   ├─ 读取 BERT 中文词汇表文件
   ├─ 构建 word -> token_id 映射
   └─ 初始化 FullTokenizer

2. 初始化推理客户端
   ├─ 从配置中心读取分组名称
   ├─ 为每个分组创建 InferenceServerClient
   └─ 存储在 clientMap 中

3. 动态配置监听
   ├─ 监听配置中心变化
   ├─ 新增分组时创建新客户端
   └─ 删除分组时关闭旧客户端

🔄 模型交互流程

整体流程图

输入数据
  ↓
[access] 检查是否可执行
  ↓
[build] 构建推理请求
  ├─ 文本分词
  ├─ 构建 BERT 输入
  └─ 生成 ModelInferRequest
  ↓
[action] 执行推理
  ├─ 调用 Triton 服务
  ├─ 获取模型输出
  └─ 处理响应
  ↓
输出结果

📝 关键方法详解

1. access() - 前置检查

目的: 判断是否可以执行模型推理

检查项:

  • ✅ 模型名称不为空
  • ✅ 推理服务器分组名称不为空
  • ✅ 对应分组的客户端存在
public boolean access(Pair<RankInfoCarrier, List<RankSpuItemProxy>> pair) {
    // 从配置中获取模型名称和分组名称
    String model_name = expConfig.getSemanticXxxRelevanceModelName();
    String groupName = expConfig.getSemanticXxxRelevanceGroupName();

    // 任何一个为空或客户端不存在则返回 false
    if (model_name == null || "".equals(model_name) ||
        groupName == null || "".equals(groupName) ||
        clientMap.get(groupName) == null) {
        return false;
    }
    return true;
}

2. build() - 构建推理请求

目的: 将输入数据转换为 Triton 模型可接受的格式

2.1 输入数据提取

RankInfoCarrier (排序信息载体)
├─ query: 搜索查询词
├─ queryIntentInfo: 查询意图(POI/SPU)
└─ xxxUnionLayerConfig: 模型配置

RankSpuItemProxy[] (商品列表)
├─ spuId: 商品 ID
├─ spuName: 商品名称
└─ poiName: 门店名称

2.2 文本分词与序列构建

关键方法: bpeTokenize()

输入序列结构 (MAX_SEQ_LENGTH = 64):

[CLS] + Query + [SEP] + [Prompt] + [SEP] + POI_Name + [SEP] + SPU_Name + [SEP] + [PAD]
 ↑                                                                                    ↑
 segment_id=0                                                                  segment_id=1

分段详解:

部分Segment ID说明
[CLS]0BERT 特殊标记
Query0搜索查询词(最多 16 个 token)
[SEP]0分隔符
Prompt0Soft Prompt(可选,3 个 token)
[SEP]0分隔符
POI Name1门店名称
[SEP]1分隔符
SPU Name1商品名称
[SEP]1分隔符
Padding0填充到 64 长度

Soft Prompt 选择逻辑:

if (expConfig.getXxxUseSoftPrompt()) {
    if (rankInfoCarrier.getQueryIntentInfo().isPoiIntent()) {
        prompt = softPrompt[0];  // ["[unused3]", "[unused4]"]
    } else {
        prompt = softPrompt[1];  // ["[unused1]", "[unused2]"]
    }
}

2.3 BERT 输入张量构建

三个输入张量:

1. input_ids (INT32)
   └─ Token ID 序列,shape: [batch_size, 64]

2. token_type_ids (INT32)
   └─ Segment ID,shape: [batch_size, 64]
   └─ 0 表示第一句,1 表示第二句

3. attention_mask (INT32)
   └─ 注意力掩码,shape: [batch_size, 64]
   └─ 1 表示有效 token,0 表示 padding

张量准备方法:

private GrpcService.ModelInferRequest.InferInputTensor.Builder prepareTensor(
    List<Integer> data, String dtype, String name, int dim0, int dim1) {
    // 创建张量内容
    GrpcService.InferTensorContents.Builder input_data =
        GrpcService.InferTensorContents.newBuilder();
    input_data.addAllIntContents(data);

    // 创建张量
    GrpcService.ModelInferRequest.InferInputTensor.Builder input =
        GrpcService.ModelInferRequest.InferInputTensor.newBuilder();
    input.setName(name);
    input.setDatatype(dtype);
    input.addShape(dim0);      // batch_size
    input.addShape(dim1);      // sequence_length
    input.setContents(input_data);

    return input;
}

2.4 完整 build() 流程

public GrpcService.ModelInferRequest build(Pair<...> pair) {
    // 1. 提取配置和数据
    String model_name = expConfig.getSemanticXxxRelevanceModelName();
    int batchSize = rankSpuItemProxyList.size();

    // 2. 分词查询
    List<String> queryTokens = tokenizer.tokenize(query);

    // 3. 确定 Soft Prompt
    String[] prompt = null;
    if (expConfig.getXxxUseSoftPrompt()) {
        prompt = rankInfoCarrier.getQueryIntentInfo().isPoiIntent()
            ? softPrompt[0] : softPrompt[1];
    }

    // 4. 为每个商品构建输入序列
    for (RankSpuItemProxy item : rankSpuItemProxyList) {
        List<List<Long>> idsList = bpeTokenize(queryTokens,
                                               item.getPoiName(),
                                               item.getSpuName(),
                                               prompt);
        inputIds.addAll(idsList.get(0));      // token IDs
        inputMasks.addAll(idsList.get(1));    // attention mask
        segmentIds.addAll(idsList.get(2));    // segment IDs
    }

    // 5. 构建三个输入张量
    request.addInputs(0, prepareTensor(inputIdsInt, "INT32", "input_ids", ...));
    request.addInputs(1, prepareTensor(segmentIdsInt, "INT32", "token_type_ids", ...));
    request.addInputs(2, prepareTensor(inputMasksInt, "INT32", "attention_mask", ...));

    // 6. 指定输出张量
    request.addOutputs(0, output0);  // name: "output"

    return request.build();
}

3. action() - 执行推理

目的: 调用 Triton 推理服务并处理响应

3.1 推理调用

public Integer action(Pair<...> pair, GrpcService.ModelInferRequest modelInferRequest) {
    try {
        // 1. 获取对应分组的客户端
        String groupName = expConfig.getSemanticXxxRelevanceGroupName();
        InferenceServerClient client = clientMap.get(groupName);

        // 2. 调用推理(同步 RPC)
        GrpcService.ModelInferResponse response = client.predict(modelInferRequest);

        // 3. 解析输出
        float[] probs = toArray(
            response.getRawOutputContentsList()
                   .get(0)
                   .asReadOnlyByteBuffer()
                   .order(ByteOrder.LITTLE_ENDIAN)
                   .asFloatBuffer()
        );

        // 4. 处理响应
        processResponse(rankInfoCarrier, rankSpuItemProxyList, probs);

        return ExecuteResult.SUCCESS;
    } catch (Exception e) {
        LOGGER.error("FAIL TO CALL MODEL: ", e);
        return ExecuteResult.FAILURE;
    }
}

3.2 输出解析

关键步骤:

  1. 获取原始字节数据

    response.getRawOutputContentsList().get(0)
    
  2. 转换为 ByteBuffer

    asReadOnlyByteBuffer()
    
  3. 设置字节序 (小端序)

    .order(ByteOrder.LITTLE_ENDIAN)
    
  4. 转换为 FloatBuffer

    .asFloatBuffer()
    
  5. 转换为 float 数组

    toArray(FloatBuffer b)
    

4. processResponse() - 处理推理结果

目的: 将模型输出分配给对应的商品

4.1 输出格式

模型输出: float[] probs
长度: batch_size * 3

每个商品对应 3 个分数:
├─ probs[3*i]     → 高相关分数 (High Relevance)
├─ probs[3*i+1]   → 低相关分数 (Low Relevance)
└─ probs[3*i+2]   → 不相关分数 (No Relevance)

4.2 处理逻辑

public void processResponse(RankInfoCarrier rankInfoCarrier,
                           List<RankSpuItemProxy> rankSpuItemProxyList,
                           float[] probs) {
    // 1. 验证输出长度
    if (probs.length != rankSpuItemProxyList.size() * 3) {
        LOGGER.warn("floatValList size not match spu size!");
        return;
    }

    // 2. 遍历商品列表,分配分数
    Map<Long, float[]> semanticProbMap = rankInfoCarrier.getXxxSemanticProbMap();
    for (int i = 0; i < rankSpuItemProxyList.size(); i++) {
        RankSpuItemProxy item = rankSpuItemProxyList.get(i);
        if (Objects.nonNull(item)) {
            int scoreIdx = 3 * i;

            // 设置三个相关性分数
            item.setHighRelevanceScore(probs[scoreIdx]);
            item.setLowRelevanceScore(probs[scoreIdx + 1]);
            item.setNoRelevanceScore(probs[scoreIdx + 2]);

            // 存储到 Map 中
            semanticProbMap.put(item.getSpuId(),
                Arrays.copyOfRange(probs, scoreIdx, scoreIdx + 3));
        }
    }
}

📊 数据流转示例

示例场景

输入:
├─ Query: "红油火锅"
├─ QueryIntent: POI Intent
└─ SPU 列表:
   ├─ SPU1: spuName="牛肉丸", poiName="火锅店A"
   └─ SPU2: spuName="鱼丸", poiName="火锅店B"

分词过程

Query 分词: ["红", "油", "火", "锅"]

SPU1 序列构建:
[CLS] + [红][油][火][锅] + [SEP] + [unused3][unused4] + [SEP]
+ [火][锅][店][A] + [SEP] + [牛][肉][丸] + [SEP] + [PAD]...

Segment IDs:
0      0      0      0      0      0      0      0      0      0
1      1      1      1      1      1      1      1      1      0

Attention Mask:
1      1      1      1      1      1      1      1      1      1
1      1      1      1      1      1      1      1      1      0

推理输出

模型输出 (float[]):
[0.85, 0.10, 0.05,    // SPU1 分数: 高相关 0.85, 低相关 0.10, 不相关 0.05
 0.72, 0.20, 0.08]    // SPU2 分数: 高相关 0.72, 低相关 0.20, 不相关 0.08

结果分配:
SPU1.highRelevanceScore = 0.85
SPU1.lowRelevanceScore = 0.10
SPU1.noRelevanceScore = 0.05

SPU2.highRelevanceScore = 0.72
SPU2.lowRelevanceScore = 0.20
SPU2.noRelevanceScore = 0.08

🔌 Triton 推理服务交互

通信协议

  • 协议: gRPC
  • 消息格式: Protocol Buffers (protobuf)
  • 客户端: InferenceServerClient

请求结构

ModelInferRequest {
    model_name: "xxx_spu_relevance_model"
    inputs: [
        {
            name: "input_ids"
            datatype: "INT32"
            shape: [batch_size, 64]
            contents: { int_contents: [...] }
        },
        {
            name: "token_type_ids"
            datatype: "INT32"
            shape: [batch_size, 64]
            contents: { int_contents: [...] }
        },
        {
            name: "attention_mask"
            datatype: "INT32"
            shape: [batch_size, 64]
            contents: { int_contents: [...] }
        }
    ]
    outputs: [
        { name: "output" }
    ]
}

响应结构

ModelInferResponse {
    outputs: [
        {
            name: "output"
            datatype: "FP32"
            shape: [batch_size, 3]
        }
    ]
    raw_output_contents: [
        <binary float data>
    ]
}

⏱️ 性能监控

监控埋点

埋点名称说明
xxx.relevance.model.tokenize分词耗时
xxx.relevance.model.build.request构建请求耗时
xxx.relevance.model.predict推理耗时
xxx.relevance.model.processResponse处理响应耗时
xxx.relevance.model.score.assign分数分配统计

超时配置

// 配置中心键: "[超时配置键]"
// 默认值: 200ms
int timeout = ConfigUtilAdapter.getInt("[timeout_config_key]", 200);

🛠️ 关键工具方法

toArray() - FloatBuffer 转数组

public static float[] toArray(FloatBuffer b) {
    if (b.hasArray()) {
        if (b.arrayOffset() == 0) return b.array();
        return Arrays.copyOfRange(b.array(), b.arrayOffset(), b.array().length);
    }
    b.rewind();
    float[] tmp = new float[b.remaining()];
    b.get(tmp);
    return tmp;
}

作用: 安全地将 FloatBuffer 转换为 float 数组,处理 offset 和 backing array 的情况

bpeTokenize() - 构建 BERT 输入序列

输入:

  • queryTokens: 查询分词结果
  • poiName: 门店名称
  • spuName: 商品名称
  • prompt: Soft Prompt(可选)

输出:

  • List<List<Long>>: 包含 [inputIds, inputMasks, segmentIds]

🔐 线程安全性

并发设计

// 使用 ConcurrentHashMap 存储客户端
private Map<String, InferenceServerClient> clientMap = Maps.newConcurrentMap();

动态配置更新

// 配置中心监听器中的线程安全操作
ConfigUtilAdapter.getMtConfigClient().addListener(
    CONFIG_KEY,
    (s, s1, s2) -> {
        // 使用迭代器安全删除
        Iterator<String> iterator = clientMap.keySet().iterator();
        while(iterator.hasNext()) {
            String key = iterator.next();
            if(!Arrays.stream(groupNameList).anyMatch(key::equals)) {
                iterator.remove();  // 安全删除
            }
        }
    }
);

📌 关键设计要点

1. Batch 处理

  • 支持批量推理,提高吞吐量
  • 所有商品共享同一个查询分词结果

2. Soft Prompt 机制

  • 根据查询意图选择不同的 prompt
  • 提升模型对不同场景的适应性

3. 多分组支持

  • 支持多个推理服务器分组(如 A30、T4 GPU)
  • 动态配置,无需重启

4. 三分类输出

  • 高相关、低相关、不相关三个分数
  • 为排序提供更细粒度的相关性信息

5. 错误处理

  • 前置检查(access)确保可执行性
  • 异常捕获和日志记录
  • 输出验证(长度检查)

🎯 总结

这个类是一个商品相关性排序模型调用服务,基于 BERT 语义模型,核心流程为:

数据验证 → 文本分词 → 序列构建 → 张量准备 → gRPC 推理 → 结果解析 → 分数分配

业务价值:

  • 提升搜索结果的相关性
  • 将用户查询与商品进行语义匹配
  • 为排序算法提供精准的相关性分数

技术特点:

  • 通过 Triton 推理服务器调用 BERT 模型
  • 支持动态配置、批量处理和多分组部署
  • 采用标准的 RPC 调用模式