longformer的pytorch代码解读(更新中)

1,768 阅读2分钟

longformer

当前transformer由于自注意力机制是对全序列的自注意力,所以导致O(n2)O(n^2)的时间复杂度。因此,transformer的模型对序列的长度限制为512维。 而现实生活中的文档,往往词数会超过512个词,transformer的做法是将其拆成多个有部分重合的512维的样本,输入模型中训练。这显然对上下文信息获取,文档粒度的损失函数计算带来偏差。

为此作者收到CNN的启发提出一种滑动窗口注意力机制longformer,区别于transformer的每个词要注意序列中的512个词,longformer的词仅注意左右各ww窗口的词,即每个词的窗口为2w+12w+1,在窗口内做自注意力。这样就将时间复杂度缩减为O(wn)O(wn)

核心代码逻辑-处理窗口

作者代码地址:github.com/allenai/lon…

本文是基于pytorch的代码,作者称pytorch的函数支持chunk操作,而tensorflow不支持。为此研究作者文中的窗口注意力如何运作一定要知道作者将tensor进行了何种数据结构的处理。

首先先看qk的注意力:作者将[bs,seqlen,dim]的序列特征利用chunk的形式转换为[bs,seqlen//w-1,2w,dim]的形式。所使用的方法为跳动从内存中读取向量的数值,也就是一个地址会读窗口数ww次:

# input;tensor q, tensor k, int w
bsz, seqlen, num_heads, head_dim = q.size()
assert seqlen % (w * 2) == 0
assert q.size() == k.size()

chunks_count = seqlen // w - 1

# group bsz and num_heads dimensions into one, then chunk seqlen into chunks of size w * 2
q = q.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim)
k = k.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim)

chunk_q = _chunk(q, w)
chunk_k = _chunk(k, w)

# matrix multipication
# bcxd: bsz*num_heads x chunks x 2w x head_dim
# bcyd: bsz*num_heads x chunks x 2w x head_dim
# bcxy: bsz*num_heads x chunks x 2w x 2w
chunk_attn = torch.einsum('bcxd,bcyd->bcxy', (chunk_q, chunk_k))  # multiply

其中chunk方法的工作过逻辑如下图所示: image.png 落到代码层面就是作者完美的想法。

def _chunk(x, w):
    '''convert into overlapping chunkings. Chunk size = 2w, overlap size = w'''

    # non-overlapping chunks of size = 2w
    x = x.view(x.size(0), x.size(1) // (w * 2), w * 2, x.size(2))

    # use `as_strided` to make the chunks overlap with an overlap size = w
    chunk_size = list(x.size())
    chunk_size[1] = chunk_size[1] * 2 - 1

    chunk_stride = list(x.stride())
    chunk_stride[1] = chunk_stride[1] // 2
    return x.as_strided(size=chunk_size, stride=chunk_stride)

为什么从内存中读,而不是利用复制机制呢?因为从底层视角来看梯度更新实际是对底层内存的更新,同一个内存地址调用多次可以收到多次更新。而复制会导致梯度更新错误,作者这一招是真的高明。

在此技术上,处理序列特征就可以冗余出多个窗口。画个图举个例子,其中省略了bs跟dim,以方便理解:

image.png 代码运行:

image.png