模型输出的原始分数,是怎么变成一个个字符的?
完整流程
在字符级 RNN 生成文本时,整个过程分三步:
logits = model(h) # 第一步:输出原始分数
logits = logits / temperature # 第二步:temperature 调整
probs = torch.softmax(logits, dim=0) # 第三步:softmax 转概率
next_idx = torch.multinomial(probs, 1) # 第四步:按概率采样
next_c = dataset.idx_to_char[next_idx] # 第五步:索引转字符
这四行代码是文本生成的核心。下面逐一拆解。
第一步:logits —— 模型的原始输出
模型最后一步是一个线性层 h2o,输出是词表大小的原始分数(logits)。
假设词表有 5 个字符 ['a', 'b', 'c', 'd', 'e'],模型输出可能是:
logits = [3.0, -0.5, 1.8, 0.2, -1.0]
↑ ↑ ↑ ↑ ↑
a b c d e
这些分数本身没有明确含义:有正有负,加起来也不等于 1。不能直接当概率用。
第二步:softmax —— 把分数变成概率
softmax 的作用是把任意分数转成合法的概率分布:
- 每个值都在 (0, 1) 之间
- 所有值加起来 = 1
- 大小关系不变(原来大的还是大,原来小的还是小)
公式
softmax(x_i) = exp(x_i) / sum(exp(x_j))
先对每个元素取 e 的指数,再除以所有指数之和。
logits = [3.0, -0.5, 1.8, 0.2, -1.0]
# 取指数
exp(logits) = [20.1, 0.61, 6.05, 1.22, 0.37]
# 求和
sum = 20.1 + 0.61 + 6.05 + 1.22 + 0.37 = 28.35
# 归一化
probs = [20.1/28.35, 0.61/28.35, ...]
= [0.71, 0.02, 0.21, 0.04, 0.01]
验证:0.71 + 0.02 + 0.21 + 0.04 + 0.01 = 0.99 ≈ 1.0 ✓
现在字符 'a' 的概率是 71%,'c' 是 21%,其他都很低。模型大概率会选 'a'。
为什么要取指数
logits 可能有负数,没法直接当概率。取指数后:
- 负数变成 0~1 之间的小数
- 正数变成大于 1 的数
- 原来大的变得更大,原来小的变得更小
这放大了分数之间的差距,让概率分布更"有区分度"。
第三步:temperature —— 控制概率的"陡峭程度"
softmax 后的概率分布形态决定了生成的随机程度。temperature 在 softmax 之前调整 logits,从而控制这个形态。
temperature = 1.0(默认,不做调整)
logits: [3.0, -0.5, 1.8, 0.2, -1.0]
probs: [0.71, 0.02, 0.21, 0.04, 0.01]
'a' 的概率 71%,'c' 的概率 21%,差距明显。
temperature = 0.8(推荐默认值)
logits: [3.0, -0.5, 1.8, 0.2, -1.0]
÷ 0.8: [3.75, -0.63, 2.25, 0.25, -1.25] ← 差距略微放大
probs: [0.77, 0.01, 0.17, 0.02, 0.004] ← 更确定一些
除以 0.8(小于 1),分数差距被轻微放大。'a' 的概率从 71% 升到 77%,生成文本更通顺,同时保留了合理的随机性。
temperature = 2.0(更随机)
logits: [3.0, -0.5, 1.8, 0.2, -1.0]
÷ 2.0: [1.5, -0.25, 0.9, 0.1, -0.5] ← 差距缩小
probs: [0.45, 0.08, 0.25, 0.11, 0.06] ← 更均匀
除以一个大数,分数差距被缩小。'a' 的概率从 71% 降到 45%,低概率字符的机会增加了。生成文本更有变化,但可能出现不通顺的内容。
temperature = 0.3(极度确定)
logits: [3.0, -0.5, 1.8, 0.2, -1.0]
÷ 0.3: [10.0, -1.67, 6.0, 0.67, -3.33] ← 差距大幅放大
probs: [0.997, 0.0004, 0.002, 0.0004, 0] ← 几乎只选'a'
除以极小的数,分数差距被剧烈放大。'a' 的概率接近 100%,几乎一定会选它。生成文本极度通顺,但会大量重复。
为什么必须在 softmax 之前除
如果先 softmax 再除以 temperature:
probs = softmax([3.0, ...]) # [0.71, 0.02, ...]
probs / 2.0 = [0.36, 0.01, ...] ← 加起来 ≠ 1,不是合法概率!
而在 softmax 之前除,softmax 会自动重新归一化,保证结果合法。
temperature 调参指南
| 范围 | 效果 | 适用场景 |
|---|---|---|
| 0.2 ~ 0.5 | 保守,重复性高 | 需要稳定输出 |
| 0.7 ~ 0.9 | 平衡,合理又有变化 | 文本生成(推荐) |
| 1.0 ~ 1.5 | 有创意,偶尔出错 | 探索性生成 |
| > 2.0 | 太随机,基本乱字 | 不推荐 |
第四步:multinomial —— 按概率采样
有了概率分布,接下来就是"抽奖"——按概率随机选一个字符。
probs = [0.77, 0.01, 0.17, 0.02, 0.004] # temperature = 0.8
# 'a' 'b' 'c' 'd' 'e'
next_idx = torch.multinomial(probs, 1)
# 大概率返回 0(选'a'),偶尔返回 2(选'c'),很少返回其他
采样逻辑
multinomial 就像转盘抽奖:
╭─────────────────────────╮
╱ a 77% ╲
│ │
│ c 17% │ d 2% │
│ │
│ b 1% │ e 0.4% │
│ │
╲ ╱
╰───────────────────────╯
转盘上 'a' 占了绝大部分面积(77%),指针大概率转到 'a'。但偶尔也会转到 'c'(17%),极偶尔转到 'd'(2%)。
如果跑 100 次,大概会是:
- 'a':约 77 次
- 'c':约 17 次
- 'd':约 2 次
- 'b':约 1 次
- 'e':约 0 次
为什么不用 argmax
argmax 永远选概率最高的那个:
torch.argmax(probs) # 永远是 0('a')
生成的文本就是 'aaaaaaa...',完全重复。
multinomial 加入了随机性,让低概率字符也有机会出现,文本才有变化和可读性。
总结:四步协作
原始分数 调整陡峭程度 转概率 按概率随机选
┌───────────────┐ ┌────────────────┐ ┌───────────────┐ ┌────────────────┐
│ logits │─→│ temperature │─→│ softmax │─→│ multinomial │
├───────────────┤ ├────────────────┤ ├───────────────┤ ├────────────────┤
│ [3.0, -0.5, │ │ [3.75, -0.63, │ │ [0.77, 0.01, │ │ 0 ││ 1.8, 0.2, │ │ 2.25, 0.25, │ │ 0.17, 0.02, │ │ ('a') ││ -1.0] │ │ -1.25] │ │ 0.004] │ │ │
└───────────────┘ └────────────────┘ └───────────────┘ └────────────────┘
- logits:模型对每个字符的"喜好程度"打分
- temperature:调整分数差距,控制创造力
- softmax:把分数变成合法概率(0~1,和为 1)
- multinomial:按概率转盘抽奖,选出下一个字符
这四步缺一不可,共同完成了从"模型的猜测"到"生成的字符"的转换。