《前端开发者的人工智能第一课》第二节:用React状态管理训练线性回归

49 阅读2分钟

《前端开发者的人工智能第一课》第二节:用React状态管理训练线性回归

🌟 本节重点

  • 将UI状态管理迁移到模型训练
  • 实现可视化梯度下降
  • 用Hooks封装训练流程
  • 理解损失函数与组件更新的相似性

第二节:用React状态管理训练线性回归

2.1 课前准备

# 创建React+TF.js项目
npm create vite@latest react-regression -- --template react
cd react-regression
npm install @tensorflow/tfjs @tensorflow/tfjs-react-vis

2.2 数据可视化:散点图即组件状态

// 模拟数据集生成
const useTrainingData = () => {
  const [data, setData] = useState(() => {
    return Array.from({ length: 100 }, () => ({
      x: Math.random() * 10,
      y: Math.random() * 10 + 2
    }))
  });
  
  // 类似拖拽交互更新数据集
  const addPoint = (e) => {
    const rect = e.target.getBoundingClientRect();
    setData([...data, {
      x: (e.clientX - rect.left) / rect.width * 10,
      y: (e.clientY - rect.top) / rect.height * 10
    }]);
  };

  return [data, addPoint];
};

2.3 模型定义:权重即组件状态

const useLinearModel = () => {
  const [weights, dispatch] = useReducer(
    (state, action) => {
      // 类似Redux的reducer处理参数更新
      switch (action.type) {
        case 'UPDATE_WEIGHTS':
          return { ...state, ...action.payload };
        case 'RESET':
          return { w: Math.random(), b: Math.random() };
        default:
          return state;
      }
    },
    { w: Math.random(), b: Math.random() }
  );

  // 前向传播类似表单校验
  const predict = useCallback((x) => {
    return weights.w * x + weights.b;
  }, [weights]);

  return { weights, predict, dispatch };
};

2.4 训练循环:useEffect替代Epoch迭代

const useModelTraining = ({ data, weights, predict }) => {
  const [lossHistory, setLossHistory] = useState([]);
  const lr = 0.01; // 类似动画的duration参数

  useEffect(() => {
    if (data.length < 2) return;
    
    // 类似requestAnimationFrame的训练循环
    const timer = setInterval(() => {
      // 损失计算:均方误差
      const loss = data.reduce((sum, {x, y}) => {
        return sum + Math.pow(predict(x) - y, 2);
      }, 0) / data.length;

      // 梯度计算(手动推导版)
      const dw = data.reduce((sum, {x, y}) => 
        sum + 2 * (predict(x) - y) * x, 0) / data.length;
      const db = data.reduce((sum, {x, y}) => 
        sum + 2 * (predict(x) - y), 0) / data.length;

      // 参数更新:类似setState批量更新
      setLossHistory(h => [...h, loss]);
      dispatch({ 
        type: 'UPDATE_WEIGHTS',
        payload: {
          w: weights.w - lr * dw,
          b: weights.b - lr * db
        }
      });
    }, 100); // 类似动画帧率控制

    return () => clearInterval(timer);
  }, [data, weights]);

  return { lossHistory };
};

2.5 实时可视化:损失曲线与回归线

// 回归线绘制组件
const RegressionLine = ({ w, b }) => {
  const path = useMemo(() => {
    return `M 0 ${b} L 100 ${w*10 + b}`;
  }, [w, b]);

  return <path d={path} stroke="#ff4757" strokeWidth="2" />;
};

// 使用tfjs-vis显示损失曲线
const LossChart = ({ lossHistory }) => {
  const containerRef = useRef();
  
  useEffect(() => {
    if (!containerRef.current) return;
    tfvis.render.linechart(
      containerRef.current,
      { values: [lossHistory], series: ['loss'] },
      { xLabel: 'Epoch', yLabel: 'Loss' }
    );
  }, [lossHistory]);

  return <div ref={containerRef} style={{ width: '400px', height: '300px' }} />;
};

📚 课后实践

  1. 尝试修改学习率观察训练震荡现象
  2. 给数据添加噪声后观察模型表现
  3. 实现动量优化器(类似动画缓动函数)

💡 下节预告

《用Redux管理卷积神经网络训练》

  • Action分派训练进度
  • Store存储模型参数
  • Middleware处理异步训练任务