一、引言
作为前端开发,老是会被别人嘲讽为“抠图仔”。我时常想,前端平时除了写界面还能做啥。于是我决定试试使用端上算法开发一款证件照生成应用。由于官方的文档(传送门)对于没有算法经验的前端来说看起来还是很吃力,代码仓库示例也无法直接运行,我决定将自己的开发心得以及踩过的坑跟大家分享一下。
二、产品效果展示
原图
生成效果
三、技术背景
作为一款证件照自动生成工具,需要解决两个问题:
- 找到图片中人脸的位置,并将人脸裁剪出来。
- 将裁剪后的图片中的人像抠出来。
其中第一点,微信小程序有自带人脸检测的API。第二点需要一款抠图模型,这里我们采用开源的modnet模型。
四、实战演练
4.1 环境准备
本项目采用微信官方wepy框架开发,版本如下:
- node: 14
- wepy: 2.1.0
4.2核心功能实现
4.2.1 获取人脸区域
获取人脸区域没什么好讲,直接调用官方api就行。直接上代码!
/**
* 获取人脸区域
* @param {*} imageInfo 图片信息
* @param {number} timeout 超时时间,毫秒
* @returns { x: number, y: number, width: number, height: number } 人脸区域
*/
export function getFaceOrigin(imageInfo, timeout = 2000) {
const { width, height, imgData } = imageInfo
return new Promise((resolve, reject) => {
const session = wx.createVKSession({
track: {
plane: {
mode: 1
},
face: { mode: 2 } // mode: 1 - 使用摄像头;2 - 手动传入图像
},
version: 'v1'
})
session.on('updateAnchors', (anchors) => {
session.stop()
resolve(anchors[0])
})
session.start((errno) => {
if (errno) {
// 如果失败,将返回 errno
reject(errno)
} else {
// 否则,返回null,表示成功
session.detectFace({
scoreThreshold: 0.5, // 评分阈值
sourceType: 1,
modelModel: 1,
frameBuffer: imgData.buffer,
width: width,
height: height
})
}
})
setTimeout(() => {
session.stop()
reject(new Error('timeout'))
}, timeout)
})
}
4.2.2 抠图
这一步是关键点,下面我将详细讲解步骤
4.2.2.1 加载模型文件
由于小程序代码包大小有限制,模型文件只能从云端下载。推荐将模型文件上传到阿里魔搭平台,上传后能得到下载链接,这里有个坑,由于魔搭下载链接会重定向,所以在配置小程序服务器域名时需要添加多个魔搭域名:
4.2.2.2 模型输入参数预处理
模型一般无法直接处理图片数据,需要我们预处理一下。具体怎么处理,需要看onnx模型的定义,这里我们可以通过在线网站netron查看onnx模型文件定义。步骤如下:
1.打开网站上传模型
2.查看模型输入格式
input代表输入参数名.tensor代表输入参数格式,这里的格式是float32[batch_size,3,height,width]batch_size代表批处理图片数量,通常我们一次只处理一张图片。3代表图片通道数,我们通过canvas api获取到ImageData是Uint8ClampedArray数组,数据示例[r1,g1,b1,a1,r2,g2,b2,a2...],包含RGBA四个通道,通常视觉算法需要去掉a通道,只保留rgb三个通道。height、width分别代表图片高宽,有的模型会指定图片高宽,未指定的我们可以直接使用原图的高宽。- 最后我们将图片数据转为
float32格式,这个在算法领域有个专门的名词,叫归一化(Normalization)。实现起来很简单,rgb颜色取值范围是[0, 255],归一化就是将[0, 255]转化为[0, 1]。以下是代码示例:
/**
* 将图片数据归一化处理
* @param {Uint8ClampedArray} data 图片数据
* @param {number} width 图片宽度
* @param {number} height 图片高度
* @returns {Float32Array} 转换后的数据
*/
export function normalization(data, width, height) {
// 3 代表通道数
const dstInput = new Float32Array(3 * width * height)
// 遍历每个像素点
for (let i = 0; i < width * height; i++) {
const r = data[i * 4] / 255 // Red通道归一化
const g = data[i * 4 + 1] / 255 // Green通道归一化
const b = data[i * 4 + 2] / 255 // Blue通道归一化
// 这里我们只需要RGB三个通道数据
dstInput[i] = r // Red channel (C=0)
dstInput[i + width * height] = g // Green channel (C=1)
dstInput[i + 2 * width * height] = b // Blue channel (C=2)
}
return dstInput
}
4.2.2.3 模型推理
这一步直接调用小程序推理AI,直接上代码!
/**
* 模型推理
* @param {*} normalizeData 归一化数据
* @returns {Promise<Float32Array>} 推理后的数据
*/
export async function process(normalizeData) {
return new Promise(async (resolve, reject) => {
const session = wx.createInferenceSession({
model: '模型本地地址',
precisionLevel: 4,
allowNPU: false, // 是否使用 NPU 推理,仅针对 IOS 有效
allowQuantize: false,
// 模型参数定义
typicalShape: { input: [1, 3, 512, 512] }
})
// 监听error事件
session.onError((error) => {
console.error(mattingPath + ' load error: ', error)
return reject(error)
})
await new Promise((resolve) => {
session.onLoad(() => {
console.log(mattingPath + ' model session load')
resolve()
})
})
// 运行模型
const [runErr, tensor] = await to(
session.run({
input: {
shape: [1, 3, 512, 512],
data: normalizeData.buffer,
type: 'float32'
}
})
)
if (runErr) {
console.error('run model error:', runErr)
return reject(runErr)
}
// 释放模型
session.destroy()
return resolve(new Float32Array(tensor.output.data))
})
}
4.2.2.4 模型输出后处理
模型输出的数据是黑白的掩码图片数据
我们需要遍历掩码图片的像素点,如果是像素点是白色就需要将原图对应的像素点提取出来,否则就填充透明像素点,具体代码如下:
/**
* 模型图片进行抠图
* @param {Uint8ClampedArray} maskData 模型数据
* @param {Uint8ClampedArray} originalData 原始图片数据
* @returns { Uint8ClampedArray } 抠图后的图片数据
*/
export function getTargetImageData(maskData, originalData) {
const resultData = new Uint8ClampedArray(maskData.byteLength)
for (let i = 0; i < maskData.length; i += 4) {
const red = maskData[i]
const green = maskData[i + 1]
const blue = maskData[i + 2]
// 判断掩码图片像素点是否是黑色
if (red !== 0 && green !== 0 && blue !== 0) {
resultData[i] = originalData[i]
resultData[i + 1] = originalData[i + 1]
resultData[i + 2] = originalData[i + 2]
resultData[i + 3] = originalData[i + 3]
} else {
resultData[i + 3] = 0
}
}
return resultData
}
至此,大功告成了!!!
五、项目地址
- 本项目代码已经开源,github传送门
- 微信预览
原创内容,禁止转载。