PS:点赞,评论,收藏,分享 防止迷路
有人的地方就有江湖,有编程的地方就有js。大模型不仅仅是服务端的事情,今天我们使用js 大模型库来实现通过描述明星的特征识别具体是哪个明星,坐稳发车
一 Brain.js 训练神经网络的 JavaScript 库
Brain.js 是一个用于训练神经网络的 JavaScript 库,可以在浏览器或 Node.js 环境中运行。它主要被用来进行机器学习任务,如分类、回归分析等。以下是 Brain.js 的几个关键特点:
- 易于使用:Brain.js 提供了简单易用的 API,使得开发者即使没有深厚的机器学习背景也能快速上手。
- 灵活性:支持自定义网络结构,可以根据具体需求调整网络层和参数。
- 性能优化:利用 WebGL 进行硬件加速,提高训练速度。
- 社区活跃:拥有活跃的开源社区,提供了丰富的示例和文档。
安装
在 Node.js 环境中安装 Brain.js 可以通过 npm 来完成:
npm install brain.js
基本使用示例
以下是一个简单的示例,展示如何使用 Brain.js 训练一个基本的神经网络模型:
const brain = require('brain.js');
// 创建一个新的神经网络实例
const net = new brain.NeuralNetwork();
// 准备训练数据
const trainingData = [
{ input: [0, 0], output: [0] },
{ input: [0, 1], output: [1] },
{ input: [1, 0], output: [1] },
{ input: [1, 1], output: [0] }
];
// 训练神经网络
net.train(trainingData);
// 使用训练好的模型进行预测
const output = net.run([1, 0]); // 输出接近 1
console.log(output);
常见应用场景
- 图像识别:虽然 Brain.js 主要适用于较小规模的数据集,但它也可以用于简单的图像识别任务。
- 文本分类:可以用于情感分析、垃圾邮件过滤等文本分类任务。
- 时间序列预测:可用于股票价格预测、天气预报等时间序列数据分析。
二 实战
步骤
- 定义明星的名字和对应的标签。
- 将文本描述转换为数值特征。
- 准备训练数据。
- 训练神经网络。
- 使用训练好的模型进行预测。
- 将预测结果转换回明星的名字。
const brain = require("brain.js");
// 定义明星的名字和对应的标签
const stars = ["刘德华", "张学友", "郭富城"];
// 创建一个新的神经网络实例,禁用 GPU 加速
const net = new brain.NeuralNetwork({
gpu: false, //是否开启gpu加速
});
// 定义特征词典
const featureDictionary = {
高鼻梁: [1, 0, 0],
大眼睛: [0, 1, 0],
小眼睛: [0, 0, 1],
短发: [1, 0, 0],
长发: [0, 1, 0],
圆脸: [0, 0, 1],
方脸: [1, 0, 0],
尖下巴: [0, 1, 0],
厚嘴唇: [0, 0, 1],
// 可以根据需要添加更多特征
};
// 将文本描述转换为数值特征
function textToFeatures(text) {
const features = [0, 0, 0, 0, 0, 0, 0, 0]; // 初始化特征向量
const words = text.split(" ");
words.forEach((word) => {
if (featureDictionary[word]) {
featureDictionary[word].forEach((value, index) => {
features[index] += value;
});
}
});
return features;
}
// 准备训练数据
const trainingData = [
{ input: textToFeatures("高鼻梁 大眼睛"), output: [1, 0, 0] }, // 刘德华
{ input: textToFeatures("高鼻梁 小眼睛"), output: [0, 1, 0] }, // 张学友
{ input: textToFeatures("短发 圆脸"), output: [0, 0, 1] }, // 郭富城
{ input: textToFeatures("长发 方脸"), output: [1, 0, 0] }, // 刘德华
{ input: textToFeatures("尖下巴 厚嘴唇"), output: [0, 1, 0] }, // 张学友
{ input: textToFeatures("高鼻梁 短发"), output: [0, 0, 1] }, // 郭富城
];
net.train(trainingData, {
errorThresh: 0.005, // error threshold to reach 训练过程中允许的最大误差阈值。当误差低于这个值时,训练停止。
iterations: 20000, // maximum training iterations 最大训练迭代次数。即使误差没有达到 `errorThresh`,也会在达到最大迭代次数时停止训练。
log: true, // console.log() progress periodically 是否在控制台中记录训练进度。
logPeriod: 10, // number of iterations between logging 记录训练进度的频率,单位为迭代次数。例如,每 10 次迭代记录一次。
learningRate: 0.3, // learning rate 学习率,控制权重更新的速度。较大的学习率可能导致训练不稳定,较小的学习率可能导致训练缓慢。
});
// 使用训练好的模型进行预测
const inputText = "高鼻梁 大眼睛";
const inputFeatures = textToFeatures(inputText);
const output = net.run(inputFeatures);
console.log(`预测特征输出: ${output}`);
// 将预测结果转换回明星的名字
const predictedStarIndex = output.indexOf(Math.max(...output));
const predictedStarName = stars[predictedStarIndex];
console.log(`预测的明星是: ${predictedStarName}`);
运行后数据
node .\brainjs\index.js
三 参数解释
net.train()
方法用于训练神经网络,它接受两个参数:训练数据和可选的训练选项。
参数详细解释:
-
trainingData:
- 类型: Array
- 描述: 包含训练样本的数组。每个样本是一个对象,包含
input
和output
属性。 - 示例:
const trainingData = [ { input: [0, 0], output: [0] }, { input: [0, 1], output: [1] }, { input: [1, 0], output: [1] }, { input: [1, 1], output: [0] } ];
-
options (可选):
- 类型: Object
- 描述: 一个包含训练选项的对象。这些选项可以控制训练过程的行为。
- 常用选项:
- errorThresh:
- 类型: Number
- 默认值: 0.005
- 描述: 训练过程中允许的最大误差阈值。当误差低于这个值时,训练停止。
- iterations:
- 类型: Number
- 默认值: 20000
- 描述: 最大训练迭代次数。即使误差没有达到
errorThresh
,也会在达到最大迭代次数时停止训练。
- log:
- 类型: Boolean
- 默认值: false
- 描述: 是否在控制台中记录训练进度。
- logPeriod:
- 类型: Number
- 默认值: 10
- 描述: 记录训练进度的频率,单位为迭代次数。例如,每 10 次迭代记录一次。
- learningRate:
- 类型: Number
- 默认值: 0.3
- 描述: 学习率,控制权重更新的速度。较大的学习率可能导致训练不稳定,较小的学习率可能导致训练缓慢。
- momentum:
- 类型: Number
- 默认值: 0.1
- 描述: 动量项,有助于加速梯度下降过程,特别是在平坦区域。
- callback:
- 类型: Function
- 默认值: null
- 描述: 在每次迭代后调用的回调函数。可以用于自定义日志记录或其他操作。
- callbackPeriod:
- 类型: Number
- 默认值: 10
- 描述: 调用回调函数的频率,单位为迭代次数。
- timeout:
- 类型: Number
- 默认值: Infinity
- 描述: 训练超时时间,单位为毫秒。如果训练时间超过这个值,训练将被终止。
- errorThresh:
示例
代码中,net.train()
的调用如下:
net.train(trainingData, {
errorThresh: 0.005, // 错误阈值
iterations: 20000, // 最大训练迭代次数
log: true, // 控制台记录训练进度
logPeriod: 10, // 每 10 次迭代记录一次
learningRate: 0.3, // 学习率
});
解释
- errorThresh: 0.005: 当训练误差低于 0.005 时,训练停止。
- iterations: 20000: 最多进行 20000 次迭代,即使误差没有达到 0.005。
- log: true: 在控制台中记录训练进度。
- logPeriod: 10: 每 10 次迭代记录一次训练进度。
- learningRate: 0.3: 学习率为 0.3,控制权重更新的速度。
希望这篇文章对你有帮助。如果有其他问题或需要进一步的说明,评论区交流!
PS:摸鱼创作不易 学会了记得,点赞,评论,收藏,分享