Mistral 7B

269 阅读2分钟

Mistral 7B 是小模型的典型例子。相对于 Llama,它使用了 GQA 与 SWA 两种技术改进。

grouped-query attention (GQA):讨论 GQA 之前,需要先讨论它的两种前身 multi-head attention (MHA) 和 multi-query attention(MQA)。嵌入数据矩阵经过 Q、K、V 矩阵乘积计算之后,会进行注意力头分配,对于 MHA,query、key、value 都会分成 hh 个头,然后在各个头内进行注意力计算;对于 MQA,在 query、key、value 分成 hh 个头之后,key 和 value 的头会进行均值池化计算,最后每个 query 头共享 key 和 value 向量,再进行注意力计算。而 GQA 则是这两种方法的折中,试图平衡模型能力与效率,它将 hh 个头分为 gg,组,每组具有一个 key 向量、一个 value 向量,以及hg\frac{h}{g} 个 query 向量,然后组内进行注意力计算。GQA 还有一个好处在于它可以自定义组数量,拥有较大的灵活性。
(但是在这里我突然觉得,query 和 key 配对似乎会好一些)

Screenshot 2024-02-02 at 18.37.07.png

sliding window attention (SWA):SWA 使用了容量为 WW 的注意力窗口,而不是全局。虽然在attention计算中失去了远处的信息,但是由于层的堆叠,在注意力窗之外的 token,也还是能在若干层之后传到目标位置的 token。

Rolling Buffer Cache 是 SWA 产生的缓存受益。回顾 decoder 层与层之间的数据迁移,上一层不断的输出单个 token,原有的注意力机制,必须保存所有的 token,以便最后一个token的注意力计算;但对于 SWA 来说,由于窗口的大小为 WW,对于前 WW 个token,正常进入到缓存,但是第 W+1W+1 个,第 11 个数据是不会再使用了的,因此可以将 W+1W+1 个数据放在第 11 个的位置(keys and values)。

Screenshot 2024-02-02 at 15.28.18.png

Pre-fill and Chunking 也是 SWA 产生的缓存受益,相比于 Rolling Buffer Cache,它指的是模型开头的处理。对于过长的 sequence,可以按照滑动窗口长度进行切块(chunking),这样一个块的计算,只需要考虑该块和它前面的一个块。
(我觉得这个没有特别大作用,但也许是我没实际训练过模型的原因我体会不到)

Screenshot 2024-02-02 at 15.37.58.png