微信小程序端上如何直接运行小模型

265 阅读5分钟

一、引言

作为前端开发,老是会被别人嘲讽为“抠图仔”。我时常想,前端平时除了写界面还能做啥。于是我决定试试使用端上算法开发一款证件照生成应用。由于官方的文档(传送门)对于没有算法经验的前端来说看起来还是很吃力,代码仓库示例也无法直接运行,我决定将自己的开发心得以及踩过的坑跟大家分享一下。

二、产品效果展示

原图

image.png

生成效果

7ae58956a42c35e7e613dfb45871f588.jpg

三、技术背景

作为一款证件照自动生成工具,需要解决两个问题:

  1. 找到图片中人脸的位置,并将人脸裁剪出来。
  2. 将裁剪后的图片中的人像抠出来。

其中第一点,微信小程序有自带人脸检测的API。第二点需要一款抠图模型,这里我们采用开源的modnet模型

四、实战演练

4.1 环境准备

本项目采用微信官方wepy框架开发,版本如下:

  1. node: 14
  2. wepy: 2.1.0

4.2核心功能实现

4.2.1 获取人脸区域

获取人脸区域没什么好讲,直接调用官方api就行。直接上代码!

/**
 * 获取人脸区域
 * @param {*} imageInfo 图片信息
 * @param {number} timeout 超时时间,毫秒
 * @returns { x: number, y: number, width: number, height: number } 人脸区域
 */
export function getFaceOrigin(imageInfo, timeout = 2000) {
  const { width, height, imgData } = imageInfo
  return new Promise((resolve, reject) => {
    const session = wx.createVKSession({
      track: {
        plane: {
          mode: 1
        },
        face: { mode: 2 } // mode: 1 - 使用摄像头;2 - 手动传入图像
      },
      version: 'v1'
    })

    session.on('updateAnchors', (anchors) => {
      session.stop()
      resolve(anchors[0])
    })

    session.start((errno) => {
      if (errno) {
        // 如果失败,将返回 errno
        reject(errno)
      } else {
        // 否则,返回null,表示成功
        session.detectFace({
          scoreThreshold: 0.5, // 评分阈值
          sourceType: 1,
          modelModel: 1,
          frameBuffer: imgData.buffer,
          width: width,
          height: height
        })
      }
    })

    setTimeout(() => {
      session.stop()
      reject(new Error('timeout'))
    }, timeout)
  })
}
4.2.2 抠图

这一步是关键点,下面我将详细讲解步骤

4.2.2.1 加载模型文件

由于小程序代码包大小有限制,模型文件只能从云端下载。推荐将模型文件上传到阿里魔搭平台,上传后能得到下载链接,这里有个坑,由于魔搭下载链接会重定向,所以在配置小程序服务器域名时需要添加多个魔搭域名:

  1. cdn-lfs-cn-1.modelscope.cn
  2. modelscope.cn
4.2.2.2 模型输入参数预处理

模型一般无法直接处理图片数据,需要我们预处理一下。具体怎么处理,需要看onnx模型的定义,这里我们可以通过在线网站netron查看onnx模型文件定义。步骤如下:

1.打开网站上传模型 image.png

2.查看模型输入格式

image.png

  • input代表输入参数名.
  • tensor代表输入参数格式,这里的格式是float32[batch_size,3,height,width]
  • batch_size代表批处理图片数量,通常我们一次只处理一张图片。
  • 3代表图片通道数,我们通过canvas api获取到ImageData是Uint8ClampedArray数组,数据示例[r1,g1,b1,a1,r2,g2,b2,a2...],包含RGBA四个通道,通常视觉算法需要去掉a通道,只保留rgb三个通道。
  • height、width分别代表图片高宽,有的模型会指定图片高宽,未指定的我们可以直接使用原图的高宽。
  • 最后我们将图片数据转为float32格式,这个在算法领域有个专门的名词,叫归一化(Normalization)。实现起来很简单,rgb颜色取值范围是[0, 255],归一化就是将[0, 255]转化为[0, 1]。以下是代码示例:
/**
 * 将图片数据归一化处理
 * @param {Uint8ClampedArray} data 图片数据
 * @param {number} width 图片宽度
 * @param {number} height 图片高度
 * @returns {Float32Array} 转换后的数据
 */
export function normalization(data, width, height) {
  // 3 代表通道数
  const dstInput = new Float32Array(3 * width * height)

  // 遍历每个像素点
  for (let i = 0; i < width * height; i++) {
    const r = data[i * 4] / 255 // Red通道归一化
    const g = data[i * 4 + 1] / 255 // Green通道归一化
    const b = data[i * 4 + 2] / 255 // Blue通道归一化

    // 这里我们只需要RGB三个通道数据
    dstInput[i] = r // Red channel (C=0)
    dstInput[i + width * height] = g // Green channel (C=1)
    dstInput[i + 2 * width * height] = b // Blue channel (C=2)
  }
  return dstInput
}
4.2.2.3 模型推理

这一步直接调用小程序推理AI,直接上代码!

/**
 * 模型推理
 * @param {*} normalizeData  归一化数据
 * @returns {Promise<Float32Array>} 推理后的数据
 */
export async function process(normalizeData) {
  return new Promise(async (resolve, reject) => {
    const session = wx.createInferenceSession({
      model: '模型本地地址',
      precisionLevel: 4,
      allowNPU: false, // 是否使用 NPU 推理,仅针对 IOS 有效
      allowQuantize: false,
      // 模型参数定义
      typicalShape: { input: [1, 3, 512, 512] }
    })

    // 监听error事件
    session.onError((error) => {
      console.error(mattingPath + ' load error: ', error)
      return reject(error)
    })

    await new Promise((resolve) => {
      session.onLoad(() => {
        console.log(mattingPath + ' model session load')
        resolve()
      })
    })

    // 运行模型
    const [runErr, tensor] = await to(
      session.run({
        input: {
          shape: [1, 3, 512, 512],
          data: normalizeData.buffer,
          type: 'float32'
        }
      })
    )

    if (runErr) {
      console.error('run model error:', runErr)
      return reject(runErr)
    }

    // 释放模型
    session.destroy()

    return resolve(new Float32Array(tensor.output.data))
  })
}
4.2.2.4 模型输出后处理

模型输出的数据是黑白的掩码图片数据

image.png

我们需要遍历掩码图片的像素点,如果是像素点是白色就需要将原图对应的像素点提取出来,否则就填充透明像素点,具体代码如下:

/**
 * 模型图片进行抠图
 * @param {Uint8ClampedArray} maskData  模型数据
 * @param {Uint8ClampedArray} originalData  原始图片数据
 * @returns { Uint8ClampedArray } 抠图后的图片数据
 */
export function getTargetImageData(maskData, originalData) {
  const resultData = new Uint8ClampedArray(maskData.byteLength)

  for (let i = 0; i < maskData.length; i += 4) {
    const red = maskData[i]
    const green = maskData[i + 1]
    const blue = maskData[i + 2]

    // 判断掩码图片像素点是否是黑色
    if (red !== 0 && green !== 0 && blue !== 0) {
      resultData[i] = originalData[i]
      resultData[i + 1] = originalData[i + 1]
      resultData[i + 2] = originalData[i + 2]
      resultData[i + 3] = originalData[i + 3]
    } else {
      resultData[i + 3] = 0
    }
  }
  return resultData
}

至此,大功告成了!!!

五、项目地址

  1. 本项目代码已经开源,github传送门
  2. 微信预览
image.png

原创内容,禁止转载。