《前端开发者的人工智能第一课》第二节:用React状态管理训练线性回归
🌟 本节重点
- 将UI状态管理迁移到模型训练
- 实现可视化梯度下降
- 用Hooks封装训练流程
- 理解损失函数与组件更新的相似性
第二节:用React状态管理训练线性回归
2.1 课前准备
# 创建React+TF.js项目
npm create vite@latest react-regression -- --template react
cd react-regression
npm install @tensorflow/tfjs @tensorflow/tfjs-react-vis
2.2 数据可视化:散点图即组件状态
// 模拟数据集生成
const useTrainingData = () => {
const [data, setData] = useState(() => {
return Array.from({ length: 100 }, () => ({
x: Math.random() * 10,
y: Math.random() * 10 + 2
}))
});
// 类似拖拽交互更新数据集
const addPoint = (e) => {
const rect = e.target.getBoundingClientRect();
setData([...data, {
x: (e.clientX - rect.left) / rect.width * 10,
y: (e.clientY - rect.top) / rect.height * 10
}]);
};
return [data, addPoint];
};
2.3 模型定义:权重即组件状态
const useLinearModel = () => {
const [weights, dispatch] = useReducer(
(state, action) => {
// 类似Redux的reducer处理参数更新
switch (action.type) {
case 'UPDATE_WEIGHTS':
return { ...state, ...action.payload };
case 'RESET':
return { w: Math.random(), b: Math.random() };
default:
return state;
}
},
{ w: Math.random(), b: Math.random() }
);
// 前向传播类似表单校验
const predict = useCallback((x) => {
return weights.w * x + weights.b;
}, [weights]);
return { weights, predict, dispatch };
};
2.4 训练循环:useEffect替代Epoch迭代
const useModelTraining = ({ data, weights, predict }) => {
const [lossHistory, setLossHistory] = useState([]);
const lr = 0.01; // 类似动画的duration参数
useEffect(() => {
if (data.length < 2) return;
// 类似requestAnimationFrame的训练循环
const timer = setInterval(() => {
// 损失计算:均方误差
const loss = data.reduce((sum, {x, y}) => {
return sum + Math.pow(predict(x) - y, 2);
}, 0) / data.length;
// 梯度计算(手动推导版)
const dw = data.reduce((sum, {x, y}) =>
sum + 2 * (predict(x) - y) * x, 0) / data.length;
const db = data.reduce((sum, {x, y}) =>
sum + 2 * (predict(x) - y), 0) / data.length;
// 参数更新:类似setState批量更新
setLossHistory(h => [...h, loss]);
dispatch({
type: 'UPDATE_WEIGHTS',
payload: {
w: weights.w - lr * dw,
b: weights.b - lr * db
}
});
}, 100); // 类似动画帧率控制
return () => clearInterval(timer);
}, [data, weights]);
return { lossHistory };
};
2.5 实时可视化:损失曲线与回归线
// 回归线绘制组件
const RegressionLine = ({ w, b }) => {
const path = useMemo(() => {
return `M 0 ${b} L 100 ${w*10 + b}`;
}, [w, b]);
return <path d={path} stroke="#ff4757" strokeWidth="2" />;
};
// 使用tfjs-vis显示损失曲线
const LossChart = ({ lossHistory }) => {
const containerRef = useRef();
useEffect(() => {
if (!containerRef.current) return;
tfvis.render.linechart(
containerRef.current,
{ values: [lossHistory], series: ['loss'] },
{ xLabel: 'Epoch', yLabel: 'Loss' }
);
}, [lossHistory]);
return <div ref={containerRef} style={{ width: '400px', height: '300px' }} />;
};
📚 课后实践
- 尝试修改学习率观察训练震荡现象
- 给数据添加噪声后观察模型表现
- 实现动量优化器(类似动画缓动函数)
💡 下节预告
《用Redux管理卷积神经网络训练》
- Action分派训练进度
- Store存储模型参数
- Middleware处理异步训练任务