Time embeding,对时间进行编码?

914 阅读3分钟

上篇文章讲到了lantent diffusion model,有没有人注意到模型去噪过程中需要理解到时间T的信息呢?它就像我们看待墨水在清水中扩散一样,如果你能知道墨水已经扩散到了哪一个时间点,是不是就能更好地去预测之前墨水的扩散程度了呢。让我们一起看看在lantent diffusion model中是怎么运用这一思想的吧!🪂🪂🪂

Time embeding?position embeding!

在认识time embeding之前,必须先介绍一下它的前身:Transformer中的position embeding,相信这将有助于你更好地认识它。

Transformer是一种用于自然语言处理(NLP)和其他序列到序列任务的深度学习模型架构,这意味着它能很好地理解自然语言中每个词在句子中出现的位置,每个词的前后关系,而完成这一效果的板块正是position embeding。如果不添加位置编码,那么无论单词在什么位置,它的注意力分数都是确定的,这不是我们想要的。position embeding为了实现这一效果,为每个输入的词嵌入添加了一个向量,这样使得词的前后关系能更好地表现出来。同时,词与位置的编码是相加而不是拼接,因为两者虽然效率一样,但拼接会使维度上升,影响到模型结构。

Time embeding实现

def pos_encoding(t, channels):
    t = t.unsqueeze(-1).type(torch.float)  # 确保t是浮点类型并增加一个维度
    inv_freq = 1.0 / (
        10000
        ** (torch.arange(0, channels, 2, device=t.device).float() / channels)
    )
    t_expand = t.repeat(1, channels // 2) * inv_freq
    pos_enc_a = torch.sin(t_expand)
    pos_enc_b = torch.cos(t_expand)
    pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
    return pos_enc
​
# 为64个时间步生成位置编码
t1 = pos_encoding(torch.tensor([1] * 1).long(), channels=6)
t2 = pos_encoding(torch.tensor([2] * 1).long(), channels=6)
​
print(t1) #tensor([[0.8415, 0.0464, 0.0022, 0.5403, 0.9989, 1.0000]])
print(t2) #tensor([[ 0.9093,  0.0927,  0.0043, -0.4161,  0.9957,  1.0000]])

这个编码又称为正弦位置编码,常被用于Transformer模型,在这段代码中完成了对时间步1和时间步2生成的时间编码,一起来看看完成编码的步骤吧!

step1:首先将t增加一个维度,使其从一维变为二维(即 [[t], [t], [t]] 的形状

step2:获得一个向量(inv_freq)用于之后的编码计算,生成一个从0开始到 channels - 1 结束,步长为2的序列,用于索引每个维度的频率。这个序列被除以 channels,然后被提升到 1/10000 的幂,最后得到逆频率的向量

step3:将向量(inv_freq)与 t 相乘并分别进行计算正弦以及余弦值

step4:将得到的正弦以及余弦值拼接起来,因为得到的 t_expand 的长度为channels // 2

Time embeding为什么这么做?

有人会问为什么用这个公式进行编码呢?这时候就能体现出数学的优美性了,你会觉得一切都是恰到好处的。下面是我所看到的Time embeding优美之处,一起感受一下吧!

1)不同的T应有不同的embeding以区分时刻

2)邻近的T之间的embeding应相似

上面两个可能很显而易见,继续看下去🎈🎈🎈

3)添加Time embeding的合理性

  • 对于X预留有空间加入时序信息
  • 由于2)中的结论以及得到的编码之间的距离具有对称性
  • 通过线性变化在不同时间进行改变(时间复杂度低)

4)时间信息后续会经过一些激活函数可能会使得无法保持编码的对称性以及线性变换的性质?在扩散模型等生成模型中都会引入残差块,这会保证了一定的这些性质

5)为什么代入了sin,cos?由这两个计算方法可以把编码后的向量元素之间的距离具有对称性

虽然理解后可能觉得很简单,但是没理解的时候真的感觉挺难受的,希望这篇文章能让你更好地理解模型是如何记忆序列的吧