用TensorFlowjs 在浏览器中搭建神经网络 识别0~9和A~Z

605 阅读1分钟

dome wangpeng1478.github.io/web/tensorf…

1、定义模型

    //定义线性回归的模型
    const model = tf.sequential();

    //输入层 28*28 = 784
    model.add(tf.layers.inputLayer({inputShape: [784]}));

    //输出空间,输出空间为训练范围的长度 (0-9和A-z的长度)
    model.add(tf.layers.dense({units: 36}));

    //输出空间所有值之和为1
    model.add(tf.layers.softmax());

    //准备训练模型:指定损失和优化器。
    model.compile({
        optimizer: 'sgd', //优化器
        loss: 'categoricalCrossentropy', //损失函数
        metrics: ['accuracy'] //logs里的acc
    })

2、开始训练

    const XData = []; // 手写数据 (白色背景 黑色文字 28x28)
    const YData = tf.oneHot(parseInt('训练数字'), 36).arraySync(); //张量
    
    await this.model.fit(tf.tensor(XData), tf.tensor(YData), {
        epochs:10, //训练次数
        callbacks: { //每次的回调
            onEpochEnd(epoch, logs) {
                // acc 准确率
                // loss 损失函数
                console.log('onEpochEnd:',logs);
            }
        }
    });

3、预测

 // 手写数据 (白色背景 黑色文字 28x28)
 let data = [];
 
 //结果也是tensor
 let predictions = this.model.predict(tf.tensor([data]));
 
 //获取tensor第一层里的最大值的index,正好就是数字本身
 let result = predictions.argMax(1).arraySync();
 
 //预测结果
 let predictResults = predictions.arraySync()[0]; 
 console.log(result[0]);

4、在线预览

识别0-9和A-Z wangpeng1478.github.io/web/tensorf…

image.png

6、文档

机器学习库 tensorflow.google.cn/js?hl=zh-cn

机器学习库可视化 js.tensorflow.org/api_vis/1.5…