文/ 阿里淘系 F(x) Team - 天可
本文将会介绍使用 Tensorflow.js 进行大规模训练时候的一个技巧:分批读取训练数据。我们将会用:
- JS 的 Generator 函数
- Tensorflow.js 的 tf.data.generator 和 model.fitDataset 两个 API
来实现分批训练。
前置知识
在阅读本文前,请确保你拥有以下知识,否则读起来可能比较吃力:
- 机器学习/深度学习基础,推荐阅读:机器学习速成教程。
- 基于 Tensorflow.js 的小规模训练经验,推荐阅读:Tensorflow.js 官方教程。
背景问题:线性增长的内存占用
对于深度学习来说,训练集越大,模型学习的效果越好。所以,在真实的工作中,我们的训练集往往会上 G,这时候,如果把所有训练集都读到内存中,那么还没等到训练,计算机的内存就吃不消,被爆掉了。
比如,你有这样一段代码,将所有训练集一次性读取到内存中:
// trainSet 是一个数组,里面包含了所有的训练集的硬盘路径信息和分类信息。
// getTensor 函数用来读取训练集到内存,并生成 tensor
const { xs, ys } = getTensor(trainSet);
await model.fit(xs, ys, { epochs: 20 });那么,你在读取训练集到 Tensor 的过程中,内存占用一定是线性增加的。如图所示:

如果你的计算机内存够大能够吃掉所有的训练集还好,否则就会崩溃掉,抛出 OOM (Out Of Memory)错误。
解法思路
那么如何解决这个问题呢?当然是分批读取训练数据。我们先读取一部分训练数据到内存里转为 Tensor,让模型训练,然后再读取下一批数据。这个“批”在深度学习里叫 batch,一“批”的数量叫 batchSize。
那么如何在 Tensorflow.js 中分批读取训练数据呢?Tensorflow.js 提供了一个 tf.data.generator 方法,可以让我们写一个 JS 的生成器函数(Genertaor)作为参数放进去,生成一个支持分批读取数据的数据集(Dataset)。然后,我们使用 model.fitDataset 方法进行训练即可。
让我们详细演示一下吧!
Generator 简介
在 JS 中,Generator 函数调用后会生成一个 Generator 对象,Generator 对象有一个 next 方法,可以执行“被 yield 隔开的代码”,并返回一个对象,这个对象的 value 是 yield 后面的表达式。比如:
function* infinite() {
let index = 0;
while (true) {
yield index++;
}
}
const generator = infinite(); // "Generator { }"
console.log(generator.next().value); // 0
console.log(generator.next().value); // 1
console.log(generator.next().value); // 2使用 Generator 生成分批数据集
那么如何利用 Generator 生成分批数据集呢?我们直接上代码吧:
// 设置批次大小为 32
const batchSize = 32;
// 利用 tf.data.generator 创建 dataset
const ds = tf.data.generator(function* () {
// for 循环里面是读取每批数据的代码,start 和 end 是每批的开始与结束下标
for (let start = 0; start < trainSet.length; start += batchSize) {
// 数据集不一定是 batchSize 的整数倍,所以要保证 end 最大为训练集的大小
const end = Math.min(start + batchSize, trainSet.length);
// 返回分批的 xs 和 ys。
const { xs, ys } = getTensor(trainSet.slice(start, end));
return { xs, ys };
}
});分批训练最佳实践:随机化
由于分批读取训练数据,无法考虑训练集的总体情况,为了保证训练效果,我们需要让样本(训练集)分布足够的随机化,这样分批采样(即读取数据)时候才更有代表性。
tf.util.shuffle(trainSet);
// 在后面执行生成 ds 的逻辑使用 fitDataset 进行训练
生成了 dataset 之后,我们只需要调用 fitDataset 方法就可以实现分批读取训练了。
await model.fitDataset(ds, { epochs: 20 });就是这么简单!
优化结果:内存占用不再线性增长
如此以来,内存就变成了常量级别,不会再线性增长了。

皆大欢喜!