[AI 入门] 手写数字识别-获取数据集

770 阅读4分钟

2022年后,AI 以极其迅猛的速度发展,无论是从科技还是商业角度,大模型(LLM) 都具备无可估量的价值。但是,随着模型越发庞大,其推理所消耗的资源也越来越大,传统的由服务器资源推理成本是越来越高,但是,如今用户的设备也是越来越好了,为减轻服务器压力和开销,边缘计算 的价值越来越高,场景越来越重要。笔者作为一个页面仔小白曾思考:除了 wasm 这套技术可以加持端侧 ai 能力,那我浏览器亲儿子:JavaScript 难道就不能做做同样是脚本语言 Python 的事情?

答案是肯定的,在神经网络领域,JS 社区中已经有了 TensorFlow.js, transformer.js 的框架支持,最近笔者发现了另一个库 : JS-PyTorch upload_p8yhmrhwawiw95fixjs14s51fojbqam8.png

由于笔者之前用 PyTorch 做过一些模型研究,看它的语法非常亲切😶‍🌫️,比 TensorFlow 好多了。 upload_4u589v1jwtgn7az28h66oevmdzeouf0f.png

看这个库未来还会支持 GPU?,那就决定用它试试水,看看它能不能跑通 CNN 的 hello world - 识别手写数字 upload_ur9muu5k16orez7v9wxafdlnvphrz9zy.png

数据集准备

  1. 下载地址 yann.lecun.com/exdb/mnist/ upload_w7nmosw3r2dv2tmevmsb034a30kbrlcl.png 其中,train-images 为训练集图片,train-labels 为训练集标签,t10k-images 为测试集图片,t10k-labels 为测试集标签,训练集和测试集没有交集。
  2. 解压后的数据为二进制数据,本文主要介绍对下载后的数据进行处理并验证数据

数据处理

解压后的二进制数据需要处理,本文在实战刚过程中就没有处理🫥,浪费了so much time。 一般来说,在不压缩数据的情况下 图像被存储成二进制后,文件大小的计算公式如下: length=widthimageheightimagechannelimagenumlength = width_{image} * height_{image} * channel_{image} * num 即大小等于 图像的宽高 * 图像的通道数 * 图像的个数。 拿训练的图像数据为例: upload_sdxho41869dwqhgbzaxru8ruywc4ljng.png upload_uaqv931crko9pg8ecrgnhnhj0zz5wiqo.png 它的大小是 47040016 字节,但是,根据公式大小应为 28* 28* 1* 60000 = 47040000 ,多了 16 个字节(128位)。(一度怀疑自我ing🤺),翻阅了几篇用 Python 跑识别数字的文章 ,其中一篇是这么做的: upload_11qh1mlf828hpk4fvopju0aaz1lp3ef9.png 其实在官网的下方提及了这几个变量的含义。(看文档的重要性) upload_lmljrz9dqvk9p4l0g6lumfkx7dfuxigs.png 之后我们便可以把原始的图像数据拿到了

  const images = [];
  let index = 16;
  
  //* bufferSize: 47040016
  
  while (index < buffer.byteLength) {
    const array = new Float32Array(28 * 28 * 1);
    for (let i = 0; i < 28 * 28 * 1; i++) {
    
      //* 图片的原始值为 (0,255) 一个8位无符号整型 就能放的下
      
      array[i] = buffer.readUInt8(index++);

    }

    images.push(new Picture(array));

  }

label 数据也是这样处理,不过它的头除了魔数外,只有长度了,打印下 训练集 label 的数据: upload_bjjzc4cyb5e4adu4kexaohgc0oamvpl9.png

label 数据很清晰,但对于图像数据来说,目前只是一个 (0, 255) 范围的数字数组,不能直观的感受图像是否是正确的,因此下一步工作是把数字数组,还原成一个图片来看数据是否加载成功。

还原图片

将数组还原成图像,是用 canvas 这个库,把数组中每一个点画在 canvas 上面再进行保存查看。在 RGB模型 中,R=G=B 为一种灰度值,灰度值越大,颜色越偏白,越小则越偏黑。如下图所示: upload_150qab0ovbnv6nduwo2lwey2ctmv2du8.png

upload_k9sjnc9f11tgaf11vtgpdr7iwm3jjf78.png 因此还原过程中,只需把 3 通道的值都设置为图像数组里的值即可,对于 PNG 这种格式,需要再加一个透明通道即可。还原图像代码如下:

class Picture {

    image = null;

    constructor(image, width = 28, height = 28, depth =1){

        this.width = width;

        this.height = height;

        this.depth = depth;

        this.imgSize = width * height;

        this.image = image;

    }

    savePicture(path){

        const canvas = createCanvas(this.width, this.height);

        const context = canvas.getContext('2d');

        const imageData = context.createImageData(this.width, this.height);

        for(let i = 0; i< this.imgSize; i++){

            const value = this.image[i];

            imageData.data[(4 * i )] = value;     //0 4 8

            imageData.data[(4 * i ) + 1] = value; //1 5 9

            imageData.data[(4 * i ) + 2] = value; //2 6 10

            imageData.data[(4 * i ) + 3] = 255;   //3 7 11

        }

        context.putImageData(imageData, 0, 0);

        const out = fs.createWriteStream(path);

        const stream = canvas.createPNGStream();

        stream.pipe(out);

        out.on('finish', () => {console.log('还原图像成功')})

    }

}

在加载训练图像数据完成后,保存第一张图片: upload_mhi6lpgtbyr2hcpm2ykabbq0eqzvf2fx.png 运行结果: upload_cnk69xm4ydyjb139sh5gguocu1nnlf8v.png 和标签数据中的第一个数据一致! 加载数据完成,下一步,开始训练👋!

项目环境

版本
nodev19.1.0
js-PyTorch0.3.6
canvas2.11.2

参考

深度学习实战入门——CNN实现MNIST手写数字识别:


hi, there👋

i am siroi

一个接触前端 2.5 年的新人,欢迎关注我的微信公众号:siroi的前端手册

让我们一起愉快的玩耍吧🤓