GPT横空出世的时代,人工智能还远嘛

105 阅读3分钟

什么是机器学习?

从翻译应用到自动驾驶汽车,我们利用一些最重要的技术为机器学习提供支持。从根本上来讲,机器学习是对一种软件(称为模型)进行训练的过程,用于进行实用的预测或通过数据生成内容。

机器学习系统根据学习进行预测或生成内容的方式,分为以下一个或多个类别:

  • 监督式学习
  • 非监督式学习
  • 强化学习
  • 生成式 AI

网站机器学习的应用场景

  • 增强现实AR
  • 基于手势或肢体的交互
  • 语音识别
  • 无障碍网站
  • 语义分析
  • 智能会话

浏览器机器学习的优势

  • 保护用户隐私
  • 减少流量
  • 提高用户体验
  • 降低服务器 成本

代码实战

1、前端页面设计

首先,采用BootStrap框架来美化页面的样式

<div class="container" style="width: 50%;margin-left: 8%;">
  <div class="header" style="text-align: center;margin-top: 8%;margin-bottom: 5%">
    <h2>手写数字数据集</h2>
  </div>

  <label class="form-label">请在手写板中输入您想预测的数字:</label><br>
  <div style="text-align: center;margin-top: 2%">
    <canvas width="300" height="300" style="border: 2px solid black"></canvas>
  </div>
  <div style="text-align: center">
    <button onclick="window.clear()" style="margin: 4px" class="btn btn-danger">清除</button>
    <button onclick="window.predict()" style="margin: 4px;" class="btn btn-primary">预测</button>
    <button onclick="downloadModel()" class="btn btn-dark">下载模型</button>
  </div>
  <div class="blockquote-footer" style="margin-top: 30%;text-align: center">
    <span>@made by <a href="https://gitee.com/yao-wen_long">姚文龙</a></span>
  </div>
</div>

1)画布

在页面的HTML代码里面定义

定义完成之后,实现页面中的清除功能按钮的逻辑功能,当用户点击,就可以清理画布里的内容

window.clear = () => {
  const ctx = 画布.getContext("2d");
  ctx.fillStyle = "rgb(0,0,0)"
  ctx.fillRect(0,0,300,300)
}

然后,实现画布的核心功能用户按下鼠标可以在画布上绘制图形

const 画布 = document.querySelector("canvas")
            画布.addEventListener("mousemove",(e)=>{
                if(e.buttons === 1){
                    const ctx = 画布.getContext("2d");
                    ctx.fillStyle = "rgb(255,255,255)"
                    ctx.fillRect(e.offsetX,e.offsetY,10,10)
                }
            })

2、数据预处理

const MNIST图像精灵路径 =
            'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png';
const MNIST标签路径 =
            'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8';

然后,开始划分训练数据集和测试数据集

      const [训练Xs,训练Ys] = tf.tidy(() => {
                const d = 数据.nextTrainBatch(3000);
                return [
                    d.xs.reshape([3000,28,28,1]),
                    d.labels
                ]
            })
            const [测试Xs,测试Ys] = tf.tidy(() => {
                const d = 数据.nextTestBatch(500);
                return [
                    d.xs.reshape([500,28,28,1]),
                    d.labels
                ]
            })

3、自定义模型

借助TensorFlow.js官方提供tf.sequential()方法创建一个顺序模型,并通过tf.sequential.add()方法给模型添加两个卷积层和一个全连接层。(卷积层用于提取特征,例如图像中的边缘和纹理,而连接层用于将这些特征组合起来以进行更高级别的图像识别或分类)

     // 模型结构搭建
            const 模型 = tf.sequential();

            // 添加卷积层
            模型.add(tf.layers.conv2d({
                inputShape:[28,28,1],
                kernelSize:5,
                filters:8,
                strides:1,
                activation:"relu",
                kernelInitializer:"varianceScaling"
            }))
            模型.add(tf.layers.batchNormalization())
            模型.add(tf.layers.maxPool2d({
                poolSize:[2,2],
                strides:[2,2]
            }))

            // 添加更多卷积层和池化层...

            模型.add(tf.layers.flatten())

            // 添加全连接层
            模型.add(tf.layers.dense({
                units:10,
                activation:"softmax",
                kernelInitializer:"varianceScaling"
            }))

网络模型结构

4、配置模型

在训练模型前,还需要进行一些配置

分别指定了训练过程中所用到的损失函数、优化器及评价指标

   模型.compile({
                loss:"categoricalCrossentropy",
                optimizer:tf.train.adam(0.01),
                metrics:"accuracy"
            })

5、训练模型

完成模型训练过程中的参数配置后,就可以将准备好的数据送入模型并开始训练了

// 训练模型
await 模型.fit(训练Xs,训练Ys,{
  validationData:[测试Xs,测试Ys],
  epochs:50,
  callbacks:tfvis.show.fitCallbacks(
    {name:"训练过程",tab:"训练数据"},
    ["loss","val_loss","acc","val_acc"],
    {callbacks:["onEpochEnd"]}
  )
})

6、模型预测

训练好模型后,在画布上绘制0~9的数字并通过模型进行预测

// 预测手写数字
window.predict = () => {
  const 输入 = tf.tidy(()=>{
    return tf.image.resizeBilinear(
      tf.browser.fromPixels(画布),
      [28,28],
      true
    )
      .slice([0,0,0],[28,28,1])
      .toFloat()
      .div(255)
      .reshape([1,28,28,1])
  })

  const 输出 = 模型.predict(输入).argMax(1).dataSync()[0]
  alert(`预测结果为:${输出}`)
}

7、效果展示