白话生成式推荐二:MiniOneRec之RQ-VAE

0 阅读14分钟

RQ-VAE:文本向量 → Semantic ID (SID) 全流程分析

一、整体架构概览

RQ-VAE(Residual Quantized Variational AutoEncoder)将高维连续的文本 Embedding 向量压缩为离散的多级码本索引,即 Semantic ID (SID)

核心思想:通过 Encoder 降维 → 残差向量量化(多层 VQ 逐级逼近)→ Decoder 重建,以自监督方式训练,使离散编码尽可能保留原始语义信息。

模块组成

rqvae.py          # 训练入口,参数配置、数据加载、启动训练
models/rqvae.py   # RQVAE 模型:Encoder + RQ + Decoder
models/rq.py      # ResidualVectorQuantizer:多层残差量化
models/vq.py      # VectorQuantizer:单层向量量化(码本查找)
models/layers.py  # MLPLayers、KMeans、Sinkhorn 等基础组件
datasets.py       # EmbDataset:加载 .npy embedding 数据
trainer.py        # 训练循环、评估(碰撞率)、checkpoint 管理

二、文本向量 → SID 转化流程(维度变化)

以默认配置为例:

  • in_dim = D(由 Qwen 模型决定,如 Qwen-1.8B 为 2048,Qwen-7B 为 4096)
  • layers = [2048, 1024, 512, 256, 128, 64]
  • e_dim = 32
  • num_emb_list = [256, 256, 256](3 层 VQ,每层码本大小 256)

完整流程图

输入 Item Embedding
        │
        ▼
┌─────────────────────────────────────────────────────────────────┐
│  Encoder (MLPLayers)                                            │
│                                                                 │
│  (batch, D) ─Linear→ (batch, 2048) ─ReLU→                      │
│              ─Linear→ (batch, 1024) ─ReLU→                      │
│              ─Linear→ (batch,  512) ─ReLU→                      │
│              ─Linear→ (batch,  256) ─ReLU→                      │
│              ─Linear→ (batch,  128) ─ReLU→                      │
│              ─Linear→ (batch,   64) ─ReLU→                      │
│              ─Linear→ (batch,   32)          ← 最后一层无激活函数  │
│                                                                 │
│  输出: x_e, shape = (batch, 32)                                  │
└─────────────────────────────────────────────────────────────────┘
        │
        ▼
┌─────────────────────────────────────────────────────────────────┐
│  Residual Vector Quantization (3 层 VQ 残差量化)                  │
│                                                                 │
│  初始: residual = x_e                    (batch, 32)             │
│        x_q = 0                                                  │
│                                                                 │
│  ┌─── VQ Layer 1 (码本: 256 × 32) ───────────────────────────┐  │
│  │ 输入: residual                        (batch, 32)          │  │
│  │ 计算 L2 距离: d = ||r - c||²          (batch, 256)         │  │
│  │ 查找最近码字:  indices_1 = argmin(d)   (batch,)            │  │
│  │ 码字查表:      x_q_1 = Embed(indices)  (batch, 32)         │  │
│  │ 更新残差:      residual = residual - x_q_1                 │  │
│  │ 累加量化:      x_q = x_q + x_q_1                           │  │
│  └────────────────────────────────────────────────────────────┘  │
│                          │                                       │
│                          ▼                                       │
│  ┌─── VQ Layer 2 (码本: 256 × 32) ───────────────────────────┐  │
│  │ 输入: residual (Layer1 的残差)         (batch, 32)          │  │
│  │ 计算 L2 距离: d = ||r - c||²          (batch, 256)         │  │
│  │ 查找最近码字:  indices_2 = argmin(d)   (batch,)            │  │
│  │ 码字查表:      x_q_2 = Embed(indices)  (batch, 32)         │  │
│  │ 更新残差:      residual = residual - x_q_2                 │  │
│  │ 累加量化:      x_q = x_q + x_q_2                           │  │
│  └────────────────────────────────────────────────────────────┘  │
│                          │                                       │
│                          ▼                                       │
│  ┌─── VQ Layer 3 (码本: 256 × 32) ───────────────────────────┐  │
│  │ 输入: residual (Layer2 的残差)         (batch, 32)          │  │
│  │ 计算 L2 距离: d = ||r - c||²          (batch, 256)         │  │
│  │ 查找最近码字:  indices_3 = argmin(d)   (batch,)            │  │
│  │ 码字查表:      x_q_3 = Embed(indices)  (batch, 32)         │  │
│  │ 更新残差:      residual = residual - x_q_3                 │  │
│  │ 累加量化:      x_q = x_q + x_q_3                           │  │
│  └────────────────────────────────────────────────────────────┘  │
│                                                                  │
│  输出:                                                           │
│    x_q     = x_q_1 + x_q_2 + x_q_3     (batch, 32)             │
│    indices = stack([idx1, idx2, idx3])   (batch, 3)  ← 即 SID   │
│    rq_loss = mean(loss_1, loss_2, loss_3)                        │
└──────────────────────────────────────────────────────────────────┘
        │
        ▼
┌─────────────────────────────────────────────────────────────────┐
│  Decoder (MLPLayers,Encoder 的镜像,输入为x_q)                    │
│                                                                 │
│  (batch, 32) ─Linear→ (batch,   64) ─ReLU→                      │
│              ─Linear→ (batch,  128) ─ReLU→                      │
│              ─Linear→ (batch,  256) ─ReLU→                      │
│              ─Linear→ (batch,  512) ─ReLU→                      │
│              ─Linear→ (batch, 1024) ─ReLU→                      │
│              ─Linear→ (batch, 2048) ─ReLU→                      │
│              ─Linear→ (batch,    D)          ← 重建原始维度       │
│                                                                 │
│  输出: out, shape = (batch, D)                                   │
└─────────────────────────────────────────────────────────────────┘
        │
        ▼
  SID 结果: (batch, 3)
  每个 item 得到一个三元组 (c1, c2, c3),ci ∈ [0, 255]
  例如: item_42 → SID = (17, 203, 89)

维度变化总结表

阶段操作输入维度输出维度
输入从 .npy 加载-(batch, D)
Encoder7 层 MLP 逐步降维(batch, D)(batch, 32)
RQ-VQ×3残差量化,码本查找(batch, 32)x_q: (batch, 32), indices: (batch, 3)
Decoder7 层 MLP 逐步升维(batch, 32)(batch, D)
最终 SID取 indices-(batch, 3),每个值 ∈ [0, 255]

三、单层 VQ 量化详细过程

以 VQ Layer 1 为例(vq.py),输入 x 的形状为 (batch, 32)

Step0: 码本初始化

码本有两种初始化方式,由参数 kmeans_init 控制(默认为 True):

方式一:均匀随机初始化(kmeans_init=False

直接用 uniform_(-1/n_e, 1/n_e) 小范围均匀分布初始化码本权重。

方式二:KMeans 延迟初始化(kmeans_init=True,默认)

构造时先将码本置零,标记 initted = False。在训练阶段的首次前向传播时, 对第一个 batch 的编码向量执行 KMeans 聚类(默认 100 次迭代), 将得到的 256 个聚类中心直接作为码本初始值:

首次 forward 触发:
  latent (batch, 32) ──KMeans(n_clusters=256, max_iter=100)──→ centers (256, 32)
  embedding.weight ← centers
  initted = True (后续 forward 不再触发)

KMeans 初始化的优势:码本从一开始就分布在数据的实际区域中,比随机初始化收敛更快、码本利用率更高。

Step1: 计算 L2 距离

d(i,j) = ||x_i||² + ||c_j||² - 2 * x_i · c_j
  • x: (batch, 32),codebook: (256, 32)
  • 输出 d: (batch, 256),每个样本到 256 个码字的距离

Step2: 查找最近码字

  • 无 Sinkhornsk_epsilon <= 0): indices = argmin(d, dim=-1) → (batch,)
  • 有 Sinkhornsk_epsilon > 0): 先对距离做中心化归一化,再通过 Sinkhorn 迭代求软分配矩阵 Q,最后 indices = argmax(Q, dim=-1) → (batch,)

Sinkhorn 的作用:促进码本利用均匀化,避免码本坍塌(少数码字被频繁使用,其余码字浪费)。

什么是码本坍塌?

码本坍塌是向量量化中的经典问题:训练过程中只有少数码字被反复选中,其余大部分码字永远不会被匹配到。

理想情况(均匀利用):           坍塌情况:
码字 0:  ████ (50个样本)        码字 0:  ████████████ (800个样本)
码字 1:  ████ (48个样本)        码字 1:  ██████████ (600个样本)
码字 2:  ████ (52个样本)        码字 2:   (10个样本)
...                             ...
码字 255: ████ (49个样本)       码字 255: (0个样本)  死码字

恶性循环:被选中的码字持续获得梯度更新变得更好 → 吸引更多样本 → 其他码字永远得不到更新 → 码本容量严重浪费。

Sinkhorn 均衡化算法

Sinkhorn 将样本-码字分配建模为最优传输问题,通过行列交替归一化强制每个码字被近似均匀地分配到。

算法流程

Step1: 距离中心化 → 将距离矩阵 d (B, K) 归一化到 [-1, 1],防止 exp 溢出
Step2: 构造软分配矩阵 Q = exp(-d / ε),距离越小 Q 值越大(越可能分配)
Step3: 交替归一化迭代:
       重复 sk_iters 次:
         ① 行归一化: Q /= Q.sum(dim=1) / B  → 约束每个样本的分配权重之和相等
         ② 列归一化: Q /= Q.sum(dim=0) / K  → 约束每个码字被分配到的权重之和相等
Step4: Q *= B,取 indices = argmax(Q, dim=-1)
Sinkhorn 均衡化实例

场景:B=4 个样本,K=2 个码字

各样本到码字的距离矩阵 d (4, 2)(已中心化到 [-1, 1]):

              code_0    code_1
sample_0:     -0.8       0.8     ← 强烈偏好 code_0
sample_1:     -0.4       0.4     ← 明显偏好 code_0
sample_2:     -0.2       0.2     ← 微弱偏好 code_0(关键样本)
sample_3:      0.6      -0.6     ← 偏好 code_1

无 Sinkhorn(argmin 硬分配)

sample_0 → code_0, sample_1 → code_0, sample_2 → code_0, sample_3 → code_1
结果: code_0 获得 3 个样本, code_1 获得 1 个样本 → 不均衡!

有 Sinkhorn(ε=0.5)

Step1 — 构造 Q = exp(-d / 0.5),距离越小 Q 值越大:

              code_0    code_1
sample_0:     4.953      0.202    ← 强烈倾向 code_0
sample_1:     2.226      0.449
sample_2:     1.492      0.670    ← 对 code_0 的倾向不强
sample_3:     0.301      3.320    ← 倾向 code_1

归一化使总和=1 后,观察列和:code_0 占 65.9%,code_1 仅 34.1% → 严重不均。

Step2 — 交替归一化迭代,观察 sample_2(偏好最弱的样本)的变化趋势:

              sample_2 对 code_0 的权重    sample_2 对 code_1 的权重
初始:              0.690                       0.310
第 1 轮后:         0.554                       0.446
第 2 轮后:         0.487                       0.513   ← 翻转! code_1 超过 code_0
收敛后:            ~0.42                       ~0.58

Step3 — 最终 Q × B,取 argmax:

              code_0    code_1    argmax
sample_0:     ~0.95      ~0.05    → code_0  ✓ (偏好最强,不变)
sample_1:     ~0.75      ~0.25    → code_0  ✓ (偏好较强,不变)
sample_2:     ~0.42      ~0.58    → code_1  ★ 被重新分配!
sample_3:     ~0.05      ~0.95    → code_1  ✓ (不变)

结果: code_0 获得 2 个样本, code_1 获得 2 个样本 → 均衡!

对比总结

           无 Sinkhorn (argmin)        有 Sinkhorn (最优传输)
sample_0:  → code_0                    → code_0  (不变)
sample_1:  → code_0                    → code_0  (不变)
sample_2:  → code_0  ← 跟风去了0       → code_1  ← 被均衡化到1
sample_3:  → code_1                    → code_1  (不变)
─────────────────────────────────────────────────────────────
code_0:    3 个样本 (75%)              2 个样本 (50%)
code_1:    1 个样本 (25%)              2 个样本 (50%)

核心直觉:sample_2 对 code_0 的偏好最弱(距离差最小),是"最不坚定"的样本。 Sinkhorn 通过全局均衡约束,把它"让渡"给了样本不足的 code_1,实现码本均匀利用。 偏好越强的样本(如 sample_0)越不容易被重新分配。

为什么行列交替归一化能实现均衡?

用一个 3×3 的例子展示。3 个样本都偏好 code_0,初始分配极不均衡:

初始 Q:
            code_0    code_1    code_2    行和
sample_0:   0.90      0.05      0.05      1.00
sample_1:   0.80      0.10      0.10      1.00
sample_2:   0.70      0.15      0.15      1.00
            ────      ────      ────
列和:        2.40      0.30      0.30     ← code_0 占 80%,严重不均
目标列和:    1.00      1.00      1.00

第一步:列归一化(每列 ÷ 列和,强制列和相等)

列归一化后:
            code_0        code_1        code_2        行和
sample_0:   0.90/2.40     0.05/0.30     0.05/0.30
          = 0.375         0.167         0.167         0.709  ← 偏小
sample_1:   0.333         0.333         0.333         1.000
sample_2:   0.292         0.500         0.500         1.292  ← 偏大
            ─────         ─────         ─────
列和:        1.000         1.000         1.000        ← ✓ 列均衡了!
                                                       ✗ 但行不等了

第二步:行归一化(每行 ÷ 行和,强制行和相等)

行归一化后:
            code_0        code_1        code_2        行和
sample_0:   0.529         0.235         0.235         1.00   ← ✓
sample_1:   0.333         0.333         0.333         1.00   ← ✓
sample_2:   0.226         0.387         0.387         1.00   ← ✓
            ─────         ─────         ─────
列和:        1.088         0.956         0.956        ← 又不完全等了...

但对比列和的变化

                  code_0    code_1    code_2    最大差距
初始:              2.40      0.30      0.30      2.10
第 1 轮后:         1.088     0.956     0.956     0.132   ← 差距缩小了 16 倍!
第 2 轮后:         ≈1.01     ≈0.99     ≈0.99     ≈0.02
...
收敛:              1.000     1.000     1.000     0.000   ← 完全均衡

收敛原理

┌──────────────────────────────────────────────────────────────────┐
│                                                                  │
│    列归一化                          行归一化                     │
│    "每个码字获得等量权重"             "每个样本分配等量权重"       │
│         │                                │                       │
│         ▼                                ▼                       │
│    ✓ 列均衡了                        ✓ 行均衡了                   │
│    ✗ 行被打乱了                      ✗ 列被打乱了                 │
│         │                                │                       │
│         └────── 但打乱的程度越来越小 ─────┘                       │
│                                                                  │
│    每一轮 "修正一个约束时对另一个约束的破坏" 都在减小               │
│    → 最终收敛到同时满足两个约束的不动点(双随机矩阵)              │
│                                                                  │
└──────────────────────────────────────────────────────────────────┘

数学上,这叫做交替投影到两个凸集的交集(Sinkhorn-Knopp 定理, 1967):

  • 凸集 A = {所有行和相等的矩阵}
  • 凸集 B = {所有列和相等的矩阵}
  • 行归一化 = 投影到 A,列归一化 = 投影到 B
  • 交替投影保证收敛到 A ∩ B(双随机矩阵)

归一化是乘除操作而非赋值,因此保留了同一行内的相对偏好比例, 只有偏好最弱的样本才会被重新分配,偏好最强的样本始终不受影响。

Step3: 码字查表 & 计算损失

x_q = embedding(indices)          # (batch, 32),量化后的向量
codebook_loss = MSE(x_q, x.detach())      # 让码字靠近编码器输出
commitment_loss = MSE(x_q.detach(), x)     # 让编码器输出靠近码字
loss = codebook_loss + β * commitment_loss

两个 loss 的梯度流向不同,通过 .detach() 截断实现:

  • codebook_loss = MSE(x_q, x.detach())x.detach() 截断编码器梯度,只更新码本,让码字向编码器输出靠拢
  • commitment_loss = MSE(x_q.detach(), x)x_q.detach() 截断码本梯度,只更新编码器,让编码器输出向码字靠拢

注意每层 VQ 的输入不同,是上一层的残差而非原始编码:

输入 x(即残差)codebook_losscommitment_loss
VQ 1r₀ = x_eMSE(x_q₁, r₀.detach())MSE(x_q₁.detach(), r₀)
VQ 2r₁ = x_e - x_q₁MSE(x_q₂, r₁.detach())MSE(x_q₂.detach(), r₁)
VQ 3r₂ = r₁ - x_q₂MSE(x_q₃, r₂.detach())MSE(x_q₃.detach(), r₂)

最终 rq_loss = mean(loss_1, loss_2, loss_3),三层取平均值后传入总损失函数。

Step4: Straight-Through Estimator (STE)

x_q = x + (x_q - x).detach()

前向传播使用量化后的 x_q,反向传播时梯度直通到 x(绕过不可导的 argmin 操作)。


四、残差量化(RQ)的核心逻辑

残差量化的关键在于逐层逼近

x_e  = 原始编码向量
x_q_1 = VQ1 对 x_e 的最近码字           → 一级粗粒度近似
r_1   = x_e - x_q_1                    → 第一层残差
x_q_2 = VQ2 对 r_1 的最近码字           → 二级修正
r_2   = r_1 - x_q_2                    → 第二层残差
x_q_3 = VQ3 对 r_2 的最近码字           → 三级精修

最终: x_q = x_q_1 + x_q_2 + x_q_3 ≈ x_e

这形成了一种树状层级结构

  • 第 1 级索引:粗粒度语义类别(256 个簇)
  • 第 2 级索引:在第 1 级基础上的子类别细分
  • 第 3 级索引:最精细的语义区分

SID 的层级性使其天然适合推荐系统中的 beam search 检索。


五、训练过程

5.1 总损失函数

L_total = L_recon + α * L_rq
  • L_recon(重建损失):MSE(Decoder(x_q), x_input),衡量重建质量
  • L_rq(量化损失):3 层 VQ 损失的均值,每层包含:
    • codebook_loss:拉近码字到编码器输出
    • β * commitment_loss:拉近编码器输出到码字
  • α = quant_loss_weight(默认 1.0)
  • β = 0.25

5.2 训练流程

for epoch in range(epochs):

    ┌─── 训练阶段 ──────────────────────────────────────────┐
    │ for batch in DataLoader:                               │
    │   1. x = batch                            (batch, D)   │
    │   2. out, rq_loss, indices = model(x)                  │
    │      - x_e = Encoder(x)                  (batch, 32)   │
    │      - x_q, rq_loss, indices = RQ(x_e)                 │
    │      - out = Decoder(x_q)                (batch, D)    │
    │   3. loss = MSE(out, x) + 1.0 * rq_loss               │
    │   4. loss.backward()                                   │
    │   5. clip_grad_norm_(params, 1.0)                      │
    │   6. optimizer.step() (AdamW, lr=1e-3)                 │
    │   7. scheduler.step() (constant + warmup)              │
    └────────────────────────────────────────────────────────┘

    ┌─── 评估阶段(每 eval_step 个 epoch)──────────────────┐
    │ for batch in DataLoader:                               │
    │   1. indices = model.get_indices(batch)   (batch, 3)   │
    │   2. 将每个 indices 转成字符串 "c1-c2-c3" 加入集合      │
    │ collision_rate = (总样本数 - 唯一SID数) / 总样本数      │
    │                                                        │
    │ 保存 best_loss / best_collision_rate 的 checkpoint      │
    └────────────────────────────────────────────────────────┘

5.3 关键训练技巧

技巧说明
KMeans 初始化首次前向时,用 KMeans 聚类初始化码本(而非随机),加速收敛
Sinkhorn 均衡化通过最优传输算法促进码本利用率均匀,减少码本坍塌(详见第三章 Step2)
STE 梯度直通离散 argmin 不可导,通过 x + (x_q - x).detach() 传递梯度
梯度裁剪clip_grad_norm = 1.0,防止梯度爆炸
Warmup前 50 个 epoch 学习率 warmup
碰撞率评估碰撞率越低,说明不同 item 被分配到不同 SID 的能力越强
5.3.1 STE 梯度直通原理

问题:VQ 中 indices = argmin(d) 是离散操作,输出为整数索引,梯度处处为零或不存在, 导致反向传播到此中断,Encoder 无法获得梯度、无法训练。

前向: Encoder 输出 x ──→ [argmin 查表, 不可导] ──→ x_q ──→ Decoder
反向:               x ◄── 梯度断了! ✗

解法vq.py 第 95 行):

x_q = x + (x_q - x).detach()

前向和反向的行为不同:

前向传播(计算值):
  x_q = x + (x_q - x).detach()
      = x + x_q - x
      = x_q                    ← 使用量化后的码字,前向完全正确

反向传播(计算梯度):
  .detach() 使 (x_q - x) 被视为常数,梯度为 0
  ∂x_q/∂x = ∂[x + 常数]/∂x = 1
  即: ∂L/∂x = ∂L/∂x_q × 1 = ∂L/∂x_q
              梯度直接从 x_q 复制到 x,跳过 argmin

效果:

前向:  x ───→ [argmin+查表] ───→ x_q ───→ Decoder ───→ out
反向:  x ◄════════════════════ x_q ◄─── Decoder ◄─── ∂L/∂out
              梯度直通(STE)

合理性:x_q 是 x 的最近码字,两者足够接近时,Decoder 对 x_q 的梯度 与对 x 的梯度近似相等,直接复制是合理的一阶近似。

5.3.2 Learning Rate Warmup 原理

问题:训练刚开始时模型参数随机或刚经过 KMeans 初始化,此时:

  • 梯度方向不稳定,可能指向不合理的方向
  • 大学习率会把参数大幅更新到糟糕的区域
  • AdamW 的自适应步长初期统计量不准确,容易产生过大更新

解法:前 50 个 epoch 线性升温学习率

lr
 ↑
 │                  ┌────────────────────────── 恒定 lr = 1e-3
 │                 ╱
 │                ╱
 │               ╱  线性升温
 │              ╱
 │             ╱
 │            ╱
 │           ╱
 │  0       ╱
 └──┴──────┴──────────────────────────────────→ epoch
    0      50                                  10000
    ├──────┤
    warmup 阶段
  • warmup 阶段(epoch 0~50):lr 从 ~0 线性增长到 1e-3。 小学习率 → 小步探索 → 让参数稳定在合理区域, 同时 Adam 的一阶/二阶矩估计有足够时间积累准确的统计量。
  • 恒定阶段(epoch 50~10000):lr 保持 1e-3 不变,正常训练。

对 RQ-VAE 的特殊意义:码本用 KMeans 延迟初始化(发生在第一个 batch), warmup 确保初始化后的前几十个 epoch 内,编码器和码本以小步互相适应, 避免初期大梯度把 KMeans 得到的良好初始化"冲毁"。

5.4 评估指标:碰撞率 (Collision Rate)

碰撞率 = 两个不同的 item 被映射到同一个 SID 的比例。

collision_rate = (N_total - N_unique_SID) / N_total
  • 理想情况:每个 item 有唯一 SID,碰撞率 = 0
  • 码本容量上限:256 × 256 × 256 = 16,777,216 个唯一 SID

六、推理阶段:获取 SID

训练完成后,通过 model.get_indices() 获取每个 item 的 SID:

# 推理(无梯度,不使用 Sinkhorn)
indices = model.get_indices(embeddings, use_sk=False)
# indices shape: (N, 3)
# 每行即一个 item 的 Semantic ID,如 [17, 203, 89]

该 SID 可用于推荐系统中的生成式检索(Generative Retrieval),模型通过自回归方式逐级生成 SID 来检索 item。

七. EasDeepRecommand个人推荐系统开源项目介绍

在这里插入图片描述

链接如下:EasyDeepRecommand

一个通俗易懂的开源推荐系统(A user-friendly open-source project for recommendation systems).

本项目将结合:代码、数据流转图、博客、模型发展史 等多个方面通俗易懂地讲解经典推荐模型,让读者通过一个项目了解推荐系统概况!

持续更新中..., 欢迎star🌟, 第一时间获取更新,感谢!!!