「这是我参与11月更文挑战的第10天,活动详情查看:2021最后一次更文挑战」
这是分享是参考老外一个视频,希望自己不仅是翻译的搬运工,而且原分享的基础融入自己内容。给大家带来更多更好玩的 AI 项目。
<script src="https://cdn.jsdelivr.net/npm/p5@1.4.0/lib/p5.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.4/dist/tf.min.js"></script>
强化学习(reinforcement learning) 通常都会用于 game,让智能体去玩游戏,因为游戏环境相对比较简单和固定,所以适合研究强化学习。
设计思路
这里我们设计一个神经网络,输入不是图片,而是画面信息数据作为向量,然后这个可以反映当前游戏的状态信息(5 维向量,具体哪些信息可以参考下面的输入内容),然后经过一个隐藏层(8单元神经元组成隐藏层)后输出向上和向下概率。也就是一个分类问题
输入
- bird y(y 坐标)
- bird y vel(y 轴速度)
- top pipe 距离最近的上方管道位置
- bottom pipe 距离最近下方管道的位置
- x 距离管道的距离
let inputs = [];
inputs[0] = this.y / height;
inputs[1] = closest.top / height;
inputs[2] = closest.bottom / height;
inputs[3] = closest.x / width;
inputs[4] = this.velocity / 10;
其实理想输入一张图片,然后经过卷积神经网络提取图像的特征。
输出
输出两个动作分别是向下和向上的概率
- 向下
- 向上
if (output[0] > output[1]) {
this.up();
}
代码实现
引入 tensorflow.js
tf.memory()
{unreliable: true, reasons: Array(1), numTensors: 1000, numDataBuffers: 1000, numBytes: 66000}
创建 nn.js
this.brain = new NeuralNetwork(5, 8, 2);
这里构造函数接收 3 参数为 a, b, c 分别表示输入数、隐藏层和输出节点数
class NeuralNetwork{
constructor(a,b,c){
}
}
class NeuralNetwork{
constructor(a,b,c){
this.input_nodes = a;
this.hidden_nodes = b;
this.output_nodes = c;
}
}
定义模型
接下来创建一个模型 this.createModel(),用于创建模型,这里用 tf.sequential 来创建序列容器,如果用于过 keras 应该对这个 API 不会陌生。
class NeuralNetwork{
constructor(a,b,c){
this.input_nodes = a;
this.hidden_nodes = b;
this.output_nodes = c;
this.createModel();
}
createModel(){
this.model = tf.sequential()
}
}
定义隐藏层
createModel(){
this.model = tf.sequential();
const hiddn = tf.layers.dense({
units:this.hidden_nodes,
inputShape:[this.input_nodes],
activation:'sigmoid'
})
this.model.add(hiddn)
}
这里定义一个隐藏层为全连接层,tf.layers.dense接收一个对象,
- units: 指定该层神经元个数
- inputShape: 输入节点数
- activation: 指定激活函数 如果你对神经网络还不了解,激活函数是给神经网络层之间添加非线性变换
定义输出层
const output = tf.layers.dense({
units:this.output_nodes,
activation:'softmax'
});
this.model.add(output)
预测
在预测过程中,需要组件将 JavaScript 的数组对象转换为 tensorflow tensor 张量,
predict(inputs){
const xs = tf.tensor2d([inputs]);
const ys = this.model.predict(xs);
const outputs = ys.dataSync();
console.log(outputs);
return outputs;
}