📋 类概述
类名: XxxSpuRelevanceTritonModelService
功能: 与 Triton 推理服务器交互,进行 SPU(商品)相关性语义模型推理
核心作用: 通过 BERT 模型对搜索查询和商品进行语义相关性评分,用于搜索结果排序
业务场景:
- 用户搜索 "红油火锅"
- 系统召回一批商品(SPU)
- 本服务计算每个商品与查询的相关性分数
- 排序模块根据相关性分数对商品重新排序
🔧 核心配置与初始化
关键常量
| 常量 | 值 | 说明 |
|---|---|---|
MAX_SEQ_LENGTH | 64 | BERT 模型输入序列最大长度 |
dictPath | dict/bert_vocab_zh.txt | 中文词汇表路径 |
REMOTE_APP_KEY | [远程推理服务标识] | 远程推理服务 AppKey |
LOCAL_APP_KEY | [本地应用标识] | 本地应用 AppKey |
RELEVANCE_GROUP_NAME_LION_KEY | [配置中心键名] | 配置中心键(推理服务器分组) |
groupNameList | ["a30", "t4"] | 默认推理服务器分组(GPU 类型) |
Soft Prompt 配置
softPrompt = {
{"[unused3]", "[unused4]"}, // POI Intent 用
{"[unused1]", "[unused2]"} // 其他 Intent 用
}
初始化流程 (@PostConstruct init())
1. 加载词汇表
├─ 读取 BERT 中文词汇表文件
├─ 构建 word -> token_id 映射
└─ 初始化 FullTokenizer
2. 初始化推理客户端
├─ 从配置中心读取分组名称
├─ 为每个分组创建 InferenceServerClient
└─ 存储在 clientMap 中
3. 动态配置监听
├─ 监听配置中心变化
├─ 新增分组时创建新客户端
└─ 删除分组时关闭旧客户端
🔄 模型交互流程
整体流程图
输入数据
↓
[access] 检查是否可执行
↓
[build] 构建推理请求
├─ 文本分词
├─ 构建 BERT 输入
└─ 生成 ModelInferRequest
↓
[action] 执行推理
├─ 调用 Triton 服务
├─ 获取模型输出
└─ 处理响应
↓
输出结果
📝 关键方法详解
1. access() - 前置检查
目的: 判断是否可以执行模型推理
检查项:
- ✅ 模型名称不为空
- ✅ 推理服务器分组名称不为空
- ✅ 对应分组的客户端存在
public boolean access(Pair<RankInfoCarrier, List<RankSpuItemProxy>> pair) {
// 从配置中获取模型名称和分组名称
String model_name = expConfig.getSemanticXxxRelevanceModelName();
String groupName = expConfig.getSemanticXxxRelevanceGroupName();
// 任何一个为空或客户端不存在则返回 false
if (model_name == null || "".equals(model_name) ||
groupName == null || "".equals(groupName) ||
clientMap.get(groupName) == null) {
return false;
}
return true;
}
2. build() - 构建推理请求
目的: 将输入数据转换为 Triton 模型可接受的格式
2.1 输入数据提取
RankInfoCarrier (排序信息载体)
├─ query: 搜索查询词
├─ queryIntentInfo: 查询意图(POI/SPU)
└─ xxxUnionLayerConfig: 模型配置
RankSpuItemProxy[] (商品列表)
├─ spuId: 商品 ID
├─ spuName: 商品名称
└─ poiName: 门店名称
2.2 文本分词与序列构建
关键方法: bpeTokenize()
输入序列结构 (MAX_SEQ_LENGTH = 64):
[CLS] + Query + [SEP] + [Prompt] + [SEP] + POI_Name + [SEP] + SPU_Name + [SEP] + [PAD]
↑ ↑
segment_id=0 segment_id=1
分段详解:
| 部分 | Segment ID | 说明 |
|---|---|---|
[CLS] | 0 | BERT 特殊标记 |
| Query | 0 | 搜索查询词(最多 16 个 token) |
[SEP] | 0 | 分隔符 |
| Prompt | 0 | Soft Prompt(可选,3 个 token) |
[SEP] | 0 | 分隔符 |
| POI Name | 1 | 门店名称 |
[SEP] | 1 | 分隔符 |
| SPU Name | 1 | 商品名称 |
[SEP] | 1 | 分隔符 |
| Padding | 0 | 填充到 64 长度 |
Soft Prompt 选择逻辑:
if (expConfig.getXxxUseSoftPrompt()) {
if (rankInfoCarrier.getQueryIntentInfo().isPoiIntent()) {
prompt = softPrompt[0]; // ["[unused3]", "[unused4]"]
} else {
prompt = softPrompt[1]; // ["[unused1]", "[unused2]"]
}
}
2.3 BERT 输入张量构建
三个输入张量:
1. input_ids (INT32)
└─ Token ID 序列,shape: [batch_size, 64]
2. token_type_ids (INT32)
└─ Segment ID,shape: [batch_size, 64]
└─ 0 表示第一句,1 表示第二句
3. attention_mask (INT32)
└─ 注意力掩码,shape: [batch_size, 64]
└─ 1 表示有效 token,0 表示 padding
张量准备方法:
private GrpcService.ModelInferRequest.InferInputTensor.Builder prepareTensor(
List<Integer> data, String dtype, String name, int dim0, int dim1) {
// 创建张量内容
GrpcService.InferTensorContents.Builder input_data =
GrpcService.InferTensorContents.newBuilder();
input_data.addAllIntContents(data);
// 创建张量
GrpcService.ModelInferRequest.InferInputTensor.Builder input =
GrpcService.ModelInferRequest.InferInputTensor.newBuilder();
input.setName(name);
input.setDatatype(dtype);
input.addShape(dim0); // batch_size
input.addShape(dim1); // sequence_length
input.setContents(input_data);
return input;
}
2.4 完整 build() 流程
public GrpcService.ModelInferRequest build(Pair<...> pair) {
// 1. 提取配置和数据
String model_name = expConfig.getSemanticXxxRelevanceModelName();
int batchSize = rankSpuItemProxyList.size();
// 2. 分词查询
List<String> queryTokens = tokenizer.tokenize(query);
// 3. 确定 Soft Prompt
String[] prompt = null;
if (expConfig.getXxxUseSoftPrompt()) {
prompt = rankInfoCarrier.getQueryIntentInfo().isPoiIntent()
? softPrompt[0] : softPrompt[1];
}
// 4. 为每个商品构建输入序列
for (RankSpuItemProxy item : rankSpuItemProxyList) {
List<List<Long>> idsList = bpeTokenize(queryTokens,
item.getPoiName(),
item.getSpuName(),
prompt);
inputIds.addAll(idsList.get(0)); // token IDs
inputMasks.addAll(idsList.get(1)); // attention mask
segmentIds.addAll(idsList.get(2)); // segment IDs
}
// 5. 构建三个输入张量
request.addInputs(0, prepareTensor(inputIdsInt, "INT32", "input_ids", ...));
request.addInputs(1, prepareTensor(segmentIdsInt, "INT32", "token_type_ids", ...));
request.addInputs(2, prepareTensor(inputMasksInt, "INT32", "attention_mask", ...));
// 6. 指定输出张量
request.addOutputs(0, output0); // name: "output"
return request.build();
}
3. action() - 执行推理
目的: 调用 Triton 推理服务并处理响应
3.1 推理调用
public Integer action(Pair<...> pair, GrpcService.ModelInferRequest modelInferRequest) {
try {
// 1. 获取对应分组的客户端
String groupName = expConfig.getSemanticXxxRelevanceGroupName();
InferenceServerClient client = clientMap.get(groupName);
// 2. 调用推理(同步 RPC)
GrpcService.ModelInferResponse response = client.predict(modelInferRequest);
// 3. 解析输出
float[] probs = toArray(
response.getRawOutputContentsList()
.get(0)
.asReadOnlyByteBuffer()
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer()
);
// 4. 处理响应
processResponse(rankInfoCarrier, rankSpuItemProxyList, probs);
return ExecuteResult.SUCCESS;
} catch (Exception e) {
LOGGER.error("FAIL TO CALL MODEL: ", e);
return ExecuteResult.FAILURE;
}
}
3.2 输出解析
关键步骤:
-
获取原始字节数据
response.getRawOutputContentsList().get(0) -
转换为 ByteBuffer
asReadOnlyByteBuffer() -
设置字节序 (小端序)
.order(ByteOrder.LITTLE_ENDIAN) -
转换为 FloatBuffer
.asFloatBuffer() -
转换为 float 数组
toArray(FloatBuffer b)
4. processResponse() - 处理推理结果
目的: 将模型输出分配给对应的商品
4.1 输出格式
模型输出: float[] probs
长度: batch_size * 3
每个商品对应 3 个分数:
├─ probs[3*i] → 高相关分数 (High Relevance)
├─ probs[3*i+1] → 低相关分数 (Low Relevance)
└─ probs[3*i+2] → 不相关分数 (No Relevance)
4.2 处理逻辑
public void processResponse(RankInfoCarrier rankInfoCarrier,
List<RankSpuItemProxy> rankSpuItemProxyList,
float[] probs) {
// 1. 验证输出长度
if (probs.length != rankSpuItemProxyList.size() * 3) {
LOGGER.warn("floatValList size not match spu size!");
return;
}
// 2. 遍历商品列表,分配分数
Map<Long, float[]> semanticProbMap = rankInfoCarrier.getXxxSemanticProbMap();
for (int i = 0; i < rankSpuItemProxyList.size(); i++) {
RankSpuItemProxy item = rankSpuItemProxyList.get(i);
if (Objects.nonNull(item)) {
int scoreIdx = 3 * i;
// 设置三个相关性分数
item.setHighRelevanceScore(probs[scoreIdx]);
item.setLowRelevanceScore(probs[scoreIdx + 1]);
item.setNoRelevanceScore(probs[scoreIdx + 2]);
// 存储到 Map 中
semanticProbMap.put(item.getSpuId(),
Arrays.copyOfRange(probs, scoreIdx, scoreIdx + 3));
}
}
}
📊 数据流转示例
示例场景
输入:
├─ Query: "红油火锅"
├─ QueryIntent: POI Intent
└─ SPU 列表:
├─ SPU1: spuName="牛肉丸", poiName="火锅店A"
└─ SPU2: spuName="鱼丸", poiName="火锅店B"
分词过程
Query 分词: ["红", "油", "火", "锅"]
SPU1 序列构建:
[CLS] + [红][油][火][锅] + [SEP] + [unused3][unused4] + [SEP]
+ [火][锅][店][A] + [SEP] + [牛][肉][丸] + [SEP] + [PAD]...
Segment IDs:
0 0 0 0 0 0 0 0 0 0
1 1 1 1 1 1 1 1 1 0
Attention Mask:
1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 0
推理输出
模型输出 (float[]):
[0.85, 0.10, 0.05, // SPU1 分数: 高相关 0.85, 低相关 0.10, 不相关 0.05
0.72, 0.20, 0.08] // SPU2 分数: 高相关 0.72, 低相关 0.20, 不相关 0.08
结果分配:
SPU1.highRelevanceScore = 0.85
SPU1.lowRelevanceScore = 0.10
SPU1.noRelevanceScore = 0.05
SPU2.highRelevanceScore = 0.72
SPU2.lowRelevanceScore = 0.20
SPU2.noRelevanceScore = 0.08
🔌 Triton 推理服务交互
通信协议
- 协议: gRPC
- 消息格式: Protocol Buffers (protobuf)
- 客户端:
InferenceServerClient
请求结构
ModelInferRequest {
model_name: "xxx_spu_relevance_model"
inputs: [
{
name: "input_ids"
datatype: "INT32"
shape: [batch_size, 64]
contents: { int_contents: [...] }
},
{
name: "token_type_ids"
datatype: "INT32"
shape: [batch_size, 64]
contents: { int_contents: [...] }
},
{
name: "attention_mask"
datatype: "INT32"
shape: [batch_size, 64]
contents: { int_contents: [...] }
}
]
outputs: [
{ name: "output" }
]
}
响应结构
ModelInferResponse {
outputs: [
{
name: "output"
datatype: "FP32"
shape: [batch_size, 3]
}
]
raw_output_contents: [
<binary float data>
]
}
⏱️ 性能监控
监控埋点
| 埋点名称 | 说明 |
|---|---|
xxx.relevance.model.tokenize | 分词耗时 |
xxx.relevance.model.build.request | 构建请求耗时 |
xxx.relevance.model.predict | 推理耗时 |
xxx.relevance.model.processResponse | 处理响应耗时 |
xxx.relevance.model.score.assign | 分数分配统计 |
超时配置
// 配置中心键: "[超时配置键]"
// 默认值: 200ms
int timeout = ConfigUtilAdapter.getInt("[timeout_config_key]", 200);
🛠️ 关键工具方法
toArray() - FloatBuffer 转数组
public static float[] toArray(FloatBuffer b) {
if (b.hasArray()) {
if (b.arrayOffset() == 0) return b.array();
return Arrays.copyOfRange(b.array(), b.arrayOffset(), b.array().length);
}
b.rewind();
float[] tmp = new float[b.remaining()];
b.get(tmp);
return tmp;
}
作用: 安全地将 FloatBuffer 转换为 float 数组,处理 offset 和 backing array 的情况
bpeTokenize() - 构建 BERT 输入序列
输入:
queryTokens: 查询分词结果poiName: 门店名称spuName: 商品名称prompt: Soft Prompt(可选)
输出:
List<List<Long>>: 包含 [inputIds, inputMasks, segmentIds]
🔐 线程安全性
并发设计
// 使用 ConcurrentHashMap 存储客户端
private Map<String, InferenceServerClient> clientMap = Maps.newConcurrentMap();
动态配置更新
// 配置中心监听器中的线程安全操作
ConfigUtilAdapter.getMtConfigClient().addListener(
CONFIG_KEY,
(s, s1, s2) -> {
// 使用迭代器安全删除
Iterator<String> iterator = clientMap.keySet().iterator();
while(iterator.hasNext()) {
String key = iterator.next();
if(!Arrays.stream(groupNameList).anyMatch(key::equals)) {
iterator.remove(); // 安全删除
}
}
}
);
📌 关键设计要点
1. Batch 处理
- 支持批量推理,提高吞吐量
- 所有商品共享同一个查询分词结果
2. Soft Prompt 机制
- 根据查询意图选择不同的 prompt
- 提升模型对不同场景的适应性
3. 多分组支持
- 支持多个推理服务器分组(如 A30、T4 GPU)
- 动态配置,无需重启
4. 三分类输出
- 高相关、低相关、不相关三个分数
- 为排序提供更细粒度的相关性信息
5. 错误处理
- 前置检查(access)确保可执行性
- 异常捕获和日志记录
- 输出验证(长度检查)
🎯 总结
这个类是一个商品相关性排序模型调用服务,基于 BERT 语义模型,核心流程为:
数据验证 → 文本分词 → 序列构建 → 张量准备 → gRPC 推理 → 结果解析 → 分数分配
业务价值:
- 提升搜索结果的相关性
- 将用户查询与商品进行语义匹配
- 为排序算法提供精准的相关性分数
技术特点:
- 通过 Triton 推理服务器调用 BERT 模型
- 支持动态配置、批量处理和多分组部署
- 采用标准的 RPC 调用模式