前言
关于神经网络的介绍网上已经有非常多了,这篇文章就不再赘述太多,本文主要介绍具体使用步骤与一个简单的使用案例。使用了 vite + react 框架,神经网络的库使用了 @tensorflow/tfjs
。这是最终的页面
步骤
我们先过一遍神经网络训练的步骤:
-
处理数据集
-
定义模型
- 神经网络层数
- 每层节点数
- 每层的激活函数
-
编译模型
- 学习率
- 损失函数
-
训练模型
-
使用模型
好的,接下来就让我们来一步步实现一个神经网络,本文例子是鸢尾花的分类。
鸢尾花分类是一个经典的机器学习样例,鸢尾花数据集包含花朵四个维度的信息,分别是花瓣、花萼的长度和宽度测量的记录。
这个数据集一共包括150条记录,每一条记录对应一种鸢尾花,鸢尾花共有三个类别,分别是:山鸢尾(Serosa)、变色鸢尾(Versicolor)和维吉尼亚鸢尾(Virginica),数据集示例如下:
5.1,3.3,1.7,0.5,setosa
5.7,2.8,4.1,1.3,versicolor
6.3,3.3,6.0,2.5,virginica
我们需要训练出一个模型,根据花瓣、花萼的长度和宽度,就能判断出花的类型
处理数据集
- 读取 txt,整理为对象数组格式并打乱
const parseIrisData = (data: string): IrisData[] => {
return data
.split(/\r?\n/)
.filter((line) => line.trim())
.map((line) => {
const [sepalL, sepalW, petalL, petalW, species] = line.split(",");
return {
features: [sepalL, sepalW, petalL, petalW].map(Number),
label: species.trim() as IrisSpecies,
};
});
};
const response = await fetch("src/assets/iris.txt");
const data = await response.text();
const parsedData = parseIrisData(data); // 处理成有 { feature, label } 的 150 个对象的数组
tf.util.shuffle(parsedData); // 打乱
setTestSamples(parsedData.slice(140)); // 取后 10 个作为测试集
- 转化为两个对象,并对标签进行 one-hot 操作
one-hot 是什么?例如,假设我们的词汇表有5个单词:[‘我’, ‘你’, ‘他’, ‘她’, ‘它’]。那么,’我’就可以表示为[1, 0, 0, 0, 0],’你’可以表示为[0, 1, 0, 0, 0],以此类推。
const SPECIES: IrisSpecies[] = ["setosa", "versicolor", "virginica"];
const NUM_CLASSES = SPECIES.length;
const convertToTensor = (data: IrisData[]) => ({
xs: tf.tensor2d(data.map((d) => d.features)),
ys: tf.oneHot(
tf.tensor1d(
data.map((d) => SPECIES.indexOf(d.label)),
"int32"
),
NUM_CLASSES
),
});
convertToTensor(parsedData.slice(0, 140)); // 取前 140 个去训练和验证
定义模型
const createModel = () => {
// 利用 tf.sequential 创建顺序模型。顺序模型是指任何一层的输出是下一层的输入的模型
return tf.sequential({
// 这里可以有很多层,但推荐一两层就差不多了,这里是举例写了 4 层
layers: [
tf.layers.dense({
units: 6, // 随便取一个神经元数量,大致在输入参数的1-3倍之间
activation: "relu", // 激活函数,随便找一个非线性函数,一般是 relu 就行
inputShape: [4], // 第一层需要告知输入参数的个数,也就是 4
}),
tf.layers.dense({
units: MODEL_CONFIG.hiddenUnits,
activation: "relu",
}),
tf.layers.dense({
units: MODEL_CONFIG.hiddenUnits,
activation: "relu",
}),
// 输出层
tf.layers.dense({
units: 3, // 分类的神经网络,输出层需要控制分类的数量,3 种类型的花就是3
activation: "softmax", // 分类的神经网络,最后的激活函数应该是 softmax 用于分类
}),
],
});
};
编译模型
const model = createModel();
model.compile({
optimizer: tf.train.adam(0.01), // 学习率,设个 0.01 看着调就行,学的太慢就 *10,跨度太大导致结果不好就 /10
loss: "categoricalCrossentropy", // 损失函数,分类的损失函数
metrics: ["accuracy"], // 指标,accuracy 为准确率
});
训练模型
await model.fit(train.xs, train.ys, {
epochs: 100, // 轮数写个100,不够就加,够了就减
batchSize: 32, // 不必理会,直接 32
validationSplit: 0.2, // 将数据集拆多少去做验证集,0.2 也就是 20%
callbacks: {
onEpochEnd: (epoch, logs) => {
if (!logs) return;
// 每轮结束打印日志
setTrainingLogs((prev) => [
...prev,
{
epoch: epoch + 1,
loss: Number(logs.loss?.toFixed(4)) || 0,
accuracy: Number(logs.acc?.toFixed(4)) || 0,
},
]);
},
},
});
setModel(model); // 训练完以后保存供后续使用
// 导出
const exportModel = async () => {
if (!model) return;
await model.save("downloads://iris-model");
};
// 导入
const model = await tf.loadLayersModel("src/assets/iris-model.json");
使用模型
const predict = (features: number[]) => {
const currentModel = model;
if (!currentModel) return;
const input = tf.tensor2d([features]);
const prediction = currentModel.predict(input) as tf.Tensor; // 使用的是分类模型,所以得到了包含各个分类概率的张量
const probabilities = Array.from(prediction.dataSync()); // 将张量转化为普通数组
const predictedIndex = prediction.argMax(1).dataSync()[0]; // 找到最大概率的 label 的 index
// 展示结果
setPrediction({
label: SPECIES[predictedIndex],
confidence: Math.round(probabilities[predictedIndex] * 100),
});
};
至此,你就已经训练好一个完整的神经网络了。
意义
回顾一下过程,其实没有想象的那么困难,很多复杂的地方,工具包自然就帮你处理好了。
另外很多比如多少层、用什么激活函数、多少个节点、训多少轮,这些都是非常经验主义的,也没个具体的标准,大概定个值,跑两遍自己看着损失和准确度调整一下就可以了,目标就是以最小的算力和时间调教出最好的模型。
这时候有同学就会问了,imoo imoo,作为一个前端,学这些有什么意义呢?
不知道,我希望有一天我也能做出厉害的东西,也许每一环都不容易,但总得探索,总得出击。
附录-完整源码
import { useState } from "react";
import * as tf from "@tensorflow/tfjs";
interface IrisData {
features: number[];
label: IrisSpecies;
}
type IrisSpecies = "setosa" | "versicolor" | "virginica";
// 在文件顶部添加类型定义和常量
const SPECIES: IrisSpecies[] = ["setosa", "versicolor", "virginica"];
const NUM_CLASSES = SPECIES.length;
// 提取模型配置常量
const MODEL_CONFIG = {
hiddenUnits: 6,
learningRate: 0.05,
epochs: 100,
batchSize: 32,
validationSplit: 0.2,
};
// 提取数据解析函数
const parseIrisData = (data: string): IrisData[] => {
return data
.split(/\r?\n/)
.filter((line) => line.trim())
.map((line) => {
const [sepalL, sepalW, petalL, petalW, species] = line.split(",");
return {
features: [sepalL, sepalW, petalL, petalW].map(Number),
label: species.trim() as IrisSpecies,
};
});
};
// 提取Tensor转换函数
const convertToTensor = (data: IrisData[]) => ({
xs: tf.tensor2d(data.map((d) => d.features)),
ys: tf.oneHot(
tf.tensor1d(
data.map((d) => SPECIES.indexOf(d.label)),
"int32"
),
NUM_CLASSES
),
});
function App() {
const [model, setModel] = useState<tf.Sequential>();
const [prediction, setPrediction] = useState<{
label: string;
confidence: number;
}>();
const [trainingLogs, setTrainingLogs] = useState<
{ epoch: number; loss: number; accuracy: number }[]
>([]);
const [testSamples, setTestSamples] = useState<IrisData[]>([]);
const [selectedTestIndex, setSelectedTestIndex] = useState<number>(-1);
const [loadedModel, setLoadedModel] = useState<tf.LayersModel>();
const [isLoadingModel, setIsLoadingModel] = useState(false);
const [loadError, setLoadError] = useState<string>();
const [selectedModelType, setSelectedModelType] = useState<
"trained" | "loaded"
>("trained");
// 添加样式常量
const styles = {
container: {
maxWidth: 1200,
margin: "0 auto",
padding: 40,
fontFamily: "'Segoe UI', sans-serif",
},
title: {
color: "#2c3e50",
textAlign: "center",
marginBottom: 40,
},
controlSection: {
display: "flex",
gap: 16,
alignItems: "center",
marginBottom: 40,
},
trainButton: {
backgroundColor: "#3498db",
color: "white",
border: "none",
padding: "12px 24px",
borderRadius: 4,
cursor: "pointer",
},
statusIndicator: {
width: 10,
height: 10,
backgroundColor: "#27ae60",
borderRadius: "50%",
},
section: {
backgroundColor: "#f8f9fa",
padding: 24,
borderRadius: 8,
marginBottom: 40,
boxShadow: "0 2px 4px rgba(0,0,0,0.1)",
},
logItem: {
display: "flex",
gap: 16,
},
sampleSelect: {
padding: 8,
border: "1px solid #ddd",
borderRadius: 4,
minWidth: 250,
marginTop: 16,
},
predictionResult: (isCorrect: boolean) => ({
marginTop: 16,
padding: 16,
borderRadius: 4,
backgroundColor: "#fff",
borderLeft: `4px solid ${isCorrect ? "#27ae60" : "#e74c3c"}`,
}),
comparison: {
display: "grid",
gridTemplateColumns: "1fr 1fr",
gap: 16,
marginTop: 16,
},
valueText: {
fontWeight: "bold",
color: "#2c3e50",
fontSize: "1.1rem",
},
loadButton: {
backgroundColor: "#9b59b6",
color: "white",
border: "none",
padding: "12px 24px",
borderRadius: 4,
cursor: "pointer",
},
errorText: {
color: "#e74c3c",
marginLeft: 8,
},
};
// 新增公共数据加载函数
const loadAndPrepareData = async () => {
const response = await fetch("src/assets/iris.txt");
const data = await response.text();
const parsedData = parseIrisData(data);
tf.util.shuffle(parsedData);
if (testSamples.length === 0) {
setTestSamples(parsedData.slice(140));
setSelectedTestIndex(0);
}
return parsedData;
};
// 修改后的loadData函数
const loadData = async () => {
const parsedData = await loadAndPrepareData();
return convertToTensor(parsedData.slice(0, 140));
};
// 优化后的模型创建函数
const createModel = () => {
return tf.sequential({
layers: [
tf.layers.dense({
units: MODEL_CONFIG.hiddenUnits,
activation: "relu",
inputShape: [4],
}),
tf.layers.dense({
units: MODEL_CONFIG.hiddenUnits,
activation: "relu",
}),
tf.layers.dense({
units: MODEL_CONFIG.hiddenUnits,
activation: "relu",
}),
tf.layers.dense({
units: NUM_CLASSES,
activation: "softmax",
}),
],
});
};
const trainModel = async () => {
setTrainingLogs([]);
const model = createModel();
model.compile({
optimizer: tf.train.adam(MODEL_CONFIG.learningRate),
loss: "categoricalCrossentropy",
metrics: ["accuracy"],
});
const train = await loadData();
await model.fit(train.xs, train.ys, {
epochs: MODEL_CONFIG.epochs,
batchSize: MODEL_CONFIG.batchSize,
validationSplit: MODEL_CONFIG.validationSplit,
callbacks: {
onEpochEnd: (epoch, logs) => {
if (!logs) return;
setTrainingLogs((prev) => [
...prev,
{
epoch: epoch + 1,
loss: Number(logs.loss?.toFixed(4)) || 0,
accuracy: Number(logs.acc?.toFixed(4)) || 0, // 修正为正确的指标名称
},
]);
},
},
});
setModel(model);
};
// 修改后的loadExternalModel函数
const loadExternalModel = async () => {
setIsLoadingModel(true);
setLoadError(undefined);
try {
const model = await tf.loadLayersModel("src/assets/iris-model.json");
const parsedData = await loadAndPrepareData();
if (parsedData.length > 0) {
predict(parsedData[140].features);
}
setLoadedModel(model);
alert("模型加载成功!");
} catch (error) {
console.error("模型加载失败:", error);
setLoadError("加载失败,请检查模型文件是否存在");
}
setIsLoadingModel(false);
};
const predict = (features: number[], modelType = selectedModelType) => {
const currentModel = modelType === "trained" ? model : loadedModel;
if (!currentModel) return;
const input = tf.tensor2d([features]);
const prediction = currentModel.predict(input) as tf.Tensor; // 使用的是分类模型,所以得到了各个分类的概率
const probabilities = Array.from(prediction.dataSync()); // 将概率转化为普通数组
const predictedIndex = prediction.argMax(1).dataSync()[0]; // 获取概率最大的分类的索引
setPrediction({
label: SPECIES[predictedIndex],
confidence: Math.round(probabilities[predictedIndex] * 100),
});
tf.dispose(input);
};
// 添加模型导出方法
const exportModel = async () => {
if (!model) return;
// 保存模型为文件下载
await model.save("downloads://iris-model");
};
// 提取状态指示器组件
const StatusIndicator = ({
color,
label,
}: {
color: string;
label: string;
}) => (
<div style={{ display: "flex", alignItems: "center", gap: 8 }}>
<div style={{ ...styles.statusIndicator, backgroundColor: color }} />
<span>{label}</span>
</div>
);
return (
<div style={styles.container}>
<h1>鸢尾花分类器</h1>
<div style={styles.controlSection}>
<button style={styles.trainButton} onClick={trainModel}>
{model ? "重新训练模型" : "开始训练"}
</button>
{model && (
<>
<button
style={{ ...styles.trainButton, backgroundColor: "#2ecc71" }}
onClick={exportModel}
>
导出模型
</button>
<StatusIndicator color="#27ae60" label="模型已就绪" />
</>
)}
<button
style={styles.loadButton}
onClick={loadExternalModel}
disabled={isLoadingModel}
>
{isLoadingModel ? "加载模型中..." : "加载外部模型"}
</button>
{loadError && <span style={styles.errorText}>{loadError}</span>}
{loadedModel && (
<StatusIndicator color="#9b59b6" label="外部模型已加载" />
)}
</div>
<div style={styles.section}>
<h2>训练进度</h2>
<div style={{ maxHeight: "200px", overflowY: "auto" }}>
{trainingLogs.map((log) => (
<div key={log.epoch} style={styles.logItem}>
<div>轮数 {log.epoch}</div>
<div>损失: {log.loss}</div>
<div>准确率: {(log.accuracy * 100).toFixed(1)}%</div>
</div>
))}
{!trainingLogs.length && <div>训练尚未开始</div>}
</div>
</div>
<div style={styles.section}>
<h2>模型测试</h2>
<div style={{ display: "flex", gap: 16, alignItems: "center" }}>
<div>
<label>选择测试模型:</label>
<select
style={styles.sampleSelect}
value={selectedModelType}
onChange={(e) => {
const type = e.target.value as "trained" | "loaded";
setSelectedModelType(type);
if (testSamples.length > 0 && selectedTestIndex !== -1) {
predict(testSamples[selectedTestIndex].features, type);
}
}}
>
<option value="trained" disabled={!model}>
自训练模型{model ? "" : "(未训练)"}
</option>
<option value="loaded" disabled={!loadedModel}>
外部模型{loadedModel ? "" : "(未加载)"}
</option>
</select>
</div>
<div>
<label>选择测试样本:</label>
<select
style={styles.sampleSelect}
value={selectedTestIndex}
onChange={(e) => {
const index = Number(e.target.value);
setSelectedTestIndex(index);
if (testSamples.length > 0 && index !== -1) {
predict(testSamples[index].features, selectedModelType);
}
}}
disabled={!testSamples.length}
>
{testSamples.length > 0 ? (
testSamples.map((sample, index) => (
<option key={index} value={index}>
样本 #{index + 1} ({sample.label})
</option>
))
) : (
<option value="-1">请先开始训练</option>
)}
</select>
</div>
</div>
{prediction && (
<div
style={styles.predictionResult(
prediction.label === testSamples[selectedTestIndex].label
)}
>
<div style={{ display: "flex", justifyContent: "space-between" }}>
<h3>
预测结果(使用
{selectedModelType === "trained" ? "自训练" : "外部"}模型)
</h3>
<span>{prediction.confidence}% 置信度</span>
</div>
<div style={styles.comparison}>
<div>
<label>模型预测</label>
<div style={styles.valueText}>{prediction.label}</div>
</div>
<div>
<label>实际品种</label>
<div style={styles.valueText}>
{testSamples[selectedTestIndex].label}
</div>
</div>
</div>
</div>
)}
</div>
</div>
);
}
export default App;