这是我参与更文挑战的第 22 天,活动详情查看: 更文挑战
搭建环境
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>TensorFlow.js Tutorial</title>
<!-- Import TensorFlow.js -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.0/dist/tf.min.js"></script>
<!-- Import tfjs-vis -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis@1.0.2/dist/tfjs-vis.umd.min.js"></script>
<!-- Import the data file -->
<script src="data.js" type="module"></script>
<!-- Import the main script file -->
<script src="script.js" type="module"></script>
</head>
<body>
</body>
</html>
我们提供了代码,从一个特殊的 sprite 文件(约10MB)中加载这些图像,有关 sprite 做过前端或者做个游戏应该对着个不会陌生,因为加载一张大图片要比多次加载多张小图片效率更高,所以将这些图片手写数字图像按一定方式拼接成一张大图片,一并加载到浏览器端。
有关迷你手写数字集,太经典了这里就不想多说,一搜一大把,大家如果还不算了解这个数据集可以自行搜索一下了解一下迷你手写数据集的构成,这里 data.js 文件由来加载和解析数据,可以自己定加载数据集方式。
在文件提供了 MnistData类,有两个公共方法。
- nextTrainBatch(batchSize):从训练集中返回一批随机的图像及其对应的标签
- nextTestBatch(batchSize):返回一批来自测试集的图像和它们的标签。
MnistData类还做了洗牌和规范化数据的。
总共有 65,000张图片,用最多55,000张图片来训练模型,保留 10,000张图片做测试,
加载数据集以及数据集进行处理
我们先重点说一说 data.js 如何从网络获取图片资源,并将其读取。读取作者这段代码好处是能够给我们一些提示如何在浏览器以请求方式加载远端的数据集,其中可以注意下一些在浏览器端的技巧。
先说一下大概思路,然后我们就一些关键点给大家分享具体作者是如何实现的。就是读取数据写入到实现 canvas 保存数据,然后再用 getImageData 方法将数据从 canvas 中提取出来,
export class MnistData {
constructor() {
this.shuffledTrainIndex = 0;
this.shuffledTestIndex = 0;
}
async load() {
// Make a request for the MNIST sprited image.
// 通过 Image 对象来发起请求
const img = new Image();
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d');
const imgRequest = new Promise((resolve, reject) => {
// 跨域请求图片
img.crossOrigin = '';
img.onload = () => {
img.width = img.naturalWidth;
img.height = img.naturalHeight;
const datasetBytesBuffer =
new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);
const chunkSize = 5000;
canvas.width = img.width;
canvas.height = chunkSize;
for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
const datasetBytesView = new Float32Array(
datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4,
IMAGE_SIZE * chunkSize);
ctx.drawImage(
img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width,
chunkSize);
const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
for (let j = 0; j < imageData.data.length / 4; j++) {
// All channels hold an equal value since the image is grayscale, so
// just read the red channel.
datasetBytesView[j] = imageData.data[j * 4] / 255;
}
}
this.datasetImages = new Float32Array(datasetBytesBuffer);
resolve();
};
img.src = MNIST_IMAGES_SPRITE_PATH;
});
const labelsRequest = fetch(MNIST_LABELS_PATH);
const [imgResponse, labelsResponse] =
await Promise.all([imgRequest, labelsRequest]);
this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());
// 对数据进行打乱排序然后拆分到测试集合训练集
this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS);
this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS);
// 将图片和标签切分训练数据集和测试数据集
this.trainImages =
this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
this.trainLabels =
this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS);
this.testLabels =
this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);
}
nextTrainBatch(batchSize) {
return this.nextBatch(
batchSize, [this.trainImages, this.trainLabels], () => {
this.shuffledTrainIndex =
(this.shuffledTrainIndex + 1) % this.trainIndices.length;
return this.trainIndices[this.shuffledTrainIndex];
});
}
nextTestBatch(batchSize) {
return this.nextBatch(batchSize, [this.testImages, this.testLabels], () => {
this.shuffledTestIndex =
(this.shuffledTestIndex + 1) % this.testIndices.length;
return this.testIndices[this.shuffledTestIndex];
});
}
nextBatch(batchSize, data, index) {
const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE);
const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES);
for (let i = 0; i < batchSize; i++) {
const idx = index();
const image =
data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE);
batchImagesArray.set(image, i * IMAGE_SIZE);
const label =
data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES);
batchLabelsArray.set(label, i * NUM_CLASSES);
}
const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]);
const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]);
return {xs, labels};
}
}
const idx = index();
index 是函数类型,返回时候需要进行计算一下,因为对数据进行洗牌处理
() => {
this.shuffledTrainIndex =
(this.shuffledTrainIndex + 1) % this.trainIndices.length;
return this.trainIndices[this.shuffledTrainIndex];
}
nextBatch 是每一次从数据提取 batchsize 个数提供模型进行训练,这里 tensor2d
接受 typedArray ,第二个参数传入一个数组用于定义 tensor 的形状。标签类型为 Uint8Array
而图像数据类型为 Float32Array
。
nextBatch(batchSize, data, index) {
const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE);
const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES);
for (let i = 0; i < batchSize; i++) {
const idx = index();
const image =
data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE);
batchImagesArray.set(image, i * IMAGE_SIZE);
const label =
data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES);
batchLabelsArray.set(label, i * NUM_CLASSES);
}
const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]);
const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]);
return {xs, labels};
}
Image
<img width="100" height="200" src="picture.jpg">
其实 Image 对应 img 标签,如何我们 src 指定服务端存储图片地址,img 标签内置一些功能,可以请求图片数据并将其加载。Image() 构造方法来创建 img 实例都等同于 ,``document.createElement('img')`
var htmlImageElement = new Image(width, height);
在图片加载完成会有一个回调函数 onload
, 我们可以在这里处理数据
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d');
BufferArry
javascript 中数组比较重,内部比较复杂。对于 XHR、File API、Canvas 等等各种地方,读取了一大串字节流,如果用 javascript 里的 Array 去存,又浪费,又低效。
通常我们不会使用 ArrayBuffer ,可以这样理解,使用 ArrayBuffer 只是开辟了一块内存空间,需要通过视图来操作这块内存空间,因为不同数据类型占用内存空间。所以要想对这块内存进行操作我们还需定义 type Array 来操作这块内容空间
const buffer = new ArrayBuffer(8);
const view = new Int32Array(buffer);
这句话表示我们开辟一段内存空间来接受图片字节流数据,这里最后 4 表示
const datasetBytesBuffer =
new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);
import {MnistData} from './data.js';
// 这个方法用于将图片样本显示在 devtool 工具上
async function showExamples(data) {
// 创建可以可视化容器
const surface =
tfvis.visor().surface({ name: '输入数据样本', tab: '输入数据'});
// 获取批量样本
const examples = data.nextTestBatch(20);
const numExamples = examples.xs.shape[0];
// 为每一个样本创建 canvas 用于将图片渲染出来
for (let i = 0; i < numExamples; i++) {
const imageTensor = tf.tidy(() => {
// 将数据形状转换为图片格式 28x28 px
return examples.xs
.slice([i, 0], [1, examples.xs.shape[1]])
.reshape([28, 28, 1]);
});
//为每张图片有一个 canvas 来绘制图片
const canvas = document.createElement('canvas');
canvas.width = 28;
canvas.height = 28;
canvas.style = 'margin: 4px;';
// imageTensor 数据转换为 canvas
await tf.browser.toPixels(imageTensor, canvas);
// 然后将 canvas 添加绘制区域
surface.drawArea.appendChild(canvas);
// 销毁掉 imageTensor
imageTensor.dispose();
}
}
async function run() {
const data = new MnistData();
await data.load();
await showExamples(data);
}
document.addEventListener('DOMContentLoaded', run);