仅需两分钟,前端也能上手 tensorflow 神经网络❓❓❓

468 阅读7分钟

前言

关于神经网络的介绍网上已经有非常多了,这篇文章就不再赘述太多,本文主要介绍具体使用步骤与一个简单的使用案例。使用了 vite + react 框架,神经网络的库使用了 @tensorflow/tfjs。这是最终的页面

步骤

我们先过一遍神经网络训练的步骤:

  1. 处理数据集

  2. 定义模型

    1. 神经网络层数
    2. 每层节点数
    3. 每层的激活函数
  3. 编译模型

    1. 学习率
    2. 损失函数
  4. 训练模型

  5. 使用模型

好的,接下来就让我们来一步步实现一个神经网络,本文例子是鸢尾花的分类。

鸢尾花分类是一个经典的机器学习样例,鸢尾花数据集包含花朵四个维度的信息,分别是花瓣、花萼的长度和宽度测量的记录。

这个数据集一共包括150条记录,每一条记录对应一种鸢尾花,鸢尾花共有三个类别,分别是:山鸢尾(Serosa)、变色鸢尾(Versicolor)和维吉尼亚鸢尾(Virginica),数据集示例如下:

5.1,3.3,1.7,0.5,setosa

5.7,2.8,4.1,1.3,versicolor

6.3,3.3,6.0,2.5,virginica

我们需要训练出一个模型,根据花瓣、花萼的长度和宽度,就能判断出花的类型

处理数据集

  1. 读取 txt,整理为对象数组格式并打乱
const parseIrisData = (data: string): IrisData[] => {
  return data
    .split(/\r?\n/)
    .filter((line) => line.trim())
    .map((line) => {
      const [sepalL, sepalW, petalL, petalW, species] = line.split(",");
      return {
        features: [sepalL, sepalW, petalL, petalW].map(Number),
        label: species.trim() as IrisSpecies,
      };
    });
};

const response = await fetch("src/assets/iris.txt");
const data = await response.text();
const parsedData = parseIrisData(data); // 处理成有 { feature, label } 的 150 个对象的数组
tf.util.shuffle(parsedData); // 打乱

setTestSamples(parsedData.slice(140)); // 取后 10 个作为测试集
  1. 转化为两个对象,并对标签进行 one-hot 操作

one-hot 是什么?例如,假设我们的词汇表有5个单词:[‘我’, ‘你’, ‘他’, ‘她’, ‘它’]。那么,’我’就可以表示为[1, 0, 0, 0, 0],’你’可以表示为[0, 1, 0, 0, 0],以此类推。

const SPECIES: IrisSpecies[] = ["setosa", "versicolor", "virginica"];
const NUM_CLASSES = SPECIES.length;

const convertToTensor = (data: IrisData[]) => ({
  xs: tf.tensor2d(data.map((d) => d.features)),
  ys: tf.oneHot(
    tf.tensor1d(
      data.map((d) => SPECIES.indexOf(d.label)),
      "int32"
    ),
    NUM_CLASSES
  ),
});

convertToTensor(parsedData.slice(0, 140)); // 取前 140 个去训练和验证

定义模型

const createModel = () => {
    // 利用 tf.sequential 创建顺序模型。顺序模型是指任何一层的输出是下一层的输入的模型
    return tf.sequential({
      // 这里可以有很多层,但推荐一两层就差不多了,这里是举例写了 4 层
      layers: [
        tf.layers.dense({
          units: 6, // 随便取一个神经元数量,大致在输入参数的1-3倍之间
          activation: "relu", // 激活函数,随便找一个非线性函数,一般是 relu 就行
          inputShape: [4], // 第一层需要告知输入参数的个数,也就是 4
        }),        
        tf.layers.dense({
          units: MODEL_CONFIG.hiddenUnits,
          activation: "relu",
        }),
        tf.layers.dense({
          units: MODEL_CONFIG.hiddenUnits,
          activation: "relu",
        }),
        // 输出层
        tf.layers.dense({
          units: 3, // 分类的神经网络,输出层需要控制分类的数量,3 种类型的花就是3
          activation: "softmax", // 分类的神经网络,最后的激活函数应该是 softmax 用于分类
        }),
      ],
    });
};

编译模型

const model = createModel();

model.compile({
  optimizer: tf.train.adam(0.01), // 学习率,设个 0.01 看着调就行,学的太慢就 *10,跨度太大导致结果不好就 /10
  loss: "categoricalCrossentropy", // 损失函数,分类的损失函数
  metrics: ["accuracy"], // 指标,accuracy 为准确率
});

训练模型

await model.fit(train.xs, train.ys, {
  epochs: 100, // 轮数写个100,不够就加,够了就减
  batchSize: 32, // 不必理会,直接 32
  validationSplit: 0.2, // 将数据集拆多少去做验证集,0.2 也就是 20%
  callbacks: {
    onEpochEnd: (epoch, logs) => {
      if (!logs) return;
      // 每轮结束打印日志
      setTrainingLogs((prev) => [
        ...prev,
        {
          epoch: epoch + 1,
          loss: Number(logs.loss?.toFixed(4)) || 0,
          accuracy: Number(logs.acc?.toFixed(4)) || 0,
        },
      ]);
    },
  },
});
setModel(model); // 训练完以后保存供后续使用

// 导出
const exportModel = async () => {
    if (!model) return;
    await model.save("downloads://iris-model");
};

// 导入
const model = await tf.loadLayersModel("src/assets/iris-model.json");

使用模型

  const predict = (features: number[]) => {
    const currentModel = model;
    if (!currentModel) return;

    const input = tf.tensor2d([features]);
    const prediction = currentModel.predict(input) as tf.Tensor; // 使用的是分类模型,所以得到了包含各个分类概率的张量
    const probabilities = Array.from(prediction.dataSync()); // 将张量转化为普通数组
    const predictedIndex = prediction.argMax(1).dataSync()[0]; // 找到最大概率的 label 的 index

    // 展示结果
    setPrediction({
      label: SPECIES[predictedIndex],
      confidence: Math.round(probabilities[predictedIndex] * 100),
    });
  };

至此,你就已经训练好一个完整的神经网络了。

意义

回顾一下过程,其实没有想象的那么困难,很多复杂的地方,工具包自然就帮你处理好了。

另外很多比如多少层、用什么激活函数、多少个节点、训多少轮,这些都是非常经验主义的,也没个具体的标准,大概定个值,跑两遍自己看着损失和准确度调整一下就可以了,目标就是以最小的算力和时间调教出最好的模型。

这时候有同学就会问了,imoo imoo,作为一个前端,学这些有什么意义呢?

不知道,我希望有一天我也能做出厉害的东西,也许每一环都不容易,但总得探索,总得出击。

附录-完整源码

import { useState } from "react";
import * as tf from "@tensorflow/tfjs";

interface IrisData {
  features: number[];
  label: IrisSpecies;
}

type IrisSpecies = "setosa" | "versicolor" | "virginica";

// 在文件顶部添加类型定义和常量
const SPECIES: IrisSpecies[] = ["setosa", "versicolor", "virginica"];
const NUM_CLASSES = SPECIES.length;

// 提取模型配置常量
const MODEL_CONFIG = {
  hiddenUnits: 6,
  learningRate: 0.05,
  epochs: 100,
  batchSize: 32,
  validationSplit: 0.2,
};

// 提取数据解析函数
const parseIrisData = (data: string): IrisData[] => {
  return data
    .split(/\r?\n/)
    .filter((line) => line.trim())
    .map((line) => {
      const [sepalL, sepalW, petalL, petalW, species] = line.split(",");
      return {
        features: [sepalL, sepalW, petalL, petalW].map(Number),
        label: species.trim() as IrisSpecies,
      };
    });
};

// 提取Tensor转换函数
const convertToTensor = (data: IrisData[]) => ({
  xs: tf.tensor2d(data.map((d) => d.features)),
  ys: tf.oneHot(
    tf.tensor1d(
      data.map((d) => SPECIES.indexOf(d.label)),
      "int32"
    ),
    NUM_CLASSES
  ),
});

function App() {
  const [model, setModel] = useState<tf.Sequential>();
  const [prediction, setPrediction] = useState<{
    label: string;
    confidence: number;
  }>();
  const [trainingLogs, setTrainingLogs] = useState<
    { epoch: number; loss: number; accuracy: number }[]
  >([]);
  const [testSamples, setTestSamples] = useState<IrisData[]>([]);
  const [selectedTestIndex, setSelectedTestIndex] = useState<number>(-1);
  const [loadedModel, setLoadedModel] = useState<tf.LayersModel>();
  const [isLoadingModel, setIsLoadingModel] = useState(false);
  const [loadError, setLoadError] = useState<string>();
  const [selectedModelType, setSelectedModelType] = useState<
    "trained" | "loaded"
  >("trained");

  // 添加样式常量
  const styles = {
    container: {
      maxWidth: 1200,
      margin: "0 auto",
      padding: 40,
      fontFamily: "'Segoe UI', sans-serif",
    },
    title: {
      color: "#2c3e50",
      textAlign: "center",
      marginBottom: 40,
    },
    controlSection: {
      display: "flex",
      gap: 16,
      alignItems: "center",
      marginBottom: 40,
    },
    trainButton: {
      backgroundColor: "#3498db",
      color: "white",
      border: "none",
      padding: "12px 24px",
      borderRadius: 4,
      cursor: "pointer",
    },
    statusIndicator: {
      width: 10,
      height: 10,
      backgroundColor: "#27ae60",
      borderRadius: "50%",
    },
    section: {
      backgroundColor: "#f8f9fa",
      padding: 24,
      borderRadius: 8,
      marginBottom: 40,
      boxShadow: "0 2px 4px rgba(0,0,0,0.1)",
    },
    logItem: {
      display: "flex",
      gap: 16,
    },
    sampleSelect: {
      padding: 8,
      border: "1px solid #ddd",
      borderRadius: 4,
      minWidth: 250,
      marginTop: 16,
    },
    predictionResult: (isCorrect: boolean) => ({
      marginTop: 16,
      padding: 16,
      borderRadius: 4,
      backgroundColor: "#fff",
      borderLeft: `4px solid ${isCorrect ? "#27ae60" : "#e74c3c"}`,
    }),
    comparison: {
      display: "grid",
      gridTemplateColumns: "1fr 1fr",
      gap: 16,
      marginTop: 16,
    },
    valueText: {
      fontWeight: "bold",
      color: "#2c3e50",
      fontSize: "1.1rem",
    },
    loadButton: {
      backgroundColor: "#9b59b6",
      color: "white",
      border: "none",
      padding: "12px 24px",
      borderRadius: 4,
      cursor: "pointer",
    },
    errorText: {
      color: "#e74c3c",
      marginLeft: 8,
    },
  };

  // 新增公共数据加载函数
  const loadAndPrepareData = async () => {
    const response = await fetch("src/assets/iris.txt");
    const data = await response.text();
    const parsedData = parseIrisData(data);

    tf.util.shuffle(parsedData);
    if (testSamples.length === 0) {
      setTestSamples(parsedData.slice(140));
      setSelectedTestIndex(0);
    }

    return parsedData;
  };

  // 修改后的loadData函数
  const loadData = async () => {
    const parsedData = await loadAndPrepareData();
    return convertToTensor(parsedData.slice(0, 140));
  };

  // 优化后的模型创建函数
  const createModel = () => {
    return tf.sequential({
      layers: [
        tf.layers.dense({
          units: MODEL_CONFIG.hiddenUnits,
          activation: "relu",
          inputShape: [4],
        }),
        tf.layers.dense({
          units: MODEL_CONFIG.hiddenUnits,
          activation: "relu",
        }),
        tf.layers.dense({
          units: MODEL_CONFIG.hiddenUnits,
          activation: "relu",
        }),
        tf.layers.dense({
          units: NUM_CLASSES,
          activation: "softmax",
        }),
      ],
    });
  };

  const trainModel = async () => {
    setTrainingLogs([]);
    const model = createModel();

    model.compile({
      optimizer: tf.train.adam(MODEL_CONFIG.learningRate),
      loss: "categoricalCrossentropy",
      metrics: ["accuracy"],
    });

    const train = await loadData();

    await model.fit(train.xs, train.ys, {
      epochs: MODEL_CONFIG.epochs,
      batchSize: MODEL_CONFIG.batchSize,
      validationSplit: MODEL_CONFIG.validationSplit,
      callbacks: {
        onEpochEnd: (epoch, logs) => {
          if (!logs) return;
          setTrainingLogs((prev) => [
            ...prev,
            {
              epoch: epoch + 1,
              loss: Number(logs.loss?.toFixed(4)) || 0,
              accuracy: Number(logs.acc?.toFixed(4)) || 0, // 修正为正确的指标名称
            },
          ]);
        },
      },
    });

    setModel(model);
  };

  // 修改后的loadExternalModel函数
  const loadExternalModel = async () => {
    setIsLoadingModel(true);
    setLoadError(undefined);

    try {
      const model = await tf.loadLayersModel("src/assets/iris-model.json");
      const parsedData = await loadAndPrepareData();

      if (parsedData.length > 0) {
        predict(parsedData[140].features);
      }

      setLoadedModel(model);
      alert("模型加载成功!");
    } catch (error) {
      console.error("模型加载失败:", error);
      setLoadError("加载失败,请检查模型文件是否存在");
    }
    setIsLoadingModel(false);
  };

  const predict = (features: number[], modelType = selectedModelType) => {
    const currentModel = modelType === "trained" ? model : loadedModel;
    if (!currentModel) return;

    const input = tf.tensor2d([features]);
    const prediction = currentModel.predict(input) as tf.Tensor; // 使用的是分类模型,所以得到了各个分类的概率
    const probabilities = Array.from(prediction.dataSync()); // 将概率转化为普通数组
    const predictedIndex = prediction.argMax(1).dataSync()[0]; // 获取概率最大的分类的索引

    setPrediction({
      label: SPECIES[predictedIndex],
      confidence: Math.round(probabilities[predictedIndex] * 100),
    });

    tf.dispose(input);
  };

  // 添加模型导出方法
  const exportModel = async () => {
    if (!model) return;

    // 保存模型为文件下载
    await model.save("downloads://iris-model");
  };

  // 提取状态指示器组件
  const StatusIndicator = ({
    color,
    label,
  }: {
    color: string;
    label: string;
  }) => (
    <div style={{ display: "flex", alignItems: "center", gap: 8 }}>
      <div style={{ ...styles.statusIndicator, backgroundColor: color }} />
      <span>{label}</span>
    </div>
  );

  return (
    <div style={styles.container}>
      <h1>鸢尾花分类器</h1>
      <div style={styles.controlSection}>
        <button style={styles.trainButton} onClick={trainModel}>
          {model ? "重新训练模型" : "开始训练"}
        </button>

        {model && (
          <>
            <button
              style={{ ...styles.trainButton, backgroundColor: "#2ecc71" }}
              onClick={exportModel}
            >
              导出模型
            </button>
            <StatusIndicator color="#27ae60" label="模型已就绪" />
          </>
        )}

        <button
          style={styles.loadButton}
          onClick={loadExternalModel}
          disabled={isLoadingModel}
        >
          {isLoadingModel ? "加载模型中..." : "加载外部模型"}
        </button>

        {loadError && <span style={styles.errorText}>{loadError}</span>}
        {loadedModel && (
          <StatusIndicator color="#9b59b6" label="外部模型已加载" />
        )}
      </div>

      <div style={styles.section}>
        <h2>训练进度</h2>
        <div style={{ maxHeight: "200px", overflowY: "auto" }}>
          {trainingLogs.map((log) => (
            <div key={log.epoch} style={styles.logItem}>
              <div>轮数 {log.epoch}</div>
              <div>损失: {log.loss}</div>
              <div>准确率: {(log.accuracy * 100).toFixed(1)}%</div>
            </div>
          ))}
          {!trainingLogs.length && <div>训练尚未开始</div>}
        </div>
      </div>

      <div style={styles.section}>
        <h2>模型测试</h2>
        <div style={{ display: "flex", gap: 16, alignItems: "center" }}>
          <div>
            <label>选择测试模型:</label>
            <select
              style={styles.sampleSelect}
              value={selectedModelType}
              onChange={(e) => {
                const type = e.target.value as "trained" | "loaded";
                setSelectedModelType(type);
                if (testSamples.length > 0 && selectedTestIndex !== -1) {
                  predict(testSamples[selectedTestIndex].features, type);
                }
              }}
            >
              <option value="trained" disabled={!model}>
                自训练模型{model ? "" : "(未训练)"}
              </option>
              <option value="loaded" disabled={!loadedModel}>
                外部模型{loadedModel ? "" : "(未加载)"}
              </option>
            </select>
          </div>
          <div>
            <label>选择测试样本:</label>
            <select
              style={styles.sampleSelect}
              value={selectedTestIndex}
              onChange={(e) => {
                const index = Number(e.target.value);
                setSelectedTestIndex(index);
                if (testSamples.length > 0 && index !== -1) {
                  predict(testSamples[index].features, selectedModelType);
                }
              }}
              disabled={!testSamples.length}
            >
              {testSamples.length > 0 ? (
                testSamples.map((sample, index) => (
                  <option key={index} value={index}>
                    样本 #{index + 1} ({sample.label})
                  </option>
                ))
              ) : (
                <option value="-1">请先开始训练</option>
              )}
            </select>
          </div>
        </div>

        {prediction && (
          <div
            style={styles.predictionResult(
              prediction.label === testSamples[selectedTestIndex].label
            )}
          >
            <div style={{ display: "flex", justifyContent: "space-between" }}>
              <h3>
                预测结果(使用
                {selectedModelType === "trained" ? "自训练" : "外部"}模型)
              </h3>
              <span>{prediction.confidence}% 置信度</span>
            </div>
            <div style={styles.comparison}>
              <div>
                <label>模型预测</label>
                <div style={styles.valueText}>{prediction.label}</div>
              </div>
              <div>
                <label>实际品种</label>
                <div style={styles.valueText}>
                  {testSamples[selectedTestIndex].label}
                </div>
              </div>
            </div>
          </div>
        )}
      </div>
    </div>
  );
}

export default App;