Mistral 7B 是小模型的典型例子。相对于 Llama,它使用了 GQA 与 SWA 两种技术改进。
grouped-query attention (GQA):讨论 GQA 之前,需要先讨论它的两种前身 multi-head attention (MHA) 和 multi-query attention(MQA)。嵌入数据矩阵经过 Q、K、V 矩阵乘积计算之后,会进行注意力头分配,对于 MHA,query、key、value 都会分成 个头,然后在各个头内进行注意力计算;对于 MQA,在 query、key、value 分成 个头之后,key 和 value 的头会进行均值池化计算,最后每个 query 头共享 key 和 value 向量,再进行注意力计算。而 GQA 则是这两种方法的折中,试图平衡模型能力与效率,它将 个头分为 ,组,每组具有一个 key 向量、一个 value 向量,以及 个 query 向量,然后组内进行注意力计算。GQA 还有一个好处在于它可以自定义组数量,拥有较大的灵活性。
(但是在这里我突然觉得,query 和 key 配对似乎会好一些)
sliding window attention (SWA):SWA 使用了容量为 的注意力窗口,而不是全局。虽然在attention计算中失去了远处的信息,但是由于层的堆叠,在注意力窗之外的 token,也还是能在若干层之后传到目标位置的 token。
Rolling Buffer Cache 是 SWA 产生的缓存受益。回顾 decoder 层与层之间的数据迁移,上一层不断的输出单个 token,原有的注意力机制,必须保存所有的 token,以便最后一个token的注意力计算;但对于 SWA 来说,由于窗口的大小为 ,对于前 个token,正常进入到缓存,但是第 个,第 个数据是不会再使用了的,因此可以将 个数据放在第 个的位置(keys and values)。
Pre-fill and Chunking 也是 SWA 产生的缓存受益,相比于 Rolling Buffer Cache,它指的是模型开头的处理。对于过长的 sequence,可以按照滑动窗口长度进行切块(chunking),这样一个块的计算,只需要考虑该块和它前面的一个块。
(我觉得这个没有特别大作用,但也许是我没实际训练过模型的原因我体会不到)