理解神经网络:Brain.js 背后的核心思想(二)

803 阅读22分钟

温馨提示

这篇文章篇幅较长,主要是为后续内容做铺垫和说明。如果你觉得文字太多,可以:

  1. 先收藏,等后面文章遇到不懂的地方再回来查阅。
  2. 直接跳读,重点关注加粗或高亮的部分。

放心,这种“文字轰炸”不会常有的,哈哈~ 感谢你的耐心阅读!😊

欢迎来到 brain.js 的学习之旅!

无论你是零基础的新手,还是已经有一定编程经验的开发者,这个系列都将为你提供一个系统、全面的学习路径。我们将从最基础的概念开始,逐步深入到实际应用和高级技巧,最终让你能够自信地构建和训练自己的神经网络模型。

以下是我们的学习路线图:

brainJS-roadmap

这一系列文章从入门到进阶,涵盖了 brain.js 的核心功能、技术细节以及实际应用场景。不仅适合初学者学习和实践,也为有一定基础的开发者提供了更多扩展和深入的思考方向。接下来,我们进入系列的第一部分:基础篇。

Snipaste_2025-01-24_12-01-41

一、什么是神经网络?

1.1 神经网络的定义

神经网络(Neural Network),全称人工神经网络(Artificial Neural NetworkANN),是一种受生物神经系统启发的计算模型。它通过模拟人脑神经元的连接和工作方式,完成数据的处理和预测任务。

  • 直观理解:神经网络就像一个会学习的系统。它通过处理输入数据,生成与之对应的输出。例如,输入一张手写数字图片,神经网络会将其分类为 0 到 9 中的某一个数字。
  • 核心功能:神经网络擅长模式识别和预测。无论是语音识别、图像分类,还是文本生成,它都展现了强大的适应性和学习能力。

1.2 为什么神经网络如此强大?

传统的算法需要人为设计规则,而神经网络通过训练可以自动学习规则。这使它能够适应各种数据模式和任务场景。总结来说,神经网络是一个“通用函数近似器”,可以从数据中学习规律,并利用这些规律进行推理和预测。

1.3 举例说明:神经网络如何工作?

假设你有一组手写数字图片,神经网络的任务是识别这些数字。它的工作过程如下:

  1. 接收输入数据:将每张图片转化为像素矩阵,例如一个 28x28 的灰度图片会被展平为一个 784 维的向量。
  2. 特征提取:通过隐藏层逐步提取图片中的边缘、形状等关键信息。
  3. 生成输出:根据提取的特征,将图片分类为相应的数字(0 到 9)。

神经网络的强大之处在于它的“学习能力”。它能够从海量数据中提取特征并构建复杂的映射关系。它的核心结构类似于人脑,由大量的“神经元”组成,并通过“权重”连接,形成一个可以自我优化的网络。

二、神经网络的灵感来自大脑

2.1 人脑与人工神经网络的对比

神经网络的概念来源于对人脑结构和工作方式的模拟。人脑是由大约 860亿个神经元数百万亿个突触 组成的复杂网络。它通过这些神经元的连接和协作,完成思维、学习和决策等复杂任务。

人工神经网络(Artificial Neural NetworkANN)通过数学建模,抽象出人脑的部分功能。尽管人工神经网络远不及人脑复杂,但它能够在特定任务中表现得非常出色。

以下是人脑和人工神经网络的一些对比:

特性人脑人工神经网络
基本单元神经元(Neuron人工神经元(Node
信息传递方式电化学信号通过突触传播数值信号通过权重传播
学习能力通过强化学习调整突触连接强度通过训练调整权重和偏置
灵活性高度灵活,能处理多任务通常针对特定任务设计,灵活性较低
能耗高效低耗能能耗高,尤其在大规模训练时

2.2 生物神经元的工作原理

生物神经元由以下三个主要部分组成:

  1. 树突(Dendrite :接收其他神经元传来的信号。
  2. 细胞体(Soma :对接收到的信号进行整合和处理,并决定是否激活神经元。
  3. 轴突(Axon :如果神经元被激活,轴突将信号传递给下一个神经元。

信号传播过程

  • 树突接收到多个信号,并传递到细胞体。
  • 如果信号强度超过某个阈值,神经元会“激活”,产生动作电位
  • 动作电位通过轴突传播到突触,影响下一个神经元。

2.3 生物神经元的工作原理

人工神经元是对生物神经元的数学抽象模型,它通过以下方式工作:

  1. 接收输入信号(Input :每个输入信号代表一个特征值,例如房价预测中的面积或房间数。

  2. 加权求和(Weighted Sum :输入信号会根据重要性赋予不同的权重(Weight)。权重值越大,表示信号对最终结果的影响越大。

    输入信号会根据重要性赋予不同的权重(Weight)。权重值越大,表示信号对最终结果的影响越大。

    z=i=1nwixi+bz = \sum_{i=1}^{n} w_i \cdot x_i + b

    其中:

    • xi:第 i 个输入信号
    • wi:对应的权重
    • b:偏置(Bias),用于调整计算结果
  3. 激活函数(Activation Function :加权求和结果会通过一个激活函数,决定神经元是否激活,以及激活后的输出值。

  4. 输出信号(Output :激活函数的结果被传递到下一层神经元,直到输出层。

2.4 从生物神经元到人工神经元的转化

生物神经元部分对应的人工神经元组件功能
树突(Dendrite输入信号接收外界输入数据,传递到神经元中处理
细胞体(Soma加权求和和激活函数处理输入信号,并根据阈值决定是否激活
轴突(Axon输出信号将激活后的信号传递给下一层的神经元

人工神经元的设计灵感虽然来源于生物神经元,但它的目标是高效计算和任务专用化,而不是完全复制生物神经元的复杂性。通过输入信号、权重、偏置和激活函数的协作,人工神经元能够处理复杂的数据模式并生成输出。

三、人工神经网络的基础结构

人工神经网络由多个神经元按照层级结构排列而成,通常包括以下三个部分:输入层隐藏层输出层。每一层都有特定的功能,它们协同工作以实现数据处理和任务预测。

NNS


3.1 输入层(Input Layer

作用:输入层是神经网络的起点,用于接收外部数据,并将这些数据传递给网络的下一层。

  • 每个输入节点对应一个特征值。例如,在房价预测中,特征可能包括面积、房间数、地理位置等。
  • 输入层本身不对数据进行任何处理,只是将数据作为信号传递到隐藏层。

假设我们有以下数据用于预测房价:

  • 面积:120 平方米
  • 房间数:3
  • 距市中心距离:5 公里

这些数据会作为输入信号传递到网络中。


3.2 隐藏层(Hidden Layer

作用:隐藏层是神经网络的核心计算部分,用于提取数据特征并执行复杂的数学运算。

  • 隐藏层节点会接收来自上一层的输入信号,经过权重计算和激活函数处理后,将结果传递给下一层。
  • 隐藏层的数量和每层节点数可以根据问题的复杂性调整。

隐藏层的计算过程

  1. 对输入信号进行加权求和:

    z=i=1nwixi+bz = \sum_{i=1}^{n} w_i \cdot x_i + b
    • wi:输入信号的权重
    • xi:输入信号
    • b:偏置,用于调整计算结果
  2. 应用激活函数:

    a=f(z)a = f(z)
    • f:激活函数,用于引入非线性,使网络能够学习复杂的模式。

在房价预测中,隐藏层中的一个节点可能专注于“面积”和“房间数”之间的关系,而另一个节点可能关注“距离市中心”和“房价”的关系。通过多层隐藏层的计算,神经网络能够提取到输入数据中的深层次特征。


3.3 输出层(Output Layer

作用:输出层是神经网络的终点,用于生成最终结果。

  • 输出的形式取决于任务类型:

    • 分类任务:输出层通常包含多个节点,每个节点表示一个类别的概率。
    • 回归任务:输出层通常只有一个节点,表示连续值的预测结果。

在手写数字识别中,输出层有 10 个节点,表示数字 0-9 的概率。例如:输出层结果:[0.1, 0.05, 0.8, 0.05, ...],这里第 3 个节点的值为 0.8,表示神经网络预测输入图片是数字“2”。

在房价预测中,输出层可能直接输出房价,例如 200 万元


3.4 激活函数(Activation Function

作用:激活函数是每个节点的“开关”,决定神经元是否被激活,以及如何传递信号到下一层。

  • 如果没有激活函数,神经网络只能处理简单的线性关系,无法应对复杂的非线性问题。
  • 数学意义:激活函数引入非线性,使神经网络具备学习复杂映射关系的能力。

常见激活函数

  1. Sigmoid

    f(x)=11+exf(x) = \frac{1}{1 + e^{-x}}

    将输出限制在 (0, 1) 范围内,适合二分类任务。

  2. ReLURectified Linear Unit

    f(x)=max(0,x)f(x) = \max(0, x)

    简单高效,广泛用于隐藏层节点。

  3. Tanh(双曲正切函数)

    f(x)=exexex+exf(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}

    输出范围为 (-1, 1),对称性比 Sigmoid 更适合有正负值的输入。

  4. Softmax

    f(xi)=exij=1nexjf(x_i) = \frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}}

    通常用于输出层,用于多分类任务,将所有输出值转化为概率分布。

image.png


3.5 权重和偏置(Weight and Bias

权重(Weight

  • 权重是神经网络的核心参数,用于衡量每个输入信号的重要性。
  • 在训练过程中,神经网络会不断调整权重值,以最小化预测误差。

偏置(Bias

  • 偏置是一个额外的参数,用于调整神经元的激活值。
  • 作用类似于线性方程中的截距,可以提高模型的灵活性。

四、神经网络的训练过程

神经网络的训练过程是一个逐步优化的过程,其目标是使网络能够准确地从数据中学习规律。通过不断调整权重和偏置,网络逐渐提高预测精度,最终生成可靠的输出结果。


4.1 数据预处理

在开始训练之前,数据的质量和格式非常重要。神经网络对输入数据的要求通常包括:

  1. 归一化(Normalization

    将输入数据的值缩放到一个较小的范围(如 [0, 1][-1, 1]),以便提高训练的稳定性。

    示例:在房价预测中,如果面积从 50 到 500 平方米,归一化后可以转换为 [0.1, 1.0]

  2. 数据分割(Data Splitting

    将数据分为三部分:

    • 训练集:用于更新网络权重。
    • 验证集:用于调整超参数和选择最佳模型。
    • 测试集:用于评估模型的最终性能。

    常见比例:70% 训练集,20% 验证集,10% 测试集。


4.2 前向传播(Forward Propagation

前向传播 是神经网络的第一步,数据从输入层传递到输出层,并逐层计算出每一层的输出值(激活值)。

  1. 接收输入信号:数据从输入层流入,作为信号传递到隐藏层。

  2. 加权求和:每个隐藏层节点计算加权和:

    z=i=1nwixi+bz = \sum_{i=1}^{n} w_i \cdot x_i + b

    其中:

    • wi:权重
    • xi:输入信号
    • b:偏置
  3. 激活函数:使用激活函数将加权和转换为非线性输出:

    a=f(z)a = f(z)
  4. 输出信号:每层的输出信号会传递到下一层,最终到达输出层,生成最终预测值。


4.3 损失函数(Loss Function

损失函数 是衡量网络预测输出与实际目标之间差距的指标。通过最小化损失函数,网络能够提高预测精度。

常见损失函数

  1. 均方误差(MSE :常用于回归任务,公式为:

    MSE=1ni=1n(yiy^i)2MSE = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2
    • yi:实际值
    • y^i:预测值
  2. 交叉熵损失(Cross-Entropy Loss :常用于分类任务,公式为:

    L=1ni=1n[yilog(y^i)+(1yi)log(1y^i)]L = - \frac{1}{n} \sum_{i=1}^{n} \left[ y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i) \right]

    用于衡量概率分布之间的差异。

网络训练的目标是找到一组权重和偏置,使损失函数的值尽可能小。


4.4 反向传播(Backward Propagation

反向传播 是神经网络学习的核心过程。它通过计算误差的梯度,反向调整网络中的权重和偏置。

  1. 计算误差(Error):对比预测值和真实值,计算损失函数的值。

  2. 计算梯度(Gradient Calculation :根据损失函数对权重和偏置求偏导数,得到梯度。梯度表示误差随权重变化的方向和幅度。

  3. 更新参数(Weight Update

    使用梯度下降算法,根据梯度值调整权重和偏置:

    w=wηLww = w - \eta \cdot \frac{\partial L}{\partial w}
    • η:学习率,决定更新的步幅大小。
    • ∂L/∂w:损失函数对权重的偏导数。
  4. 迭代优化:重复前向传播和反向传播,逐步减小损失函数的值。


4.5 学习率(Learning Rate

学习率 是训练中的关键超参数,它决定了每次更新权重的幅度。

  • 学习率过大:更新步伐过快,可能错过全局最优解。容易导致训练过程不稳定。
  • 学习率过小:更新步伐太慢,训练过程耗时较长。

解决方法:采用动态学习率(Learning Rate Decay),在训练过程中逐渐降低学习率,从而平衡收敛速度与稳定性。


4.6 训练的终止条件

训练何时结束?以下是常见的停止标准:

  1. 达到最大迭代次数(Epochs :设定一个固定的迭代次数。
  2. 损失值收敛:如果损失值在多次迭代后不再明显下降,则认为模型已接近最佳状态。
  3. 验证集误差上升:如果验证集误差开始上升,可能出现过拟合,此时可以提前停止训练。

五、如何用 brain.js 构建神经网络

在前面的内容中,我们已经介绍了神经网络的基本结构和训练过程。接下来,我们将进入实战环节,使用 brain.js 构建一个简单的神经网络模型,并通过示例代码演示如何训练和预测。如果你对这部分内容感兴趣,可以参考我之前写的详细教程:

👉 从零开始:使用 Brain.js 创建你的第一个神经网络(一)

六、神经网络中的训练效果评估

训练神经网络并获得输出结果后,下一步就是评估模型的性能。这不仅能帮助我们了解当前模型的表现,还可以指导我们改进模型的结构和训练过程。


6.1 衡量训练效果的标准

  1. 损失值(Loss

    • 损失值是衡量模型输出与目标输出差异的指标。
    • 训练过程中,损失值的下降趋势可以直观反映网络在学习任务上的改进。
    • 例如,在使用 brain.js 时,可以通过 log 选项查看每次迭代的损失值变化。
  2. 准确率(Accuracy

    • 准确率是分类任务的常用指标,表示预测正确的样本占比。

    • 示例:

      • 总测试样本:100
      • 预测正确样本:90
      • 准确率 = (90 / 100) × 100% = 90%
  3. 误差(Error Rate

    • 误差率表示预测错误的样本占比。

    • 示例:

      • 总测试样本:100
      • 预测错误样本:10
      • 误差率 = (10 / 100) × 100% = 10%
  4. 其他指标

    • 精确率(Precision :预测为正的样本中有多少是真正的正例。
    • 召回率(Recall :所有实际正例中有多少被正确预测。
    • F1 分数(F1-Score :精确率和召回率的调和平均值。

6.2 验证集与测试集

为了更准确地评估模型性能,通常会将数据集分为以下三部分:

  1. 训练集(Training Set :用于调整模型的参数,让网络学习到数据的特征。

  2. 验证集(Validation Set

    • 训练过程中用于评估模型的泛化能力,帮助调整超参数(如隐藏层数量、学习率等)。
    • 如果验证集误差开始上升,可能表明模型出现了过拟合。
  3. 测试集(Test Set

    • 训练完成后,用于评估模型在全新数据上的表现。
    • 通过测试集误差,可以衡量模型的实际预测能力。

6.3 可视化训练过程

通过训练日志监控损失值的变化是直观了解模型学习效果的好方法。

  1. 在训练时启用日志记录:

    net.train(trainingData, {
      log: true,
      logPeriod: 100,
    });
    
  2. 记录损失值:

    const trainingLog = [];
    net.train(trainingData, {
      log: (error) => trainingLog.push(error),
      logPeriod: 10,
    });
    
  3. 可视化:

    使用常见绘图工具(如 matplotlib 或在线图表工具)绘制损失值变化曲线:

    • 横轴:迭代次数
    • 纵轴:损失值

6.4 改进模型性能的策略

如果模型性能不理想,可以尝试以下优化方法:

  1. 数据增强:增加数据量,或者对现有数据进行扩充(如增加随机噪声、改变数据排列顺序)。

  2. 调整网络结构

    • 增加或减少隐藏层数量。
    • 改变每层的节点数。
  3. 修改超参数

    • 调整学习率。
    • 增加训练的迭代次数。
  4. 引入正则化:使用 L1L2 正则化来防止过拟合。

  5. 提前停止(Early Stopping :当验证集误差开始上升时,立即停止训练。

七、优化与调整:如何提升神经网络的性能

在实际项目中,构建一个神经网络仅仅是开始,真正的挑战在于优化它的性能。通过合理的调整和优化策略,我们可以显著提升模型的训练速度、预测准确率以及泛化能力。


7.1 学习率调整

学习率(Learning Rate 决定了每次梯度下降时参数更新的步幅大小。选择合适的学习率是模型训练成功的关键之一。

  1. 学习率过高

    • 参数更新幅度过大,容易跳过最优解。
    • 可能导致损失值震荡甚至不收敛。
  2. 学习率过低

    • 更新步幅过小,训练速度非常慢。
    • 需要更多的迭代次数才能接近最优解。

优化策略

  • 动态学习率:在训练初期使用较大的学习率,加快收敛速度;训练后期逐渐降低学习率,精细优化模型。

  • 学习率衰减(Learning Rate Decay :随着迭代次数增加,按一定比例逐渐减小学习率。例如:

    ηt=η0×11+kt\eta_t = \eta_0 \times \frac{1}{1 + k \cdot t}
    • η0:初始学习率
    • k:衰减率
    • t:当前迭代次数

7.2 正则化

正则化是一种有效防止过拟合的技术。通过在损失函数中引入惩罚项,限制模型参数的大小,正则化能提高模型的泛化能力。

  1. L1 正则化

    • 惩罚权重的绝对值:

      L=Loss+λwL = Loss + \lambda \sum |w|
    • 会使部分权重变为零,从而稀疏化模型。

  2. L2 正则化(最常用):

    • 惩罚权重的平方:

      L=Loss+λw2L = Loss + \lambda \sum w^2
    • 使权重尽可能小,但不会完全变为零。

选择

  • L1 正则化适合希望模型较为稀疏的情况。
  • L2 正则化更常见,用于降低复杂度且保持较高的模型表达能力。

7.3 调整网络结构

  1. 增加或减少隐藏层数量

    • 增加隐藏层:网络能学习更复杂的特征,但也可能增加过拟合风险和训练难度。
    • 减少隐藏层:网络更简单,训练速度更快,但可能不足以处理复杂任务。
  2. 调整每层节点数

    • 更多节点:更强的表达能力,能捕获更复杂的模式。
    • 更少节点:减少计算量,同时降低过拟合风险。

经验法则:从简单的结构开始,逐渐增加复杂度,直到性能不再显著提高为止。


7.4 提前停止(Early Stopping

提前停止 是防止过拟合的一种有效方法。

原理:在训练过程中观察验证集的损失值。当验证损失值开始上升时,停止训练,防止模型继续学习训练集的噪声。

优势

  • 不需要额外的正则化项。
  • 可以节省训练时间。

实现

  • 每个 epoch 后计算验证集的损失值。
  • 设定一个耐心值(patience),例如连续 10 次验证损失值未降低,则终止训练。

7.5 数据增强

数据增强(Data Augmentation 是通过对现有数据集进行变换来生成更多样本的方法。它能显著提升模型的泛化能力,尤其是在样本数量有限的情况下。

  1. 图像数据

    • 随机旋转、平移、裁剪、水平翻转。
    • 改变亮度、对比度、饱和度。
  2. 文本数据

    • 替换同义词、改变句子结构。
    • 添加噪声(例如随机插入或删除单词)。
  3. 数值数据

    • 添加少量随机噪声。
    • 对数值特征进行比例变换。

好处

  • 增强样本多样性。
  • 减少过拟合。
  • 提高模型的鲁棒性。

7.6 使用 GPUTPU 加速

神经网络的训练通常非常耗时,尤其在大规模数据集上。利用 GPUTPU 可以显著加快训练速度。

方法

  • 在服务器上启用 GPU 加速。
  • 如果使用云平台(如 Google CloudAWS),选择支持 GPU 的实例。
  • 本地测试时,确保安装了合适的 GPU 驱动和库。

注意

  • 小型模型或少量数据集在 CPU 上可能已经足够快。
  • 大规模数据和深层网络更适合用 GPU 加速。

7.7 实践示例:改进 XOR 模型

以下是一个针对 XOR 模型的优化示例:

const brain = require('brain.js');
​
// 创建神经网络并设置隐藏层
const net = new brain.NeuralNetwork({
  hiddenLayers: [5, 5], // 调整隐藏层结构
});
​
// 定义训练数据
const trainingData = [
  { input: [0, 0], output: [0] },
  { input: [0, 1], output: [1] },
  { input: [1, 0], output: [1] },
  { input: [1, 1], output: [0] },
];
​
// 训练网络
net.train(trainingData, {
  iterations: 10000,
  learningRate: 0.01,
  log: true,
  errorThresh: 0.001,
});
​
// 测试网络
const output = net.run([1, 0]);
console.log(`预测结果: ${output}`);

改进点

  • 增加隐藏层节点数:从 3 增加到 5,提升了模型的表达能力。
  • 降低误差阈值:从 0.005 调整到 0.001,提高了模型精度。
  • 适当减少学习率:从 0.05 降低到 0.01,保证了训练的稳定性。

八、实际应用:brain.js 在项目中的使用

在了解了 brain.js 的基本功能和优化方法之后,接下来我们将聚焦于如何将 brain.js 应用于实际项目。这部分内容通过简单的代码示例和实际场景,展示 brain.js 在分类、回归和时间序列预测等任务中的应用。


8.1 分类任务:垃圾邮件检测

场景描述: 垃圾邮件检测是一个典型的分类问题。输入是一封邮件的特征(例如是否包含特定关键词),输出是“垃圾邮件”或“正常邮件”。

代码示例

const brain = require('brain.js');
​
// 创建神经网络
const net = new brain.NeuralNetwork({
  hiddenLayers: [3], // 一个简单的隐藏层
});
​
// 定义训练数据
const trainingData = [
  { input: { spamWords: 1, containsLink: 1 }, output: { spam: 1 } },
  { input: { spamWords: 0, containsLink: 1 }, output: { spam: 0 } },
  { input: { spamWords: 1, containsLink: 0 }, output: { spam: 1 } },
  { input: { spamWords: 0, containsLink: 0 }, output: { spam: 0 } },
];
​
// 训练网络
net.train(trainingData, {
  iterations: 2000,
  log: true,
  errorThresh: 0.01,
});
​
// 测试网络
const output = net.run({ spamWords: 1, containsLink: 1 });
console.log(`是否垃圾邮件: ${output.spam > 0.5 ? '是' : '否'}`);

分析

  • 输入数据包括两个特征:

    • spamWords:邮件是否包含垃圾词汇(1 或 0)
    • containsLink:邮件是否包含链接(1 或 0)。
  • 输出结果是一个概率值,接近 1 表示垃圾邮件,接近 0 表示正常邮件。


8.2 回归任务:房价预测

场景描述: 在回归任务中,输出是一个连续值。例如,输入房屋的特征,输出房屋的预测价格。

代码示例

const brain = require('brain.js');
​
// 创建神经网络
const net = new brain.NeuralNetwork({
  hiddenLayers: [5],
});
​
// 定义训练数据
const trainingData = [
  { input: { size: 0.5, rooms: 0.8, distance: 0.2 }, output: [0.7] }, // 房价 70 万
  { input: { size: 0.3, rooms: 0.5, distance: 0.6 }, output: [0.5] }, // 房价 50 万
  { input: { size: 0.8, rooms: 1.0, distance: 0.1 }, output: [0.9] }, // 房价 90 万
];
​
// 训练网络
net.train(trainingData, {
  iterations: 3000,
  log: true,
  errorThresh: 0.01,
});
​
// 测试网络
const output = net.run({ size: 0.6, rooms: 0.9, distance: 0.3 });
console.log(`预测房价: ${Math.round(output[0] * 100)} 万`);

分析

  • 输入特征包括房屋面积、房间数和到市中心的距离。
  • 输出是一个归一化的价格(例如 0.7 表示 70 万),可以通过反归一化还原为实际值。

8.3 时间序列任务:股票价格预测

场景描述: 时间序列任务需要预测未来的趋势。brain.js 支持使用 LSTM(长短时记忆网络)处理时间序列数据,例如预测股票价格的下一步走势。

代码示例

const brain = require('brain.js');
​
// 创建 LSTM 网络
const net = new brain.recurrent.LSTMTimeStep();
​
// 定义训练数据
const trainingData = [
  [10, 20, 30, 40],  // 数据集 1
  [15, 25, 35, 45],  // 数据集 2
  [20, 30, 40, 50],  // 数据集 3
];
​
// 训练网络
net.train(trainingData, {
  iterations: 2000,
  log: true,
  errorThresh: 0.01,
});
​
// 测试网络
const output = net.run([25, 35, 45]); // 输入过去三天的价格
console.log(`预测下一天的价格: ${output}`);

分析

  • 输入数据是过去几天的价格序列,网络通过学习这些时间依赖关系,预测出下一天的价格。
  • LSTM 适合处理时间序列中复杂的依赖模式。

8.4 集成 brain.js 到实际项目中

在实际项目中,神经网络通常需要与其他系统结合使用。例如,可以将 brain.js 模型嵌入到一个 Web 应用中,实现实时预测。

代码示例:结合 Express 实现分类任务

const express = require('express');
const brain = require('brain.js');
const app = express();
​
// 创建神经网络
const net = new brain.NeuralNetwork();
net.train([
  { input: { feature1: 1, feature2: 0 }, output: { category1: 1 } },
  { input: { feature1: 0, feature2: 1 }, output: { category2: 1 } },
]);
​
// 提供预测 API
app.get('/predict', (req, res) => {
  const input = { feature1: parseFloat(req.query.feature1), feature2: parseFloat(req.query.feature2) };
  const output = net.run(input);
  res.send(output);
});
​
// 启动服务
app.listen(3000, () => {
  console.log('服务已启动,访问 http://localhost:3000/predict');
});

运行步骤

  1. 启动服务:node index.js
  2. 在浏览器访问:http://localhost:3000/predict?feature1=1&feature2=0

九、展望

尽管 brain.js 在入门和轻量化应用中表现出色,但在面对更复杂的深度学习任务时,其功能和性能可能会受到限制。因此,在深入学习神经网络时,可以考虑以下方向:

  1. 探索更强大的框架

    • 学习 TensorFlow.jsPyTorch 等框架,它们提供了更加丰富的神经网络结构和工具支持。
    • 尝试卷积神经网络(CNN)、生成对抗网络(GAN)等复杂模型,扩展对神经网络的理解和应用范围。
  2. 深入理解优化方法

    • 学习更高级的优化算法,例如 AdamRMSProp
    • 探索超参数搜索(Hyperparameter Tuning)的方法,找到更优的网络配置。
  3. 实际项目实践

    • 将神经网络应用于真实业务场景,如推荐系统、图像处理、文本生成等。
    • 结合领域知识,开发针对性更强的模型。
  4. 关注神经网络的发展趋势

    • 跟踪最新研究成果,例如自监督学习(Self-Supervised Learning)、元学习(Meta-Learning)。
    • 探索神经网络与其他技术的结合点,如增强学习、迁移学习等。

最后的话

从理解神经网络的基本原理,到使用 brain.js 实现模型,再到优化和扩展,本文为读者提供了一个清晰的学习路径。无论你是初学者还是希望深入探索神经网络的开发者,这些内容都可以作为一个起点。

如果你觉得这篇文章对你有帮助,欢迎点赞、收藏或分享给更多朋友!也欢迎在评论区留下你的想法或问题,我们一起交流讨论!😊

系列文章推荐

如果你想更系统地学习如何使用 brain.js 构建神经网络,可以查看我的系列教程:

👉 从零开始:使用 Brain.js 创建你的第一个神经网络(一)