从分数到字符:文本生成中的 softmax、temperature 和 multinomial

6 阅读4分钟

模型输出的原始分数,是怎么变成一个个字符的?

完整流程

在字符级 RNN 生成文本时,整个过程分三步:

logits = model(h)                        # 第一步:输出原始分数
logits = logits / temperature            # 第二步:temperature 调整
probs = torch.softmax(logits, dim=0)     # 第三步:softmax 转概率
next_idx = torch.multinomial(probs, 1)   # 第四步:按概率采样
next_c = dataset.idx_to_char[next_idx]   # 第五步:索引转字符

这四行代码是文本生成的核心。下面逐一拆解。

第一步:logits —— 模型的原始输出

模型最后一步是一个线性层 h2o,输出是词表大小的原始分数(logits)。

假设词表有 5 个字符 ['a', 'b', 'c', 'd', 'e'],模型输出可能是:

logits = [3.0, -0.5, 1.8, 0.2, -1.0]
          ↑    ↑     ↑    ↑    ↑
          a    b     c    d    e

这些分数本身没有明确含义:有正有负,加起来也不等于 1。不能直接当概率用。

第二步:softmax —— 把分数变成概率

softmax 的作用是把任意分数转成合法的概率分布:

  1. 每个值都在 (0, 1) 之间
  2. 所有值加起来 = 1
  3. 大小关系不变(原来大的还是大,原来小的还是小)

公式

softmax(x_i) = exp(x_i) / sum(exp(x_j))

先对每个元素取 e 的指数,再除以所有指数之和。

logits = [3.0, -0.5, 1.8, 0.2, -1.0]

# 取指数
exp(logits) = [20.1, 0.61, 6.05, 1.22, 0.37]

# 求和
sum = 20.1 + 0.61 + 6.05 + 1.22 + 0.37 = 28.35

# 归一化
probs = [20.1/28.35, 0.61/28.35, ...]
      = [0.71, 0.02, 0.21, 0.04, 0.01]

验证:0.71 + 0.02 + 0.21 + 0.04 + 0.01 = 0.99 ≈ 1.0 ✓

现在字符 'a' 的概率是 71%,'c' 是 21%,其他都很低。模型大概率会选 'a'。

为什么要取指数

logits 可能有负数,没法直接当概率。取指数后:

  • 负数变成 0~1 之间的小数
  • 正数变成大于 1 的数
  • 原来大的变得更大,原来小的变得更小

这放大了分数之间的差距,让概率分布更"有区分度"。

第三步:temperature —— 控制概率的"陡峭程度"

softmax 后的概率分布形态决定了生成的随机程度。temperature 在 softmax 之前调整 logits,从而控制这个形态。

temperature = 1.0(默认,不做调整)

logits:     [3.0, -0.5, 1.8, 0.2, -1.0]
probs:      [0.71, 0.02, 0.21, 0.04, 0.01]

'a' 的概率 71%,'c' 的概率 21%,差距明显。

temperature = 0.8(推荐默认值)

logits:     [3.0, -0.5, 1.8, 0.2, -1.0]
÷ 0.8:      [3.75, -0.63, 2.25, 0.25, -1.25]  ← 差距略微放大
probs:      [0.77, 0.01, 0.17, 0.02, 0.004]   ← 更确定一些

除以 0.8(小于 1),分数差距被轻微放大。'a' 的概率从 71% 升到 77%,生成文本更通顺,同时保留了合理的随机性。

temperature = 2.0(更随机)

logits:     [3.0, -0.5, 1.8, 0.2, -1.0]
÷ 2.0:      [1.5, -0.25, 0.9, 0.1, -0.5]   ← 差距缩小
probs:      [0.45, 0.08, 0.25, 0.11, 0.06]  ← 更均匀

除以一个大数,分数差距被缩小。'a' 的概率从 71% 降到 45%,低概率字符的机会增加了。生成文本更有变化,但可能出现不通顺的内容。

temperature = 0.3(极度确定)

logits:     [3.0, -0.5, 1.8, 0.2, -1.0]
÷ 0.3:      [10.0, -1.67, 6.0, 0.67, -3.33]  ← 差距大幅放大
probs:      [0.997, 0.0004, 0.002, 0.0004, 0]  ← 几乎只选'a'

除以极小的数,分数差距被剧烈放大。'a' 的概率接近 100%,几乎一定会选它。生成文本极度通顺,但会大量重复。

为什么必须在 softmax 之前除

如果先 softmax 再除以 temperature:

probs = softmax([3.0, ...])  # [0.71, 0.02, ...]
probs / 2.0 = [0.36, 0.01, ...]  ← 加起来 ≠ 1,不是合法概率!

而在 softmax 之前除,softmax 会自动重新归一化,保证结果合法。

temperature 调参指南

范围效果适用场景
0.2 ~ 0.5保守,重复性高需要稳定输出
0.7 ~ 0.9平衡,合理又有变化文本生成(推荐)
1.0 ~ 1.5有创意,偶尔出错探索性生成
> 2.0太随机,基本乱字不推荐

第四步:multinomial —— 按概率采样

有了概率分布,接下来就是"抽奖"——按概率随机选一个字符。

probs = [0.77, 0.01, 0.17, 0.02, 0.004]  # temperature = 0.8
#        'a'   'b'   'c'   'd'   'e'

next_idx = torch.multinomial(probs, 1)
# 大概率返回 0(选'a'),偶尔返回 2(选'c'),很少返回其他

采样逻辑

multinomial 就像转盘抽奖:

         ╭─────────────────────────╮
        ╱        a  77%            ╲
       │                           │
       │    c  17%    │   d  2%    │
       │                           │
       │  b  1%   │   e  0.4%      │
       │                           │
        ╲                         ╱
         ╰───────────────────────╯

转盘上 'a' 占了绝大部分面积(77%),指针大概率转到 'a'。但偶尔也会转到 'c'(17%),极偶尔转到 'd'(2%)。

如果跑 100 次,大概会是:

  • 'a':约 77 次
  • 'c':约 17 次
  • 'd':约 2 次
  • 'b':约 1 次
  • 'e':约 0 次

为什么不用 argmax

argmax 永远选概率最高的那个:

torch.argmax(probs)  # 永远是 0('a')

生成的文本就是 'aaaaaaa...',完全重复。

multinomial 加入了随机性,让低概率字符也有机会出现,文本才有变化和可读性。

总结:四步协作

     原始分数          调整陡峭程度            转概率             按概率随机选
┌───────────────┐  ┌────────────────┐  ┌───────────────┐  ┌────────────────┐
│    logits     │─→│  temperature   │─→│    softmax    │─→│   multinomial  │
├───────────────┤  ├────────────────┤  ├───────────────┤  ├────────────────┤
│ [3.0, -0.5,   │  │ [3.75, -0.63,  │  │ [0.77, 0.01,  │  │       0        ││  1.8, 0.2,    │  │  2.25, 0.25,   │  │  0.17, 0.02,  │  │     ('a')      ││  -1.0]        │  │  -1.25]        │  │  0.004]       │  │                │
└───────────────┘  └────────────────┘  └───────────────┘  └────────────────┘
  • logits:模型对每个字符的"喜好程度"打分
  • temperature:调整分数差距,控制创造力
  • softmax:把分数变成合法概率(0~1,和为 1)
  • multinomial:按概率转盘抽奖,选出下一个字符

这四步缺一不可,共同完成了从"模型的猜测"到"生成的字符"的转换。