longformer
当前transformer由于自注意力机制是对全序列的自注意力,所以导致的时间复杂度。因此,transformer的模型对序列的长度限制为512维。 而现实生活中的文档,往往词数会超过512个词,transformer的做法是将其拆成多个有部分重合的512维的样本,输入模型中训练。这显然对上下文信息获取,文档粒度的损失函数计算带来偏差。
为此作者收到CNN的启发提出一种滑动窗口注意力机制longformer,区别于transformer的每个词要注意序列中的512个词,longformer的词仅注意左右各窗口的词,即每个词的窗口为,在窗口内做自注意力。这样就将时间复杂度缩减为
核心代码逻辑-处理窗口
作者代码地址:github.com/allenai/lon…
本文是基于pytorch的代码,作者称pytorch的函数支持chunk操作,而tensorflow不支持。为此研究作者文中的窗口注意力如何运作一定要知道作者将tensor进行了何种数据结构的处理。
首先先看qk的注意力:作者将[bs,seqlen,dim]的序列特征利用chunk的形式转换为[bs,seqlen//w-1,2w,dim]的形式。所使用的方法为跳动从内存中读取向量的数值,也就是一个地址会读窗口数次:
# input;tensor q, tensor k, int w
bsz, seqlen, num_heads, head_dim = q.size()
assert seqlen % (w * 2) == 0
assert q.size() == k.size()
chunks_count = seqlen // w - 1
# group bsz and num_heads dimensions into one, then chunk seqlen into chunks of size w * 2
q = q.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim)
k = k.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim)
chunk_q = _chunk(q, w)
chunk_k = _chunk(k, w)
# matrix multipication
# bcxd: bsz*num_heads x chunks x 2w x head_dim
# bcyd: bsz*num_heads x chunks x 2w x head_dim
# bcxy: bsz*num_heads x chunks x 2w x 2w
chunk_attn = torch.einsum('bcxd,bcyd->bcxy', (chunk_q, chunk_k)) # multiply
其中chunk方法的工作过逻辑如下图所示: 落到代码层面就是作者完美的想法。
def _chunk(x, w):
'''convert into overlapping chunkings. Chunk size = 2w, overlap size = w'''
# non-overlapping chunks of size = 2w
x = x.view(x.size(0), x.size(1) // (w * 2), w * 2, x.size(2))
# use `as_strided` to make the chunks overlap with an overlap size = w
chunk_size = list(x.size())
chunk_size[1] = chunk_size[1] * 2 - 1
chunk_stride = list(x.stride())
chunk_stride[1] = chunk_stride[1] // 2
return x.as_strided(size=chunk_size, stride=chunk_stride)
为什么从内存中读,而不是利用复制机制呢?因为从底层视角来看梯度更新实际是对底层内存的更新,同一个内存地址调用多次可以收到多次更新。而复制会导致梯度更新错误,作者这一招是真的高明。
在此技术上,处理序列特征就可以冗余出多个窗口。画个图举个例子,其中省略了bs跟dim,以方便理解:
代码运行: