【本文正在参加金石计划附加挑战赛——第三期命题】
用 Rust 编写 LLM: 寻找高效的矩阵乘法
以下是我学到的经验, 以及我是如何编写 llm.rust 和解决矩阵乘法问题的
从 Karpathy llm.c 开始, 我就在想"我能用 Rust 写这个吗?" 以下是我学到的经验教训, 以及我是如何编写 llm.rust 的. 在这一篇文章中, 让我们来解决矩阵乘法问题.
矩阵乘法可能是机器学习中最重要的运算. 我还记得, 当我还是一名工科学生时, 在第一堂线性代数课上, 老师开始讲解矩阵, 特征向量, 基和正态基. 我当时非常困惑, 脑子过了一会儿才开始明白为什么我们要这么费劲地讲矩阵和基集, 以及一个好的基对我们的世界意味着什么. 从那时起, 我一直觉得线性代数是如此迷人, 而且, 从纯计算机科学的角度来看, 那些试图越来越高效地处理矩阵的算法是多么令人惊叹.
特别是, 我们知道矩阵-向量乘积非常简单, 但当我们有矩阵-矩阵或张量-张量乘积时, 情况就变得越来越复杂了. 从这里开始, 许多方法都被用来优化矩阵乘法. 这个问题仍然让我非常着迷, 看到Karpathy的llm.c, 我感到非常有趣和高兴.
事实上, attention 算法--以及所有 ML 算法--的核心部分当然是 矩阵乘法. 在我的项目中, 我从 Karpathy 软件仓库的一个早期提交开始(这里是 矩阵乘法). 大部分时间都花在了这个函数上, 因此优化这个计算无疑会帮助我们降低 LLM 的培训成本. 公式 1 显示了我们在 LLM 中处理的公式:
公式 1: attention矩阵乘法公式.
我们有一个输出张量out, 其维度为B, 即批次索引, 定义范围为0至B-1; 时间步长t, 定义范围为0至T-1; 输出通道o, 定义范围为0至OC-1. 在attention机制中, 矩阵乘法在 Q, K 和 V 计算中发挥作用. 给定一个嵌入输入X, 可通过线性变换将嵌入投射到查询Q, 键K和值V向量中:
公式 2: attention算法中的 Q, K, V
其中, W 表示查询(下划线 Q), 键(下划线 K)和值(下划线 V)权重, b 是相关偏差.
同样, 矩阵乘法也出现在反向传播步骤中, 我们在这里运行的是反向矩阵乘法. 后向矩阵乘法计算与输入, 权重和偏置相关的梯度, 返回与输出相关的损失梯度.
公式 3: 后向矩阵. 首先, 我们计算相对于输入的损失梯度(dinp); 其次, 计算相对于权重的损失梯度; 最后, 计算相对于任何偏置的损失梯度.
公式 3 概括了后向矩阵乘法. dinp是相对于输入嵌入的损失梯度, 即公式 1 中的inp. 该公式通过累加输出梯度与相应权重的乘积来更新 dinp. 然后, 我们计算相对于权重的损失梯度, 累加输出梯度与相应输入梯度的乘积. 最后, 如果存在任何偏差, 我们将计算相对于偏差的损失梯度, 将所有批次 B 的输出梯度相加, 并将每个输出通道 OC 的时间步骤 T 相加.
有了这段令人惊叹的代码, 我想我是否可以用 Rust 做一些类似的事情, 以帮助我越来越多地学习这种编程语言, 并尝试在我的 MacBook 上实现某种训练. 本文涉及的代码都可以在此处**找到. 请注意, *代码正在编写中, 因此可能会逐日更改.
本文并不想比较实现速度, 因为这取决于多个变量(我们可以使用 GPU, 数据分片, 矢量化, 量化, 解析). 我想做的是找到在我的 Rust LLM 实现中使用的最佳方法, 并尝试在我的 MacBook M2 上运行我的代码来训练 LLM.
我对 Rust 的选择
如果你赶时间, 以下是我选择的 Rust 最佳实现, 可以在 MacBook M2 Pro 上运行类似 GPT-2 的 LLM 训练.
表 1 以秒为单位比较了在 8 线程上运行 OpenMP 的 C 语言实现(C OpenMp), Rust 的基本实现(Rust base), 使用 Rayon 的 Rust 实现(Rust Rayon)和 Blas 的 Rust 实现(Rust Blas)之间的平均性能时间. 输入维度为 B = 64, T = 1024, C = 768, OC = 768, 对应的输入和输出张量大小为 50'331'648 个元素.
总体而言, 正如预期的那样, Blas 执行正向矩阵乘法的平均时间为 0.05 s. 同样, 在 Rust 版中, Blas 的后向矩阵乘法运算速度最快, 为 0.19s.
我还尝试将这两项计算推向极限, 将批量大小从 4 修改为 128, 同样将时间步长从 64 增加到 2048, 将通道和输出通道从 48 增加到 1536. 这意味着输入和输出张量从 12饸元素增加到 402髥餐元素. 图 1 和图 2 以对数标度表示这些输入值的 Matmul 正向和反向性能. 对于 matmul 前向操作, 我们从平均一微秒到最大 0.58 +/- 0.01 s. 同样, 对于后向运算, 我们的平均时间从一微秒到 2.54 +/- 0.05 s. 由此得出的结论是, Blas 经过高度优化, 可以处理非常大的矩阵. 事实上, 在非常小的尺度(B = 4)上, 范围差异很大, 从 1.20 毫秒到 0.4 毫秒不等.
图 1: Rust BLAS matmul 正向性能与批量大小 B(从 4 到 128)的对数图. 批量大小为 4 意味着输入和输出张量大小为 12饸, 批量大小为 128 意味着输入/输出张量大小为 402髥餐.
图 2: Rust BLAS matmul 反向性能与批量大小 B(从 4 到 128)的对数图. 批量大小为 4 意味着输入和输出张量大小为 12饸, 批量大小为 128 意味着输入/输出张量大小为 402髥餐.
C 语言中的简单矩阵乘法
我知道很多人可能对 C 和 C++ 有过敏反应, 但请听我说, 在本例中, 我们简化了很多问题, 并尝试使用 OpenMP 实现矩阵乘法--请记住, 实现过程遵循公式 1, 这里是 C 代码
void matmul_forward(float* out,
float* inp,
float* weight,
float* bias,
int B, int T, int C, int OC) {
#pragma omp parallel for collapse(2)
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
float* out_bt = out + b * T * OC + t * OC;
float* inp_bt = inp + b * T * C + t * C;
for (int o = 0; o < OC; o++) {
float val = (bias != NULL) ? bias[o] : 0.0f;
float* wrow = weight + o * C;
for (int i = 0; i < C; i++) {
val += inp_bt[i] * wrow[i];
}
out_bt[o] = val;
}
}
}
}
让我们看看这段代码中发生了什么:
- 起点是调用 openMP 并行性:
#pragma omp parallel for collapse(2)omp parallel for是一个指令, 它结合了omp parallel和omp for指令. 它定义了一个有并行 for 并必须并行运行的区域. collapse(2)指示编译器将一些嵌套循环折叠成一个大的迭代. 通常,collapse`创建的单个循环的迭代次数至少比原始嵌套循环多两个数量级. - 然后, 我们要做一些"奇怪"的事情, 比如
float* out_bt = out + b*T*OC + t*OC;这是 C 语言中的指针运算, 即计算访问元素的正确索引. 在这里, 我们计算的是当前批次和时间步长的起点, 因此后面的索引都是相对于这个位置的. 此外, 这还允许我们对多维输入进行矢量化, 因此我们将多维输入扁平化为一维数组, 以提高性能. 例如, 在这里float* out_bt = out + b*T*OC + t*OC我们处理的是张量out. 这个张量的维数为B x T x OC. 偏移计算执行以下操作: 1) 使用b*T*OC移动到批次b中; 2) 使用t*OC移动到批次b中的时间步长t. - 为了进一步理解指针运算, 请考虑这种情况: b = 2, t = 3, c = 4, oc = 5. 要访问批次
1, 时间步长2, 输入通道3的输入数据inp, 我们可以计算 1) 批量偏移量b*T*C = 1*3*4 = 12; 2) 时间步偏移量t*C = 2*4 = 8; 3) 总偏移量12+8 = 20. 在最后一个循环中, 我们迭代索引i, 对于输入i=3, 总偏移量等于23. 因此,input[23]对应于输入input[1][2][3].
有一点需要注意, 如果你在 MacOS 上运行, 可能需要安装 llvm (所以 brew install llvm )并导出路径. 下面是我编译和运行代码的过程:
#!/bin/bash
export OMP_NUM_THREADS=4
export LDFLAGS="-L/opt/homebrew/opt/llvm/lib"
export CPPFLAGS="-I/opt/homebrew/opt/llvm/include"
/opt/homebrew/opt/llvm/bin/clang -O2 -fopenmp $LDFLAGS $CPPFLAGS -o matmul_example matmul_example.c
echo "Run"
./matmul_example 64 1024 768 768
其中 matmul_example.c 是 C 代码的名称.
Rust 中的naive方法
在 Rust 中的naive方法的源代码(和cargo构建)可在 这里 找到.
让我们来看看主函数:
fn matmul_forward_standard(
out: &mut [f32],
inp: &[f32],
weight: &[f32],
bias: Option<&[f32]>,
b: usize,
t: usize,
c: usize,
oc: usize,
) {
for bb in 0..b {
for tt in 0..t {
let out_offset = (bb * t + tt) * oc;
let inp_offset = (bb * t + tt) * c;
let inp_bt = &inp[inp_offset..inp_offset + c];
for oo in 0..oc {
let mut val = if let Some(bias_vec) = bias {
bias_vec[oo]
} else {
0.0
};
let weight_offset = oo * c;
let w_row = &weight[weight_offset..weight_offset + c];
for i in 0..c {
val += inp_bt[i] * w_row[i];
}
out[out_offset + oo] = val;
}
}
}
}
在 Rust 中, 将多维数组表示为一维数组可以充分利用连续的内存存储. 由于缓存的本地性和计算开销的减少, 这种方法大大提高了性能. 同样, 输入数组的大小为 [B][T][C]. 扁平化操作通过偏移进行, 如 inp_offset = (bb * t + tt) * oc:
bb*t将索引移至批次, 跳过每个批次的t个时间步;+tt移动到批次中正确的时间步长*c调整每个时间步的通道数
然后我们进行切片, 即inp_bt = &inp[inp_offset...inp_offset + c];, 因此我们在切片内执行顺序访问, 以提高空间定位性能.
这段代码中没有其他奇怪的地方, 我们可以识别一些常见的 Rust 特性, 如所有权, 借用和可变性. 在函数中, 我们有
- 使用
&f[32]的不可变引用, 因此输入数组不会被修改 - 使用
&mut [f32], 为输出张量提供可变引用 - 选项处理, 我们可能没有
bias, 因此定义为Option<&f[32]>. 在函数的最后一步, 我们通过Some(bias_vec)来考虑.
让我们把事情做得更好一些: Rayon
第二种方法是使用 Rayon. Rayon 是一个允许数据并行的 Rust 库, 可以将顺序计算(比如我们的例子)转换为并行计算. 我们可以使用 Rayon 的 ParallelIterator 和 par_sort 等高级并行结构, 也可以使用 join, scope 和 ThreadPoolBuilder 等自定义结构.
函数定义如下:
fn matmul_forward_rayon(
out: &mut [f32],
inp: &[f32],
weight: &[f32],
bias: Option<&[f32]>,
B: usize,
T: usize,
C: usize,
OC: usize,
) {
out.par_chunks_mut(T * OC)
.zip(inp.par_chunks(T * C))
.for_each(|(out_b, inp_b)| {
for time_idx in 0..T {
let inp_bt = &inp_b[time_idx * C..(time_idx + 1) * C];
let out_bt = &mut out_b[time_idx * OC..(time_idx + 1) * OC];
for o in 0..OC {
let mut val = bias.map_or(0.0, |b| b[o]);
let w_row = &weight[o * C..(o + 1) * C];
for i in 0..C {
val += inp_bt[i] * w_row[i];
}
out_bt[o] = val;
}
}
});
}
我们首先创建两个并行迭代器: out.par_chunks_mut 和 inp.par_chunks. 前者从 out 数组中创建每次最多包含 T*OC 元素的块, 后者对包含 T*C 元素的 inp 数组创建同样的块. zip会将两个迭代器合并成一对迭代器, 这样, out的每个块都有其对应的inp块(for_each(|(out_b, inp_b)| {} )). 假设 B=2, T=3, C=4 和 OC=5, 那么 inp 将有 24 个元素, 其形状为 [2][3][4], 而 out 将有 30 个元素, 其形状为 [2][3][5]. 分块是这样工作的:
- 在输出端
T*OC将得到3*5=15个元素, 所以最初是0至14元素的所有切片(out[0]), 然后是15至29元素的另一批切片(out[1]). - 在输入的
T*C中, 将有3*4=12个元素, 所以最初的一批包含从0到11的元素, 然后第二批包含从12到23的元素:
inp (flattened):
Batch 0:
[ inp[0][0][0], inp[0][0][1], ..., inp[0][0][3],
inp[0][1][0], ..., inp[0][1][3],
inp[0][2][0], ..., inp[0][2][3] ] // Total 12 elements
Batch 1:
[ inp[1][0][0], ..., inp[1][0][3],
inp[1][1][0], ..., inp[1][1][3],
inp[1][2][0], ..., inp[1][2][3] ] // Total 12 elements
Similarly for out:
out (flattened):
Batch 0:
[ out[0][0][0], ..., out[0][0][4],
out[0][1][0], ..., out[0][1][4],
out[0][2][0], ..., out[0][2][4] ] // Total 15 elements
Batch 1:
[ out[1][0][0], ..., out[1][0][4],
out[1][1][0], ..., out[1][1][4],
out[1][2][0], ..., out[1][2][4] ] // Total 15 elements
这些数据块会在一个外循环中被摄取, 该循环会经过时间步, 然后在输出值循环中被摄取.
作为一个启示, Rayon 在将输入分割成并行化的数据块方面很有帮助, 而且每个数据块的计算都是独立的, 因此一切都可以并行计算. 同样, 我们正在利用顺序数据访问, 并在连续的内存块上工作.
我的最佳方法: Blas
我测试的最后一种方法是使用 Blas. Blas 本身是用 Fortran 编写的, 但也有 Rust 绑定. 它为数学计算提供了多种方法, 其中一种是 sgemm, 它可以根据公式以单精度执行矩阵乘法(单精度通用矩阵乘法):
公式 4: SGEMM 矩阵乘法公式
这里, A 是 M x K 矩阵, B 是 K x N 矩阵, C 是 M x N - 输出矩阵. 参数 alfa 和 Berta 是单精度浮点数或scalar, 因此它们是矩阵乘法器. op 是对给定矩阵的运算, 因此我们可以得到转置或复共轭. 在编码方面, 矩阵乘法可定义为:
fn matmul_blas(
out: &mut [f32],
inp: &[f32],
weight: &[f32],
bias: Option<&[f32]>,
b: usize,
t: usize,
c: usize,
oc: usize,
) {
// inp size: m x k = ( (BT) x C)
// weight size: n x k = (OC x C) --> transposed (C x OC)
let m = (b * t) as i32; // output rows for C
let k = c as i32; // number of columns for A and rows for B
let n = oc as i32; // number of columns for C
// Leading dimensions for Row-Major layout
let lda = k; // lda >= K
let ldb = k; // ldb >= N
let ldc = n; // ldc >= N
unsafe {
sgemm(
Layout::RowMajor,
Transpose::None, // Transpose of A ('N' for no transpose)
Transpose::Ordinary, // Transpose of B
m,
n,
k,
1.0,
inp,
lda,
weight,
ldb,
0.0,
out,
ldc,
);
}
// Add bias if present
if let Some(bias) = bias {
out.par_chunks_mut(oc)
.for_each(|row| {
for (o, val) in row.iter_mut().enumerate() {
*val += bias[o];
}
});
}
}
sgemm需要以下内容:
Layout::RowMajor表示我们以行大序存储输入矩阵, 因此一行中的连续元素彼此相邻.transa: Transpose::None这里的输入是矩阵 A,None表示我们不想对这个矩阵进行转置处理transb: Transpose::Ordinary表示矩阵 B 将被转置m是结果矩阵 C 的行数, 即b*Tn是 C 中的列数, 即ock是共享维度, 因此是通道数c是输入矩阵 A 的列数alpha=1.0是第一个标量, 在我们的例子中是 1a=inp是输入矩阵- 由于我们使用的是 RowMajor 顺序, 而不是转置, 因此这相当于 A 的列数;
weight表示我们的矩阵 B- ldb
是矩阵 B 的前导维数, 也就是k` beta=0.0因为我们在计算中不需要 betaout是矩阵 Cldc是 C 的前导维数, 也就是我们输出中的列数
如果我们将其与公式 4 结合起来, 就不难发现我们正在计算矩阵 A 乘以 B 的转置.
从 Rust 的角度来看, 我们可以看到unsafe, 这是怎么回事呢? Rust 的设计默认是内存安全的, 以防止出现空指针引用等错误. unsafe块允许用户告诉 Rust 编译器"小心, 这可能不安全, 但不用担心". 这里需要使用unsafe, 因为我们使用的sgemm是一个通过绑定或"外来函数接口"(FFI)进行接口的函数. 因此, 我们有责任传递有效的指针, 并对长度和大小进行检查. 因此, 我们可以在代码中添加一些断言, 例如
assert!(inp.len() >= (b * t * c), "Input slice is too small.");
assert!(weight.len() >= (oc * c), "Weight slice is too small.");
assert!(out.len() >= (b * t * oc), "Output slice is too small.");
以确保输入矩阵的长度至少与需要的长度一样大, 并对空指针进行检查
assert!(!inp.is_empty(), "Input slice is empty.");
assert!(!weight.is_empty(), "Weight slice is empty.");
assert!(!out.is_empty(), "Output slice is empty.");
总结一下
我想我已经为今天的文章整理了许多细节. 在这篇文章中, 我想与大家分享我在寻找用 Rust 实现矩阵乘法运算的最佳方法时的经验教训, 以获得与 Karpathy 的 llm.c 类似的代码.
在本文中, 我们探索了:
- 使用 OpenMP 在 C 语言中的naive实现
- 比较 OpenMP 和 Rust 的性能. 比较是在批量大小为
B=64, 时间步长为T=1024, 通道大小和输出通道大小为C 和 OC = 768的情况下进行的. 具体来说, 我向大家介绍了
- 将 C 代码简单翻译为 Rust 代码. 在这里, 我们了解了指针运算以及从 C 到 Rust 的简单转换.
- 使用功能更强大的板块--Rayon. 这里的重点是从输出和输入数组中创建分块, 并与这些分块并行工作, 这样我们就可以运行独立进程, 加快整体计算速度. 如表 1 所示, Rayon 处理前向乘法和后向乘法大约需要 4 秒;
- 如何在 Rust 中使用 Blas 实现矩阵乘法, 以获得更好的性能. Blas 是最好的方法, 基准时间为毫秒级. 此外, 图 1 和图 2 显示了在不同输入/输出大小(从
B=4...128,T=64...2048, 到C/B=4...128和T=64...2048)下的正向和反向乘法性能.2048和C / OC = 48...1536.
好了, 今天的内容就分享到这里吧!
一家之言, 欢迎拍砖!
Happy Coding! Stay GOLDEN!