RQ-VAE:文本向量 → Semantic ID (SID) 全流程分析
一、整体架构概览
RQ-VAE(Residual Quantized Variational AutoEncoder)将高维连续的文本 Embedding 向量压缩为离散的多级码本索引,即 Semantic ID (SID)。
核心思想:通过 Encoder 降维 → 残差向量量化(多层 VQ 逐级逼近)→ Decoder 重建,以自监督方式训练,使离散编码尽可能保留原始语义信息。
模块组成
rqvae.py # 训练入口,参数配置、数据加载、启动训练
models/rqvae.py # RQVAE 模型:Encoder + RQ + Decoder
models/rq.py # ResidualVectorQuantizer:多层残差量化
models/vq.py # VectorQuantizer:单层向量量化(码本查找)
models/layers.py # MLPLayers、KMeans、Sinkhorn 等基础组件
datasets.py # EmbDataset:加载 .npy embedding 数据
trainer.py # 训练循环、评估(碰撞率)、checkpoint 管理
二、文本向量 → SID 转化流程(维度变化)
以默认配置为例:
in_dim = D(由 Qwen 模型决定,如 Qwen-1.8B 为 2048,Qwen-7B 为 4096)layers = [2048, 1024, 512, 256, 128, 64]e_dim = 32num_emb_list = [256, 256, 256](3 层 VQ,每层码本大小 256)
完整流程图
输入 Item Embedding
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ Encoder (MLPLayers) │
│ │
│ (batch, D) ─Linear→ (batch, 2048) ─ReLU→ │
│ ─Linear→ (batch, 1024) ─ReLU→ │
│ ─Linear→ (batch, 512) ─ReLU→ │
│ ─Linear→ (batch, 256) ─ReLU→ │
│ ─Linear→ (batch, 128) ─ReLU→ │
│ ─Linear→ (batch, 64) ─ReLU→ │
│ ─Linear→ (batch, 32) ← 最后一层无激活函数 │
│ │
│ 输出: x_e, shape = (batch, 32) │
└─────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ Residual Vector Quantization (3 层 VQ 残差量化) │
│ │
│ 初始: residual = x_e (batch, 32) │
│ x_q = 0 │
│ │
│ ┌─── VQ Layer 1 (码本: 256 × 32) ───────────────────────────┐ │
│ │ 输入: residual (batch, 32) │ │
│ │ 计算 L2 距离: d = ||r - c||² (batch, 256) │ │
│ │ 查找最近码字: indices_1 = argmin(d) (batch,) │ │
│ │ 码字查表: x_q_1 = Embed(indices) (batch, 32) │ │
│ │ 更新残差: residual = residual - x_q_1 │ │
│ │ 累加量化: x_q = x_q + x_q_1 │ │
│ └────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─── VQ Layer 2 (码本: 256 × 32) ───────────────────────────┐ │
│ │ 输入: residual (Layer1 的残差) (batch, 32) │ │
│ │ 计算 L2 距离: d = ||r - c||² (batch, 256) │ │
│ │ 查找最近码字: indices_2 = argmin(d) (batch,) │ │
│ │ 码字查表: x_q_2 = Embed(indices) (batch, 32) │ │
│ │ 更新残差: residual = residual - x_q_2 │ │
│ │ 累加量化: x_q = x_q + x_q_2 │ │
│ └────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─── VQ Layer 3 (码本: 256 × 32) ───────────────────────────┐ │
│ │ 输入: residual (Layer2 的残差) (batch, 32) │ │
│ │ 计算 L2 距离: d = ||r - c||² (batch, 256) │ │
│ │ 查找最近码字: indices_3 = argmin(d) (batch,) │ │
│ │ 码字查表: x_q_3 = Embed(indices) (batch, 32) │ │
│ │ 更新残差: residual = residual - x_q_3 │ │
│ │ 累加量化: x_q = x_q + x_q_3 │ │
│ └────────────────────────────────────────────────────────────┘ │
│ │
│ 输出: │
│ x_q = x_q_1 + x_q_2 + x_q_3 (batch, 32) │
│ indices = stack([idx1, idx2, idx3]) (batch, 3) ← 即 SID │
│ rq_loss = mean(loss_1, loss_2, loss_3) │
└──────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ Decoder (MLPLayers,Encoder 的镜像,输入为x_q) │
│ │
│ (batch, 32) ─Linear→ (batch, 64) ─ReLU→ │
│ ─Linear→ (batch, 128) ─ReLU→ │
│ ─Linear→ (batch, 256) ─ReLU→ │
│ ─Linear→ (batch, 512) ─ReLU→ │
│ ─Linear→ (batch, 1024) ─ReLU→ │
│ ─Linear→ (batch, 2048) ─ReLU→ │
│ ─Linear→ (batch, D) ← 重建原始维度 │
│ │
│ 输出: out, shape = (batch, D) │
└─────────────────────────────────────────────────────────────────┘
│
▼
SID 结果: (batch, 3)
每个 item 得到一个三元组 (c1, c2, c3),ci ∈ [0, 255]
例如: item_42 → SID = (17, 203, 89)
维度变化总结表
| 阶段 | 操作 | 输入维度 | 输出维度 |
|---|---|---|---|
| 输入 | 从 .npy 加载 | - | (batch, D) |
| Encoder | 7 层 MLP 逐步降维 | (batch, D) | (batch, 32) |
| RQ-VQ×3 | 残差量化,码本查找 | (batch, 32) | x_q: (batch, 32), indices: (batch, 3) |
| Decoder | 7 层 MLP 逐步升维 | (batch, 32) | (batch, D) |
| 最终 SID | 取 indices | - | (batch, 3),每个值 ∈ [0, 255] |
三、单层 VQ 量化详细过程
以 VQ Layer 1 为例(vq.py),输入 x 的形状为 (batch, 32):
Step0: 码本初始化
码本有两种初始化方式,由参数 kmeans_init 控制(默认为 True):
方式一:均匀随机初始化(kmeans_init=False)
直接用 uniform_(-1/n_e, 1/n_e) 小范围均匀分布初始化码本权重。
方式二:KMeans 延迟初始化(kmeans_init=True,默认)
构造时先将码本置零,标记 initted = False。在训练阶段的首次前向传播时,
对第一个 batch 的编码向量执行 KMeans 聚类(默认 100 次迭代),
将得到的 256 个聚类中心直接作为码本初始值:
首次 forward 触发:
latent (batch, 32) ──KMeans(n_clusters=256, max_iter=100)──→ centers (256, 32)
embedding.weight ← centers
initted = True (后续 forward 不再触发)
KMeans 初始化的优势:码本从一开始就分布在数据的实际区域中,比随机初始化收敛更快、码本利用率更高。
Step1: 计算 L2 距离
d(i,j) = ||x_i||² + ||c_j||² - 2 * x_i · c_j
x: (batch, 32),codebook: (256, 32)- 输出
d: (batch, 256),每个样本到 256 个码字的距离
Step2: 查找最近码字
- 无 Sinkhorn(
sk_epsilon <= 0):indices = argmin(d, dim=-1)→ (batch,) - 有 Sinkhorn(
sk_epsilon > 0): 先对距离做中心化归一化,再通过 Sinkhorn 迭代求软分配矩阵 Q,最后indices = argmax(Q, dim=-1)→ (batch,)
Sinkhorn 的作用:促进码本利用均匀化,避免码本坍塌(少数码字被频繁使用,其余码字浪费)。
什么是码本坍塌?
码本坍塌是向量量化中的经典问题:训练过程中只有少数码字被反复选中,其余大部分码字永远不会被匹配到。
理想情况(均匀利用): 坍塌情况:
码字 0: ████ (50个样本) 码字 0: ████████████ (800个样本)
码字 1: ████ (48个样本) 码字 1: ██████████ (600个样本)
码字 2: ████ (52个样本) 码字 2: █ (10个样本)
... ...
码字 255: ████ (49个样本) 码字 255: (0个样本) ← 死码字
恶性循环:被选中的码字持续获得梯度更新变得更好 → 吸引更多样本 → 其他码字永远得不到更新 → 码本容量严重浪费。
Sinkhorn 均衡化算法
Sinkhorn 将样本-码字分配建模为最优传输问题,通过行列交替归一化强制每个码字被近似均匀地分配到。
算法流程:
Step1: 距离中心化 → 将距离矩阵 d (B, K) 归一化到 [-1, 1],防止 exp 溢出
Step2: 构造软分配矩阵 Q = exp(-d / ε),距离越小 Q 值越大(越可能分配)
Step3: 交替归一化迭代:
重复 sk_iters 次:
① 行归一化: Q /= Q.sum(dim=1) / B → 约束每个样本的分配权重之和相等
② 列归一化: Q /= Q.sum(dim=0) / K → 约束每个码字被分配到的权重之和相等
Step4: Q *= B,取 indices = argmax(Q, dim=-1)
Sinkhorn 均衡化实例
场景:B=4 个样本,K=2 个码字
各样本到码字的距离矩阵 d (4, 2)(已中心化到 [-1, 1]):
code_0 code_1
sample_0: -0.8 0.8 ← 强烈偏好 code_0
sample_1: -0.4 0.4 ← 明显偏好 code_0
sample_2: -0.2 0.2 ← 微弱偏好 code_0(关键样本)
sample_3: 0.6 -0.6 ← 偏好 code_1
无 Sinkhorn(argmin 硬分配):
sample_0 → code_0, sample_1 → code_0, sample_2 → code_0, sample_3 → code_1
结果: code_0 获得 3 个样本, code_1 获得 1 个样本 → 不均衡!
有 Sinkhorn(ε=0.5):
Step1 — 构造 Q = exp(-d / 0.5),距离越小 Q 值越大:
code_0 code_1
sample_0: 4.953 0.202 ← 强烈倾向 code_0
sample_1: 2.226 0.449
sample_2: 1.492 0.670 ← 对 code_0 的倾向不强
sample_3: 0.301 3.320 ← 倾向 code_1
归一化使总和=1 后,观察列和:code_0 占 65.9%,code_1 仅 34.1% → 严重不均。
Step2 — 交替归一化迭代,观察 sample_2(偏好最弱的样本)的变化趋势:
sample_2 对 code_0 的权重 sample_2 对 code_1 的权重
初始: 0.690 0.310
第 1 轮后: 0.554 0.446
第 2 轮后: 0.487 0.513 ← 翻转! code_1 超过 code_0
收敛后: ~0.42 ~0.58
Step3 — 最终 Q × B,取 argmax:
code_0 code_1 argmax
sample_0: ~0.95 ~0.05 → code_0 ✓ (偏好最强,不变)
sample_1: ~0.75 ~0.25 → code_0 ✓ (偏好较强,不变)
sample_2: ~0.42 ~0.58 → code_1 ★ 被重新分配!
sample_3: ~0.05 ~0.95 → code_1 ✓ (不变)
结果: code_0 获得 2 个样本, code_1 获得 2 个样本 → 均衡!
对比总结:
无 Sinkhorn (argmin) 有 Sinkhorn (最优传输)
sample_0: → code_0 → code_0 (不变)
sample_1: → code_0 → code_0 (不变)
sample_2: → code_0 ← 跟风去了0 → code_1 ← 被均衡化到1
sample_3: → code_1 → code_1 (不变)
─────────────────────────────────────────────────────────────
code_0: 3 个样本 (75%) 2 个样本 (50%)
code_1: 1 个样本 (25%) 2 个样本 (50%)
核心直觉:sample_2 对 code_0 的偏好最弱(距离差最小),是"最不坚定"的样本。 Sinkhorn 通过全局均衡约束,把它"让渡"给了样本不足的 code_1,实现码本均匀利用。 偏好越强的样本(如 sample_0)越不容易被重新分配。
为什么行列交替归一化能实现均衡?
用一个 3×3 的例子展示。3 个样本都偏好 code_0,初始分配极不均衡:
初始 Q:
code_0 code_1 code_2 行和
sample_0: 0.90 0.05 0.05 1.00
sample_1: 0.80 0.10 0.10 1.00
sample_2: 0.70 0.15 0.15 1.00
──── ──── ────
列和: 2.40 0.30 0.30 ← code_0 占 80%,严重不均
目标列和: 1.00 1.00 1.00
第一步:列归一化(每列 ÷ 列和,强制列和相等)
列归一化后:
code_0 code_1 code_2 行和
sample_0: 0.90/2.40 0.05/0.30 0.05/0.30
= 0.375 0.167 0.167 0.709 ← 偏小
sample_1: 0.333 0.333 0.333 1.000
sample_2: 0.292 0.500 0.500 1.292 ← 偏大
───── ───── ─────
列和: 1.000 1.000 1.000 ← ✓ 列均衡了!
✗ 但行不等了
第二步:行归一化(每行 ÷ 行和,强制行和相等)
行归一化后:
code_0 code_1 code_2 行和
sample_0: 0.529 0.235 0.235 1.00 ← ✓
sample_1: 0.333 0.333 0.333 1.00 ← ✓
sample_2: 0.226 0.387 0.387 1.00 ← ✓
───── ───── ─────
列和: 1.088 0.956 0.956 ← 又不完全等了...
但对比列和的变化:
code_0 code_1 code_2 最大差距
初始: 2.40 0.30 0.30 2.10
第 1 轮后: 1.088 0.956 0.956 0.132 ← 差距缩小了 16 倍!
第 2 轮后: ≈1.01 ≈0.99 ≈0.99 ≈0.02
...
收敛: 1.000 1.000 1.000 0.000 ← 完全均衡
收敛原理:
┌──────────────────────────────────────────────────────────────────┐
│ │
│ 列归一化 行归一化 │
│ "每个码字获得等量权重" "每个样本分配等量权重" │
│ │ │ │
│ ▼ ▼ │
│ ✓ 列均衡了 ✓ 行均衡了 │
│ ✗ 行被打乱了 ✗ 列被打乱了 │
│ │ │ │
│ └────── 但打乱的程度越来越小 ─────┘ │
│ │
│ 每一轮 "修正一个约束时对另一个约束的破坏" 都在减小 │
│ → 最终收敛到同时满足两个约束的不动点(双随机矩阵) │
│ │
└──────────────────────────────────────────────────────────────────┘
数学上,这叫做交替投影到两个凸集的交集(Sinkhorn-Knopp 定理, 1967):
- 凸集 A = {所有行和相等的矩阵}
- 凸集 B = {所有列和相等的矩阵}
- 行归一化 = 投影到 A,列归一化 = 投影到 B
- 交替投影保证收敛到 A ∩ B(双随机矩阵)
归一化是乘除操作而非赋值,因此保留了同一行内的相对偏好比例, 只有偏好最弱的样本才会被重新分配,偏好最强的样本始终不受影响。
Step3: 码字查表 & 计算损失
x_q = embedding(indices) # (batch, 32),量化后的向量
codebook_loss = MSE(x_q, x.detach()) # 让码字靠近编码器输出
commitment_loss = MSE(x_q.detach(), x) # 让编码器输出靠近码字
loss = codebook_loss + β * commitment_loss
两个 loss 的梯度流向不同,通过 .detach() 截断实现:
codebook_loss = MSE(x_q, x.detach()):x.detach()截断编码器梯度,只更新码本,让码字向编码器输出靠拢commitment_loss = MSE(x_q.detach(), x):x_q.detach()截断码本梯度,只更新编码器,让编码器输出向码字靠拢
注意每层 VQ 的输入不同,是上一层的残差而非原始编码:
| 层 | 输入 x(即残差) | codebook_loss | commitment_loss |
|---|---|---|---|
| VQ 1 | r₀ = x_e | MSE(x_q₁, r₀.detach()) | MSE(x_q₁.detach(), r₀) |
| VQ 2 | r₁ = x_e - x_q₁ | MSE(x_q₂, r₁.detach()) | MSE(x_q₂.detach(), r₁) |
| VQ 3 | r₂ = r₁ - x_q₂ | MSE(x_q₃, r₂.detach()) | MSE(x_q₃.detach(), r₂) |
最终 rq_loss = mean(loss_1, loss_2, loss_3),三层取平均值后传入总损失函数。
Step4: Straight-Through Estimator (STE)
x_q = x + (x_q - x).detach()
前向传播使用量化后的 x_q,反向传播时梯度直通到 x(绕过不可导的 argmin 操作)。
四、残差量化(RQ)的核心逻辑
残差量化的关键在于逐层逼近:
x_e = 原始编码向量
x_q_1 = VQ1 对 x_e 的最近码字 → 一级粗粒度近似
r_1 = x_e - x_q_1 → 第一层残差
x_q_2 = VQ2 对 r_1 的最近码字 → 二级修正
r_2 = r_1 - x_q_2 → 第二层残差
x_q_3 = VQ3 对 r_2 的最近码字 → 三级精修
最终: x_q = x_q_1 + x_q_2 + x_q_3 ≈ x_e
这形成了一种树状层级结构:
- 第 1 级索引:粗粒度语义类别(256 个簇)
- 第 2 级索引:在第 1 级基础上的子类别细分
- 第 3 级索引:最精细的语义区分
SID 的层级性使其天然适合推荐系统中的 beam search 检索。
五、训练过程
5.1 总损失函数
L_total = L_recon + α * L_rq
- L_recon(重建损失):
MSE(Decoder(x_q), x_input),衡量重建质量 - L_rq(量化损失):3 层 VQ 损失的均值,每层包含:
codebook_loss:拉近码字到编码器输出β * commitment_loss:拉近编码器输出到码字
- α =
quant_loss_weight(默认 1.0) - β = 0.25
5.2 训练流程
for epoch in range(epochs):
┌─── 训练阶段 ──────────────────────────────────────────┐
│ for batch in DataLoader: │
│ 1. x = batch (batch, D) │
│ 2. out, rq_loss, indices = model(x) │
│ - x_e = Encoder(x) (batch, 32) │
│ - x_q, rq_loss, indices = RQ(x_e) │
│ - out = Decoder(x_q) (batch, D) │
│ 3. loss = MSE(out, x) + 1.0 * rq_loss │
│ 4. loss.backward() │
│ 5. clip_grad_norm_(params, 1.0) │
│ 6. optimizer.step() (AdamW, lr=1e-3) │
│ 7. scheduler.step() (constant + warmup) │
└────────────────────────────────────────────────────────┘
┌─── 评估阶段(每 eval_step 个 epoch)──────────────────┐
│ for batch in DataLoader: │
│ 1. indices = model.get_indices(batch) (batch, 3) │
│ 2. 将每个 indices 转成字符串 "c1-c2-c3" 加入集合 │
│ collision_rate = (总样本数 - 唯一SID数) / 总样本数 │
│ │
│ 保存 best_loss / best_collision_rate 的 checkpoint │
└────────────────────────────────────────────────────────┘
5.3 关键训练技巧
| 技巧 | 说明 |
|---|---|
| KMeans 初始化 | 首次前向时,用 KMeans 聚类初始化码本(而非随机),加速收敛 |
| Sinkhorn 均衡化 | 通过最优传输算法促进码本利用率均匀,减少码本坍塌(详见第三章 Step2) |
| STE 梯度直通 | 离散 argmin 不可导,通过 x + (x_q - x).detach() 传递梯度 |
| 梯度裁剪 | clip_grad_norm = 1.0,防止梯度爆炸 |
| Warmup | 前 50 个 epoch 学习率 warmup |
| 碰撞率评估 | 碰撞率越低,说明不同 item 被分配到不同 SID 的能力越强 |
5.3.1 STE 梯度直通原理
问题:VQ 中 indices = argmin(d) 是离散操作,输出为整数索引,梯度处处为零或不存在,
导致反向传播到此中断,Encoder 无法获得梯度、无法训练。
前向: Encoder 输出 x ──→ [argmin 查表, 不可导] ──→ x_q ──→ Decoder
反向: x ◄── 梯度断了! ✗
解法(vq.py 第 95 行):
x_q = x + (x_q - x).detach()
前向和反向的行为不同:
前向传播(计算值):
x_q = x + (x_q - x).detach()
= x + x_q - x
= x_q ← 使用量化后的码字,前向完全正确
反向传播(计算梯度):
.detach() 使 (x_q - x) 被视为常数,梯度为 0
∂x_q/∂x = ∂[x + 常数]/∂x = 1
即: ∂L/∂x = ∂L/∂x_q × 1 = ∂L/∂x_q
梯度直接从 x_q 复制到 x,跳过 argmin
效果:
前向: x ───→ [argmin+查表] ───→ x_q ───→ Decoder ───→ out
反向: x ◄════════════════════ x_q ◄─── Decoder ◄─── ∂L/∂out
梯度直通(STE)
合理性:x_q 是 x 的最近码字,两者足够接近时,Decoder 对 x_q 的梯度 与对 x 的梯度近似相等,直接复制是合理的一阶近似。
5.3.2 Learning Rate Warmup 原理
问题:训练刚开始时模型参数随机或刚经过 KMeans 初始化,此时:
- 梯度方向不稳定,可能指向不合理的方向
- 大学习率会把参数大幅更新到糟糕的区域
- AdamW 的自适应步长初期统计量不准确,容易产生过大更新
解法:前 50 个 epoch 线性升温学习率
lr
↑
│ ┌────────────────────────── 恒定 lr = 1e-3
│ ╱
│ ╱
│ ╱ 线性升温
│ ╱
│ ╱
│ ╱
│ ╱
│ 0 ╱
└──┴──────┴──────────────────────────────────→ epoch
0 50 10000
├──────┤
warmup 阶段
- warmup 阶段(epoch 0~50):lr 从 ~0 线性增长到 1e-3。 小学习率 → 小步探索 → 让参数稳定在合理区域, 同时 Adam 的一阶/二阶矩估计有足够时间积累准确的统计量。
- 恒定阶段(epoch 50~10000):lr 保持 1e-3 不变,正常训练。
对 RQ-VAE 的特殊意义:码本用 KMeans 延迟初始化(发生在第一个 batch), warmup 确保初始化后的前几十个 epoch 内,编码器和码本以小步互相适应, 避免初期大梯度把 KMeans 得到的良好初始化"冲毁"。
5.4 评估指标:碰撞率 (Collision Rate)
碰撞率 = 两个不同的 item 被映射到同一个 SID 的比例。
collision_rate = (N_total - N_unique_SID) / N_total
- 理想情况:每个 item 有唯一 SID,碰撞率 = 0
- 码本容量上限:256 × 256 × 256 = 16,777,216 个唯一 SID
六、推理阶段:获取 SID
训练完成后,通过 model.get_indices() 获取每个 item 的 SID:
# 推理(无梯度,不使用 Sinkhorn)
indices = model.get_indices(embeddings, use_sk=False)
# indices shape: (N, 3)
# 每行即一个 item 的 Semantic ID,如 [17, 203, 89]
该 SID 可用于推荐系统中的生成式检索(Generative Retrieval),模型通过自回归方式逐级生成 SID 来检索 item。
七. EasDeepRecommand个人推荐系统开源项目介绍
链接如下:EasyDeepRecommand
一个通俗易懂的开源推荐系统(A user-friendly open-source project for recommendation systems).
本项目将结合:代码、数据流转图、博客、模型发展史 等多个方面通俗易懂地讲解经典推荐模型,让读者通过一个项目了解推荐系统概况!
持续更新中..., 欢迎star🌟, 第一时间获取更新,感谢!!!