dome wangpeng1478.github.io/web/tensorf…
1、定义模型
//定义线性回归的模型
const model = tf.sequential();
//输入层 28*28 = 784
model.add(tf.layers.inputLayer({inputShape: [784]}));
//输出空间,输出空间为训练范围的长度 (0-9和A-z的长度)
model.add(tf.layers.dense({units: 36}));
//输出空间所有值之和为1
model.add(tf.layers.softmax());
//准备训练模型:指定损失和优化器。
model.compile({
optimizer: 'sgd', //优化器
loss: 'categoricalCrossentropy', //损失函数
metrics: ['accuracy'] //logs里的acc
})
2、开始训练
const XData = []; // 手写数据 (白色背景 黑色文字 28x28)
const YData = tf.oneHot(parseInt('训练数字'), 36).arraySync(); //张量
await this.model.fit(tf.tensor(XData), tf.tensor(YData), {
epochs:10, //训练次数
callbacks: { //每次的回调
onEpochEnd(epoch, logs) {
// acc 准确率
// loss 损失函数
console.log('onEpochEnd:',logs);
}
}
});
3、预测
// 手写数据 (白色背景 黑色文字 28x28)
let data = [];
//结果也是tensor
let predictions = this.model.predict(tf.tensor([data]));
//获取tensor第一层里的最大值的index,正好就是数字本身
let result = predictions.argMax(1).arraySync();
//预测结果
let predictResults = predictions.arraySync()[0];
console.log(result[0]);
4、在线预览
识别0-9和A-Z wangpeng1478.github.io/web/tensorf…
6、文档
机器学习库 tensorflow.google.cn/js?hl=zh-cn
机器学习库可视化 js.tensorflow.org/api_vis/1.5…